114 lines
4.2 KiB
Python
114 lines
4.2 KiB
Python
from typing import Dict, Any, Optional
|
|
from datetime import datetime, timedelta
|
|
import logging
|
|
|
|
from .postgresql_memory import get_memory_manager, get_checkpointer
|
|
from ..graph.state import TurnState, Message
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class InMemoryStore:
|
|
"""Simple in-memory store with TTL for conversation history"""
|
|
|
|
def __init__(self, ttl_days: float = 7.0):
|
|
self.ttl_days = ttl_days
|
|
self.store: Dict[str, Dict[str, Any]] = {}
|
|
|
|
def _is_expired(self, timestamp: datetime) -> bool:
|
|
"""Check if a record has expired"""
|
|
return datetime.now() - timestamp > timedelta(days=self.ttl_days)
|
|
|
|
def _cleanup_expired(self) -> None:
|
|
"""Remove expired records"""
|
|
expired_keys = []
|
|
for session_id, data in self.store.items():
|
|
if self._is_expired(data.get("last_updated", datetime.min)):
|
|
expired_keys.append(session_id)
|
|
|
|
for key in expired_keys:
|
|
del self.store[key]
|
|
logger.info(f"Cleaned up expired session: {key}")
|
|
|
|
def get(self, session_id: str) -> Optional[TurnState]:
|
|
"""Get conversation state for a session"""
|
|
self._cleanup_expired()
|
|
|
|
if session_id not in self.store:
|
|
return None
|
|
|
|
data = self.store[session_id]
|
|
if self._is_expired(data.get("last_updated", datetime.min)):
|
|
del self.store[session_id]
|
|
return None
|
|
|
|
try:
|
|
# Reconstruct TurnState from stored data
|
|
state_data = data["state"]
|
|
return TurnState(**state_data)
|
|
except Exception as e:
|
|
logger.error(f"Failed to deserialize state for session {session_id}: {e}")
|
|
return None
|
|
|
|
def put(self, session_id: str, state: TurnState) -> None:
|
|
"""Store conversation state for a session"""
|
|
try:
|
|
self.store[session_id] = {
|
|
"state": state.model_dump(),
|
|
"last_updated": datetime.now()
|
|
}
|
|
logger.debug(f"Stored state for session: {session_id}")
|
|
except Exception as e:
|
|
logger.error(f"Failed to store state for session {session_id}: {e}")
|
|
|
|
def trim(self, session_id: str, max_messages: int = 20) -> None:
|
|
"""Trim old messages to stay within token limits"""
|
|
state = self.get(session_id)
|
|
if not state:
|
|
return
|
|
|
|
if len(state.messages) > max_messages:
|
|
# Keep system message (if any) and recent user/assistant pairs
|
|
trimmed_messages = state.messages[-max_messages:]
|
|
|
|
# Try to preserve complete conversation turns
|
|
if len(trimmed_messages) > 1 and trimmed_messages[0].role == "assistant":
|
|
trimmed_messages = trimmed_messages[1:]
|
|
|
|
state.messages = trimmed_messages
|
|
self.put(session_id, state)
|
|
logger.info(f"Trimmed messages for session {session_id} to {len(trimmed_messages)}")
|
|
|
|
def create_new_session(self, session_id: str) -> TurnState:
|
|
"""Create a new conversation session"""
|
|
state = TurnState(session_id=session_id)
|
|
self.put(session_id, state)
|
|
return state
|
|
|
|
def add_message(self, session_id: str, message: Message) -> None:
|
|
"""Add a message to the conversation history"""
|
|
state = self.get(session_id)
|
|
if not state:
|
|
state = self.create_new_session(session_id)
|
|
|
|
state.messages.append(message)
|
|
self.put(session_id, state)
|
|
|
|
def get_conversation_history(self, session_id: str, max_turns: int = 10) -> str:
|
|
"""Get formatted conversation history for prompts"""
|
|
state = self.get(session_id)
|
|
if not state or not state.messages:
|
|
return ""
|
|
|
|
# Get recent messages, keeping complete turns
|
|
recent_messages = state.messages[-(max_turns * 2):]
|
|
|
|
history_parts = []
|
|
for msg in recent_messages:
|
|
if msg.role == "user":
|
|
history_parts.append(f"User: {msg.content}")
|
|
elif msg.role == "assistant" and not msg.tool_call_id:
|
|
history_parts.append(f"Assistant: {msg.content}")
|
|
|
|
return "\n".join(history_parts)
|