from typing import Dict, Any, Optional from datetime import datetime, timedelta import logging from .postgresql_memory import get_memory_manager, get_checkpointer from ..graph.state import TurnState, Message logger = logging.getLogger(__name__) class InMemoryStore: """Simple in-memory store with TTL for conversation history""" def __init__(self, ttl_days: float = 7.0): self.ttl_days = ttl_days self.store: Dict[str, Dict[str, Any]] = {} def _is_expired(self, timestamp: datetime) -> bool: """Check if a record has expired""" return datetime.now() - timestamp > timedelta(days=self.ttl_days) def _cleanup_expired(self) -> None: """Remove expired records""" expired_keys = [] for session_id, data in self.store.items(): if self._is_expired(data.get("last_updated", datetime.min)): expired_keys.append(session_id) for key in expired_keys: del self.store[key] logger.info(f"Cleaned up expired session: {key}") def get(self, session_id: str) -> Optional[TurnState]: """Get conversation state for a session""" self._cleanup_expired() if session_id not in self.store: return None data = self.store[session_id] if self._is_expired(data.get("last_updated", datetime.min)): del self.store[session_id] return None try: # Reconstruct TurnState from stored data state_data = data["state"] return TurnState(**state_data) except Exception as e: logger.error(f"Failed to deserialize state for session {session_id}: {e}") return None def put(self, session_id: str, state: TurnState) -> None: """Store conversation state for a session""" try: self.store[session_id] = { "state": state.model_dump(), "last_updated": datetime.now() } logger.debug(f"Stored state for session: {session_id}") except Exception as e: logger.error(f"Failed to store state for session {session_id}: {e}") def trim(self, session_id: str, max_messages: int = 20) -> None: """Trim old messages to stay within token limits""" state = self.get(session_id) if not state: return if len(state.messages) > max_messages: # Keep system message (if any) and recent user/assistant pairs trimmed_messages = state.messages[-max_messages:] # Try to preserve complete conversation turns if len(trimmed_messages) > 1 and trimmed_messages[0].role == "assistant": trimmed_messages = trimmed_messages[1:] state.messages = trimmed_messages self.put(session_id, state) logger.info(f"Trimmed messages for session {session_id} to {len(trimmed_messages)}") def create_new_session(self, session_id: str) -> TurnState: """Create a new conversation session""" state = TurnState(session_id=session_id) self.put(session_id, state) return state def add_message(self, session_id: str, message: Message) -> None: """Add a message to the conversation history""" state = self.get(session_id) if not state: state = self.create_new_session(session_id) state.messages.append(message) self.put(session_id, state) def get_conversation_history(self, session_id: str, max_turns: int = 10) -> str: """Get formatted conversation history for prompts""" state = self.get(session_id) if not state or not state.messages: return "" # Get recent messages, keeping complete turns recent_messages = state.messages[-(max_turns * 2):] history_parts = [] for msg in recent_messages: if msg.role == "user": history_parts.append(f"User: {msg.content}") elif msg.role == "assistant" and not msg.tool_call_id: history_parts.append(f"Assistant: {msg.content}") return "\n".join(history_parts)