271 lines
10 KiB
Python
271 lines
10 KiB
Python
"""
|
|
Conversation history trimming utilities for managing context length.
|
|
"""
|
|
import logging
|
|
from typing import List, Optional, Sequence, Tuple
|
|
from langchain_core.messages import BaseMessage, SystemMessage, HumanMessage, ToolMessage, AIMessage, AnyMessage
|
|
from langchain_core.messages.utils import trim_messages, count_tokens_approximately
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ConversationTrimmer:
|
|
"""
|
|
Manages conversation history to prevent exceeding LLM context limits.
|
|
"""
|
|
|
|
def __init__(self, max_context_length: int = 96000, preserve_system: bool = True):
|
|
"""
|
|
Initialize the conversation trimmer.
|
|
|
|
Args:
|
|
max_context_length: Maximum context length for conversation history (in tokens)
|
|
preserve_system: Whether to always preserve system messages
|
|
"""
|
|
self.max_context_length = max_context_length
|
|
self.preserve_system = preserve_system
|
|
# Reserve tokens for response generation (use 85% for history, 15% for response)
|
|
self.history_token_limit = int(max_context_length * 0.85)
|
|
|
|
def trim_conversation_history(self, messages: Sequence[AnyMessage]) -> List[BaseMessage]:
|
|
"""
|
|
Trim conversation history to fit within token limits.
|
|
|
|
Args:
|
|
messages: List of conversation messages
|
|
|
|
Returns:
|
|
Trimmed list of messages
|
|
"""
|
|
if not messages:
|
|
return list(messages)
|
|
|
|
try:
|
|
# Convert to list for processing
|
|
message_list = list(messages)
|
|
|
|
# First, try multi-round tool call optimization
|
|
optimized_messages = self._optimize_multi_round_tool_calls(message_list)
|
|
|
|
# Check if optimization is sufficient
|
|
try:
|
|
token_count = count_tokens_approximately(optimized_messages)
|
|
if token_count <= self.history_token_limit:
|
|
original_count = len(message_list)
|
|
optimized_count = len(optimized_messages)
|
|
if optimized_count < original_count:
|
|
logger.info(f"Multi-round tool optimization: {original_count} -> {optimized_count} messages")
|
|
return optimized_messages
|
|
except Exception:
|
|
# If token counting fails, continue with LangChain trimming
|
|
pass
|
|
|
|
# If still too long, use LangChain's trim_messages utility
|
|
trimmed_messages = trim_messages(
|
|
optimized_messages,
|
|
strategy="last", # Keep most recent messages
|
|
token_counter=count_tokens_approximately,
|
|
max_tokens=self.history_token_limit,
|
|
start_on="human", # Ensure valid conversation start
|
|
end_on=("human", "tool", "ai"), # Allow ending on human, tool, or AI messages
|
|
include_system=self.preserve_system, # Preserve system messages
|
|
allow_partial=False # Don't split individual messages
|
|
)
|
|
|
|
original_count = len(messages)
|
|
trimmed_count = len(trimmed_messages)
|
|
|
|
if trimmed_count < original_count:
|
|
logger.info(f"Trimmed conversation history: {original_count} -> {trimmed_count} messages")
|
|
|
|
return trimmed_messages
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error trimming conversation history: {e}")
|
|
# Fallback: keep last N messages
|
|
return self._fallback_trim(list(messages))
|
|
|
|
def _optimize_multi_round_tool_calls(self, messages: List[AnyMessage]) -> List[BaseMessage]:
|
|
"""
|
|
Optimize conversation history by removing older tool call results in multi-round scenarios.
|
|
This reduces token usage while preserving conversation context.
|
|
|
|
Strategy:
|
|
1. Always preserve system messages
|
|
2. Always preserve the original user query
|
|
3. Keep the most recent AI-Tool message pairs (for context continuity)
|
|
4. Remove older ToolMessage content which typically contains large JSON responses
|
|
|
|
Args:
|
|
messages: List of conversation messages
|
|
|
|
Returns:
|
|
Optimized list of messages
|
|
"""
|
|
if len(messages) <= 4: # Too short to optimize
|
|
return [msg for msg in messages]
|
|
|
|
# Identify message patterns
|
|
tool_rounds = self._identify_tool_rounds(messages)
|
|
|
|
if len(tool_rounds) <= 1: # Single or no tool round, no optimization needed
|
|
return [msg for msg in messages]
|
|
|
|
logger.info(f"Multi-round tool optimization: Found {len(tool_rounds)} tool rounds")
|
|
|
|
# Build optimized message list
|
|
optimized = []
|
|
|
|
# Always preserve system messages
|
|
for msg in messages:
|
|
if isinstance(msg, SystemMessage):
|
|
optimized.append(msg)
|
|
|
|
# Preserve initial user query (first human message after system)
|
|
first_human_added = False
|
|
for msg in messages:
|
|
if isinstance(msg, HumanMessage) and not first_human_added:
|
|
optimized.append(msg)
|
|
first_human_added = True
|
|
break
|
|
|
|
# Keep only the most recent tool round (preserve context for next round)
|
|
if tool_rounds:
|
|
latest_round_start, latest_round_end = tool_rounds[-1]
|
|
|
|
# Add messages from the latest tool round
|
|
for i in range(latest_round_start, min(latest_round_end + 1, len(messages))):
|
|
msg = messages[i]
|
|
if not isinstance(msg, SystemMessage) and not (isinstance(msg, HumanMessage) and not first_human_added):
|
|
optimized.append(msg)
|
|
|
|
logger.info(f"Multi-round optimization: {len(messages)} -> {len(optimized)} messages (removed {len(tool_rounds)-1} older tool rounds)")
|
|
return optimized
|
|
|
|
def _identify_tool_rounds(self, messages: List[AnyMessage]) -> List[Tuple[int, int]]:
|
|
"""
|
|
Identify tool calling rounds in the message sequence.
|
|
|
|
A tool round typically consists of:
|
|
- AI message with tool_calls
|
|
- One or more ToolMessage responses
|
|
|
|
Returns:
|
|
List of (start_index, end_index) tuples for each tool round
|
|
"""
|
|
rounds = []
|
|
i = 0
|
|
|
|
while i < len(messages):
|
|
msg = messages[i]
|
|
|
|
# Look for AI message with tool calls
|
|
if isinstance(msg, AIMessage) and hasattr(msg, 'tool_calls') and msg.tool_calls:
|
|
round_start = i
|
|
round_end = i
|
|
|
|
# Find the end of this tool round (look for consecutive ToolMessages)
|
|
j = i + 1
|
|
while j < len(messages) and isinstance(messages[j], ToolMessage):
|
|
round_end = j
|
|
j += 1
|
|
|
|
# Only consider it a tool round if we found at least one ToolMessage
|
|
if round_end > round_start:
|
|
rounds.append((round_start, round_end))
|
|
i = round_end + 1
|
|
else:
|
|
i += 1
|
|
else:
|
|
i += 1
|
|
|
|
return rounds
|
|
|
|
def _fallback_trim(self, messages: List[AnyMessage], max_messages: int = 20) -> List[BaseMessage]:
|
|
"""
|
|
Fallback trimming based on message count.
|
|
|
|
Args:
|
|
messages: List of conversation messages
|
|
max_messages: Maximum number of messages to keep
|
|
|
|
Returns:
|
|
Trimmed list of messages
|
|
"""
|
|
if len(messages) <= max_messages:
|
|
return [msg for msg in messages] # Convert to BaseMessage
|
|
|
|
# Preserve system message if it exists
|
|
system_messages = [msg for msg in messages if isinstance(msg, SystemMessage)]
|
|
other_messages = [msg for msg in messages if not isinstance(msg, SystemMessage)]
|
|
|
|
# Keep the most recent messages
|
|
recent_messages = other_messages[-(max_messages - len(system_messages)):]
|
|
|
|
result = system_messages + recent_messages
|
|
logger.info(f"Fallback trimming: {len(messages)} -> {len(result)} messages")
|
|
|
|
return [msg for msg in result] # Ensure BaseMessage type
|
|
|
|
def should_trim(self, messages: Sequence[AnyMessage]) -> bool:
|
|
"""
|
|
Check if conversation history should be trimmed.
|
|
|
|
Strategy:
|
|
1. Always trim if there are multiple tool rounds from previous conversation turns
|
|
2. Also trim if approaching token limit
|
|
|
|
Args:
|
|
messages: List of conversation messages
|
|
|
|
Returns:
|
|
True if trimming is needed
|
|
"""
|
|
try:
|
|
# Convert to list for processing
|
|
message_list = list(messages)
|
|
|
|
# Check for multiple tool rounds - if found, always trim to remove old tool results
|
|
tool_rounds = self._identify_tool_rounds(message_list)
|
|
if len(tool_rounds) > 1:
|
|
logger.info(f"Found {len(tool_rounds)} tool rounds - trimming to remove old tool results")
|
|
return True
|
|
|
|
# Also check token count for traditional trimming
|
|
token_count = count_tokens_approximately(message_list)
|
|
return token_count > self.history_token_limit
|
|
except Exception:
|
|
# Fallback to message count
|
|
return len(messages) > 30
|
|
|
|
|
|
def create_conversation_trimmer(max_context_length: Optional[int] = None) -> ConversationTrimmer:
|
|
"""
|
|
Create a conversation trimmer with config-based settings.
|
|
|
|
Args:
|
|
max_context_length: Override for maximum context length
|
|
|
|
Returns:
|
|
ConversationTrimmer instance
|
|
"""
|
|
# If max_context_length is provided, use it directly
|
|
if max_context_length is not None:
|
|
return ConversationTrimmer(
|
|
max_context_length=max_context_length,
|
|
preserve_system=True
|
|
)
|
|
|
|
# Try to get from config, fallback to default if config not available
|
|
try:
|
|
from ..config import get_config
|
|
config = get_config()
|
|
effective_max_context_length = config.get_max_context_length()
|
|
except (RuntimeError, AttributeError):
|
|
effective_max_context_length = 96000
|
|
|
|
return ConversationTrimmer(
|
|
max_context_length=effective_max_context_length,
|
|
preserve_system=True
|
|
)
|