init
This commit is contained in:
270
vw-agentic-rag/service/graph/message_trimmer.py
Normal file
270
vw-agentic-rag/service/graph/message_trimmer.py
Normal file
@@ -0,0 +1,270 @@
|
||||
"""
|
||||
Conversation history trimming utilities for managing context length.
|
||||
"""
|
||||
import logging
|
||||
from typing import List, Optional, Sequence, Tuple
|
||||
from langchain_core.messages import BaseMessage, SystemMessage, HumanMessage, ToolMessage, AIMessage, AnyMessage
|
||||
from langchain_core.messages.utils import trim_messages, count_tokens_approximately
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ConversationTrimmer:
|
||||
"""
|
||||
Manages conversation history to prevent exceeding LLM context limits.
|
||||
"""
|
||||
|
||||
def __init__(self, max_context_length: int = 96000, preserve_system: bool = True):
|
||||
"""
|
||||
Initialize the conversation trimmer.
|
||||
|
||||
Args:
|
||||
max_context_length: Maximum context length for conversation history (in tokens)
|
||||
preserve_system: Whether to always preserve system messages
|
||||
"""
|
||||
self.max_context_length = max_context_length
|
||||
self.preserve_system = preserve_system
|
||||
# Reserve tokens for response generation (use 85% for history, 15% for response)
|
||||
self.history_token_limit = int(max_context_length * 0.85)
|
||||
|
||||
def trim_conversation_history(self, messages: Sequence[AnyMessage]) -> List[BaseMessage]:
|
||||
"""
|
||||
Trim conversation history to fit within token limits.
|
||||
|
||||
Args:
|
||||
messages: List of conversation messages
|
||||
|
||||
Returns:
|
||||
Trimmed list of messages
|
||||
"""
|
||||
if not messages:
|
||||
return list(messages)
|
||||
|
||||
try:
|
||||
# Convert to list for processing
|
||||
message_list = list(messages)
|
||||
|
||||
# First, try multi-round tool call optimization
|
||||
optimized_messages = self._optimize_multi_round_tool_calls(message_list)
|
||||
|
||||
# Check if optimization is sufficient
|
||||
try:
|
||||
token_count = count_tokens_approximately(optimized_messages)
|
||||
if token_count <= self.history_token_limit:
|
||||
original_count = len(message_list)
|
||||
optimized_count = len(optimized_messages)
|
||||
if optimized_count < original_count:
|
||||
logger.info(f"Multi-round tool optimization: {original_count} -> {optimized_count} messages")
|
||||
return optimized_messages
|
||||
except Exception:
|
||||
# If token counting fails, continue with LangChain trimming
|
||||
pass
|
||||
|
||||
# If still too long, use LangChain's trim_messages utility
|
||||
trimmed_messages = trim_messages(
|
||||
optimized_messages,
|
||||
strategy="last", # Keep most recent messages
|
||||
token_counter=count_tokens_approximately,
|
||||
max_tokens=self.history_token_limit,
|
||||
start_on="human", # Ensure valid conversation start
|
||||
end_on=("human", "tool", "ai"), # Allow ending on human, tool, or AI messages
|
||||
include_system=self.preserve_system, # Preserve system messages
|
||||
allow_partial=False # Don't split individual messages
|
||||
)
|
||||
|
||||
original_count = len(messages)
|
||||
trimmed_count = len(trimmed_messages)
|
||||
|
||||
if trimmed_count < original_count:
|
||||
logger.info(f"Trimmed conversation history: {original_count} -> {trimmed_count} messages")
|
||||
|
||||
return trimmed_messages
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error trimming conversation history: {e}")
|
||||
# Fallback: keep last N messages
|
||||
return self._fallback_trim(list(messages))
|
||||
|
||||
def _optimize_multi_round_tool_calls(self, messages: List[AnyMessage]) -> List[BaseMessage]:
|
||||
"""
|
||||
Optimize conversation history by removing older tool call results in multi-round scenarios.
|
||||
This reduces token usage while preserving conversation context.
|
||||
|
||||
Strategy:
|
||||
1. Always preserve system messages
|
||||
2. Always preserve the original user query
|
||||
3. Keep the most recent AI-Tool message pairs (for context continuity)
|
||||
4. Remove older ToolMessage content which typically contains large JSON responses
|
||||
|
||||
Args:
|
||||
messages: List of conversation messages
|
||||
|
||||
Returns:
|
||||
Optimized list of messages
|
||||
"""
|
||||
if len(messages) <= 4: # Too short to optimize
|
||||
return [msg for msg in messages]
|
||||
|
||||
# Identify message patterns
|
||||
tool_rounds = self._identify_tool_rounds(messages)
|
||||
|
||||
if len(tool_rounds) <= 1: # Single or no tool round, no optimization needed
|
||||
return [msg for msg in messages]
|
||||
|
||||
logger.info(f"Multi-round tool optimization: Found {len(tool_rounds)} tool rounds")
|
||||
|
||||
# Build optimized message list
|
||||
optimized = []
|
||||
|
||||
# Always preserve system messages
|
||||
for msg in messages:
|
||||
if isinstance(msg, SystemMessage):
|
||||
optimized.append(msg)
|
||||
|
||||
# Preserve initial user query (first human message after system)
|
||||
first_human_added = False
|
||||
for msg in messages:
|
||||
if isinstance(msg, HumanMessage) and not first_human_added:
|
||||
optimized.append(msg)
|
||||
first_human_added = True
|
||||
break
|
||||
|
||||
# Keep only the most recent tool round (preserve context for next round)
|
||||
if tool_rounds:
|
||||
latest_round_start, latest_round_end = tool_rounds[-1]
|
||||
|
||||
# Add messages from the latest tool round
|
||||
for i in range(latest_round_start, min(latest_round_end + 1, len(messages))):
|
||||
msg = messages[i]
|
||||
if not isinstance(msg, SystemMessage) and not (isinstance(msg, HumanMessage) and not first_human_added):
|
||||
optimized.append(msg)
|
||||
|
||||
logger.info(f"Multi-round optimization: {len(messages)} -> {len(optimized)} messages (removed {len(tool_rounds)-1} older tool rounds)")
|
||||
return optimized
|
||||
|
||||
def _identify_tool_rounds(self, messages: List[AnyMessage]) -> List[Tuple[int, int]]:
|
||||
"""
|
||||
Identify tool calling rounds in the message sequence.
|
||||
|
||||
A tool round typically consists of:
|
||||
- AI message with tool_calls
|
||||
- One or more ToolMessage responses
|
||||
|
||||
Returns:
|
||||
List of (start_index, end_index) tuples for each tool round
|
||||
"""
|
||||
rounds = []
|
||||
i = 0
|
||||
|
||||
while i < len(messages):
|
||||
msg = messages[i]
|
||||
|
||||
# Look for AI message with tool calls
|
||||
if isinstance(msg, AIMessage) and hasattr(msg, 'tool_calls') and msg.tool_calls:
|
||||
round_start = i
|
||||
round_end = i
|
||||
|
||||
# Find the end of this tool round (look for consecutive ToolMessages)
|
||||
j = i + 1
|
||||
while j < len(messages) and isinstance(messages[j], ToolMessage):
|
||||
round_end = j
|
||||
j += 1
|
||||
|
||||
# Only consider it a tool round if we found at least one ToolMessage
|
||||
if round_end > round_start:
|
||||
rounds.append((round_start, round_end))
|
||||
i = round_end + 1
|
||||
else:
|
||||
i += 1
|
||||
else:
|
||||
i += 1
|
||||
|
||||
return rounds
|
||||
|
||||
def _fallback_trim(self, messages: List[AnyMessage], max_messages: int = 20) -> List[BaseMessage]:
|
||||
"""
|
||||
Fallback trimming based on message count.
|
||||
|
||||
Args:
|
||||
messages: List of conversation messages
|
||||
max_messages: Maximum number of messages to keep
|
||||
|
||||
Returns:
|
||||
Trimmed list of messages
|
||||
"""
|
||||
if len(messages) <= max_messages:
|
||||
return [msg for msg in messages] # Convert to BaseMessage
|
||||
|
||||
# Preserve system message if it exists
|
||||
system_messages = [msg for msg in messages if isinstance(msg, SystemMessage)]
|
||||
other_messages = [msg for msg in messages if not isinstance(msg, SystemMessage)]
|
||||
|
||||
# Keep the most recent messages
|
||||
recent_messages = other_messages[-(max_messages - len(system_messages)):]
|
||||
|
||||
result = system_messages + recent_messages
|
||||
logger.info(f"Fallback trimming: {len(messages)} -> {len(result)} messages")
|
||||
|
||||
return [msg for msg in result] # Ensure BaseMessage type
|
||||
|
||||
def should_trim(self, messages: Sequence[AnyMessage]) -> bool:
|
||||
"""
|
||||
Check if conversation history should be trimmed.
|
||||
|
||||
Strategy:
|
||||
1. Always trim if there are multiple tool rounds from previous conversation turns
|
||||
2. Also trim if approaching token limit
|
||||
|
||||
Args:
|
||||
messages: List of conversation messages
|
||||
|
||||
Returns:
|
||||
True if trimming is needed
|
||||
"""
|
||||
try:
|
||||
# Convert to list for processing
|
||||
message_list = list(messages)
|
||||
|
||||
# Check for multiple tool rounds - if found, always trim to remove old tool results
|
||||
tool_rounds = self._identify_tool_rounds(message_list)
|
||||
if len(tool_rounds) > 1:
|
||||
logger.info(f"Found {len(tool_rounds)} tool rounds - trimming to remove old tool results")
|
||||
return True
|
||||
|
||||
# Also check token count for traditional trimming
|
||||
token_count = count_tokens_approximately(message_list)
|
||||
return token_count > self.history_token_limit
|
||||
except Exception:
|
||||
# Fallback to message count
|
||||
return len(messages) > 30
|
||||
|
||||
|
||||
def create_conversation_trimmer(max_context_length: Optional[int] = None) -> ConversationTrimmer:
|
||||
"""
|
||||
Create a conversation trimmer with config-based settings.
|
||||
|
||||
Args:
|
||||
max_context_length: Override for maximum context length
|
||||
|
||||
Returns:
|
||||
ConversationTrimmer instance
|
||||
"""
|
||||
# If max_context_length is provided, use it directly
|
||||
if max_context_length is not None:
|
||||
return ConversationTrimmer(
|
||||
max_context_length=max_context_length,
|
||||
preserve_system=True
|
||||
)
|
||||
|
||||
# Try to get from config, fallback to default if config not available
|
||||
try:
|
||||
from ..config import get_config
|
||||
config = get_config()
|
||||
effective_max_context_length = config.get_max_context_length()
|
||||
except (RuntimeError, AttributeError):
|
||||
effective_max_context_length = 96000
|
||||
|
||||
return ConversationTrimmer(
|
||||
max_context_length=effective_max_context_length,
|
||||
preserve_system=True
|
||||
)
|
||||
Reference in New Issue
Block a user