Files
catonline_ai/vw-agentic-rag/service/graph/message_trimmer.py
2025-09-26 17:15:54 +08:00

271 lines
10 KiB
Python

"""
Conversation history trimming utilities for managing context length.
"""
import logging
from typing import List, Optional, Sequence, Tuple
from langchain_core.messages import BaseMessage, SystemMessage, HumanMessage, ToolMessage, AIMessage, AnyMessage
from langchain_core.messages.utils import trim_messages, count_tokens_approximately
logger = logging.getLogger(__name__)
class ConversationTrimmer:
"""
Manages conversation history to prevent exceeding LLM context limits.
"""
def __init__(self, max_context_length: int = 96000, preserve_system: bool = True):
"""
Initialize the conversation trimmer.
Args:
max_context_length: Maximum context length for conversation history (in tokens)
preserve_system: Whether to always preserve system messages
"""
self.max_context_length = max_context_length
self.preserve_system = preserve_system
# Reserve tokens for response generation (use 85% for history, 15% for response)
self.history_token_limit = int(max_context_length * 0.85)
def trim_conversation_history(self, messages: Sequence[AnyMessage]) -> List[BaseMessage]:
"""
Trim conversation history to fit within token limits.
Args:
messages: List of conversation messages
Returns:
Trimmed list of messages
"""
if not messages:
return list(messages)
try:
# Convert to list for processing
message_list = list(messages)
# First, try multi-round tool call optimization
optimized_messages = self._optimize_multi_round_tool_calls(message_list)
# Check if optimization is sufficient
try:
token_count = count_tokens_approximately(optimized_messages)
if token_count <= self.history_token_limit:
original_count = len(message_list)
optimized_count = len(optimized_messages)
if optimized_count < original_count:
logger.info(f"Multi-round tool optimization: {original_count} -> {optimized_count} messages")
return optimized_messages
except Exception:
# If token counting fails, continue with LangChain trimming
pass
# If still too long, use LangChain's trim_messages utility
trimmed_messages = trim_messages(
optimized_messages,
strategy="last", # Keep most recent messages
token_counter=count_tokens_approximately,
max_tokens=self.history_token_limit,
start_on="human", # Ensure valid conversation start
end_on=("human", "tool", "ai"), # Allow ending on human, tool, or AI messages
include_system=self.preserve_system, # Preserve system messages
allow_partial=False # Don't split individual messages
)
original_count = len(messages)
trimmed_count = len(trimmed_messages)
if trimmed_count < original_count:
logger.info(f"Trimmed conversation history: {original_count} -> {trimmed_count} messages")
return trimmed_messages
except Exception as e:
logger.error(f"Error trimming conversation history: {e}")
# Fallback: keep last N messages
return self._fallback_trim(list(messages))
def _optimize_multi_round_tool_calls(self, messages: List[AnyMessage]) -> List[BaseMessage]:
"""
Optimize conversation history by removing older tool call results in multi-round scenarios.
This reduces token usage while preserving conversation context.
Strategy:
1. Always preserve system messages
2. Always preserve the original user query
3. Keep the most recent AI-Tool message pairs (for context continuity)
4. Remove older ToolMessage content which typically contains large JSON responses
Args:
messages: List of conversation messages
Returns:
Optimized list of messages
"""
if len(messages) <= 4: # Too short to optimize
return [msg for msg in messages]
# Identify message patterns
tool_rounds = self._identify_tool_rounds(messages)
if len(tool_rounds) <= 1: # Single or no tool round, no optimization needed
return [msg for msg in messages]
logger.info(f"Multi-round tool optimization: Found {len(tool_rounds)} tool rounds")
# Build optimized message list
optimized = []
# Always preserve system messages
for msg in messages:
if isinstance(msg, SystemMessage):
optimized.append(msg)
# Preserve initial user query (first human message after system)
first_human_added = False
for msg in messages:
if isinstance(msg, HumanMessage) and not first_human_added:
optimized.append(msg)
first_human_added = True
break
# Keep only the most recent tool round (preserve context for next round)
if tool_rounds:
latest_round_start, latest_round_end = tool_rounds[-1]
# Add messages from the latest tool round
for i in range(latest_round_start, min(latest_round_end + 1, len(messages))):
msg = messages[i]
if not isinstance(msg, SystemMessage) and not (isinstance(msg, HumanMessage) and not first_human_added):
optimized.append(msg)
logger.info(f"Multi-round optimization: {len(messages)} -> {len(optimized)} messages (removed {len(tool_rounds)-1} older tool rounds)")
return optimized
def _identify_tool_rounds(self, messages: List[AnyMessage]) -> List[Tuple[int, int]]:
"""
Identify tool calling rounds in the message sequence.
A tool round typically consists of:
- AI message with tool_calls
- One or more ToolMessage responses
Returns:
List of (start_index, end_index) tuples for each tool round
"""
rounds = []
i = 0
while i < len(messages):
msg = messages[i]
# Look for AI message with tool calls
if isinstance(msg, AIMessage) and hasattr(msg, 'tool_calls') and msg.tool_calls:
round_start = i
round_end = i
# Find the end of this tool round (look for consecutive ToolMessages)
j = i + 1
while j < len(messages) and isinstance(messages[j], ToolMessage):
round_end = j
j += 1
# Only consider it a tool round if we found at least one ToolMessage
if round_end > round_start:
rounds.append((round_start, round_end))
i = round_end + 1
else:
i += 1
else:
i += 1
return rounds
def _fallback_trim(self, messages: List[AnyMessage], max_messages: int = 20) -> List[BaseMessage]:
"""
Fallback trimming based on message count.
Args:
messages: List of conversation messages
max_messages: Maximum number of messages to keep
Returns:
Trimmed list of messages
"""
if len(messages) <= max_messages:
return [msg for msg in messages] # Convert to BaseMessage
# Preserve system message if it exists
system_messages = [msg for msg in messages if isinstance(msg, SystemMessage)]
other_messages = [msg for msg in messages if not isinstance(msg, SystemMessage)]
# Keep the most recent messages
recent_messages = other_messages[-(max_messages - len(system_messages)):]
result = system_messages + recent_messages
logger.info(f"Fallback trimming: {len(messages)} -> {len(result)} messages")
return [msg for msg in result] # Ensure BaseMessage type
def should_trim(self, messages: Sequence[AnyMessage]) -> bool:
"""
Check if conversation history should be trimmed.
Strategy:
1. Always trim if there are multiple tool rounds from previous conversation turns
2. Also trim if approaching token limit
Args:
messages: List of conversation messages
Returns:
True if trimming is needed
"""
try:
# Convert to list for processing
message_list = list(messages)
# Check for multiple tool rounds - if found, always trim to remove old tool results
tool_rounds = self._identify_tool_rounds(message_list)
if len(tool_rounds) > 1:
logger.info(f"Found {len(tool_rounds)} tool rounds - trimming to remove old tool results")
return True
# Also check token count for traditional trimming
token_count = count_tokens_approximately(message_list)
return token_count > self.history_token_limit
except Exception:
# Fallback to message count
return len(messages) > 30
def create_conversation_trimmer(max_context_length: Optional[int] = None) -> ConversationTrimmer:
"""
Create a conversation trimmer with config-based settings.
Args:
max_context_length: Override for maximum context length
Returns:
ConversationTrimmer instance
"""
# If max_context_length is provided, use it directly
if max_context_length is not None:
return ConversationTrimmer(
max_context_length=max_context_length,
preserve_system=True
)
# Try to get from config, fallback to default if config not available
try:
from ..config import get_config
config = get_config()
effective_max_context_length = config.get_max_context_length()
except (RuntimeError, AttributeError):
effective_max_context_length = 96000
return ConversationTrimmer(
max_context_length=effective_max_context_length,
preserve_system=True
)