""" Conversation history trimming utilities for managing context length. """ import logging from typing import List, Optional, Sequence, Tuple from langchain_core.messages import BaseMessage, SystemMessage, HumanMessage, ToolMessage, AIMessage, AnyMessage from langchain_core.messages.utils import trim_messages, count_tokens_approximately logger = logging.getLogger(__name__) class ConversationTrimmer: """ Manages conversation history to prevent exceeding LLM context limits. """ def __init__(self, max_context_length: int = 96000, preserve_system: bool = True): """ Initialize the conversation trimmer. Args: max_context_length: Maximum context length for conversation history (in tokens) preserve_system: Whether to always preserve system messages """ self.max_context_length = max_context_length self.preserve_system = preserve_system # Reserve tokens for response generation (use 85% for history, 15% for response) self.history_token_limit = int(max_context_length * 0.85) def trim_conversation_history(self, messages: Sequence[AnyMessage]) -> List[BaseMessage]: """ Trim conversation history to fit within token limits. Args: messages: List of conversation messages Returns: Trimmed list of messages """ if not messages: return list(messages) try: # Convert to list for processing message_list = list(messages) # First, try multi-round tool call optimization optimized_messages = self._optimize_multi_round_tool_calls(message_list) # Check if optimization is sufficient try: token_count = count_tokens_approximately(optimized_messages) if token_count <= self.history_token_limit: original_count = len(message_list) optimized_count = len(optimized_messages) if optimized_count < original_count: logger.info(f"Multi-round tool optimization: {original_count} -> {optimized_count} messages") return optimized_messages except Exception: # If token counting fails, continue with LangChain trimming pass # If still too long, use LangChain's trim_messages utility trimmed_messages = trim_messages( optimized_messages, strategy="last", # Keep most recent messages token_counter=count_tokens_approximately, max_tokens=self.history_token_limit, start_on="human", # Ensure valid conversation start end_on=("human", "tool", "ai"), # Allow ending on human, tool, or AI messages include_system=self.preserve_system, # Preserve system messages allow_partial=False # Don't split individual messages ) original_count = len(messages) trimmed_count = len(trimmed_messages) if trimmed_count < original_count: logger.info(f"Trimmed conversation history: {original_count} -> {trimmed_count} messages") return trimmed_messages except Exception as e: logger.error(f"Error trimming conversation history: {e}") # Fallback: keep last N messages return self._fallback_trim(list(messages)) def _optimize_multi_round_tool_calls(self, messages: List[AnyMessage]) -> List[BaseMessage]: """ Optimize conversation history by removing older tool call results in multi-round scenarios. This reduces token usage while preserving conversation context. Strategy: 1. Always preserve system messages 2. Always preserve the original user query 3. Keep the most recent AI-Tool message pairs (for context continuity) 4. Remove older ToolMessage content which typically contains large JSON responses Args: messages: List of conversation messages Returns: Optimized list of messages """ if len(messages) <= 4: # Too short to optimize return [msg for msg in messages] # Identify message patterns tool_rounds = self._identify_tool_rounds(messages) if len(tool_rounds) <= 1: # Single or no tool round, no optimization needed return [msg for msg in messages] logger.info(f"Multi-round tool optimization: Found {len(tool_rounds)} tool rounds") # Build optimized message list optimized = [] # Always preserve system messages for msg in messages: if isinstance(msg, SystemMessage): optimized.append(msg) # Preserve initial user query (first human message after system) first_human_added = False for msg in messages: if isinstance(msg, HumanMessage) and not first_human_added: optimized.append(msg) first_human_added = True break # Keep only the most recent tool round (preserve context for next round) if tool_rounds: latest_round_start, latest_round_end = tool_rounds[-1] # Add messages from the latest tool round for i in range(latest_round_start, min(latest_round_end + 1, len(messages))): msg = messages[i] if not isinstance(msg, SystemMessage) and not (isinstance(msg, HumanMessage) and not first_human_added): optimized.append(msg) logger.info(f"Multi-round optimization: {len(messages)} -> {len(optimized)} messages (removed {len(tool_rounds)-1} older tool rounds)") return optimized def _identify_tool_rounds(self, messages: List[AnyMessage]) -> List[Tuple[int, int]]: """ Identify tool calling rounds in the message sequence. A tool round typically consists of: - AI message with tool_calls - One or more ToolMessage responses Returns: List of (start_index, end_index) tuples for each tool round """ rounds = [] i = 0 while i < len(messages): msg = messages[i] # Look for AI message with tool calls if isinstance(msg, AIMessage) and hasattr(msg, 'tool_calls') and msg.tool_calls: round_start = i round_end = i # Find the end of this tool round (look for consecutive ToolMessages) j = i + 1 while j < len(messages) and isinstance(messages[j], ToolMessage): round_end = j j += 1 # Only consider it a tool round if we found at least one ToolMessage if round_end > round_start: rounds.append((round_start, round_end)) i = round_end + 1 else: i += 1 else: i += 1 return rounds def _fallback_trim(self, messages: List[AnyMessage], max_messages: int = 20) -> List[BaseMessage]: """ Fallback trimming based on message count. Args: messages: List of conversation messages max_messages: Maximum number of messages to keep Returns: Trimmed list of messages """ if len(messages) <= max_messages: return [msg for msg in messages] # Convert to BaseMessage # Preserve system message if it exists system_messages = [msg for msg in messages if isinstance(msg, SystemMessage)] other_messages = [msg for msg in messages if not isinstance(msg, SystemMessage)] # Keep the most recent messages recent_messages = other_messages[-(max_messages - len(system_messages)):] result = system_messages + recent_messages logger.info(f"Fallback trimming: {len(messages)} -> {len(result)} messages") return [msg for msg in result] # Ensure BaseMessage type def should_trim(self, messages: Sequence[AnyMessage]) -> bool: """ Check if conversation history should be trimmed. Strategy: 1. Always trim if there are multiple tool rounds from previous conversation turns 2. Also trim if approaching token limit Args: messages: List of conversation messages Returns: True if trimming is needed """ try: # Convert to list for processing message_list = list(messages) # Check for multiple tool rounds - if found, always trim to remove old tool results tool_rounds = self._identify_tool_rounds(message_list) if len(tool_rounds) > 1: logger.info(f"Found {len(tool_rounds)} tool rounds - trimming to remove old tool results") return True # Also check token count for traditional trimming token_count = count_tokens_approximately(message_list) return token_count > self.history_token_limit except Exception: # Fallback to message count return len(messages) > 30 def create_conversation_trimmer(max_context_length: Optional[int] = None) -> ConversationTrimmer: """ Create a conversation trimmer with config-based settings. Args: max_context_length: Override for maximum context length Returns: ConversationTrimmer instance """ # If max_context_length is provided, use it directly if max_context_length is not None: return ConversationTrimmer( max_context_length=max_context_length, preserve_system=True ) # Try to get from config, fallback to default if config not available try: from ..config import get_config config = get_config() effective_max_context_length = config.get_max_context_length() except (RuntimeError, AttributeError): effective_max_context_length = 96000 return ConversationTrimmer( max_context_length=effective_max_context_length, preserve_system=True )