This commit is contained in:
2025-09-26 17:15:54 +08:00
commit db0e5965ec
211 changed files with 40437 additions and 0 deletions

View File

@@ -0,0 +1,113 @@
from typing import Dict, Any, Optional
from datetime import datetime, timedelta
import logging
from .postgresql_memory import get_memory_manager, get_checkpointer
from ..graph.state import TurnState, Message
logger = logging.getLogger(__name__)
class InMemoryStore:
"""Simple in-memory store with TTL for conversation history"""
def __init__(self, ttl_days: float = 7.0):
self.ttl_days = ttl_days
self.store: Dict[str, Dict[str, Any]] = {}
def _is_expired(self, timestamp: datetime) -> bool:
"""Check if a record has expired"""
return datetime.now() - timestamp > timedelta(days=self.ttl_days)
def _cleanup_expired(self) -> None:
"""Remove expired records"""
expired_keys = []
for session_id, data in self.store.items():
if self._is_expired(data.get("last_updated", datetime.min)):
expired_keys.append(session_id)
for key in expired_keys:
del self.store[key]
logger.info(f"Cleaned up expired session: {key}")
def get(self, session_id: str) -> Optional[TurnState]:
"""Get conversation state for a session"""
self._cleanup_expired()
if session_id not in self.store:
return None
data = self.store[session_id]
if self._is_expired(data.get("last_updated", datetime.min)):
del self.store[session_id]
return None
try:
# Reconstruct TurnState from stored data
state_data = data["state"]
return TurnState(**state_data)
except Exception as e:
logger.error(f"Failed to deserialize state for session {session_id}: {e}")
return None
def put(self, session_id: str, state: TurnState) -> None:
"""Store conversation state for a session"""
try:
self.store[session_id] = {
"state": state.model_dump(),
"last_updated": datetime.now()
}
logger.debug(f"Stored state for session: {session_id}")
except Exception as e:
logger.error(f"Failed to store state for session {session_id}: {e}")
def trim(self, session_id: str, max_messages: int = 20) -> None:
"""Trim old messages to stay within token limits"""
state = self.get(session_id)
if not state:
return
if len(state.messages) > max_messages:
# Keep system message (if any) and recent user/assistant pairs
trimmed_messages = state.messages[-max_messages:]
# Try to preserve complete conversation turns
if len(trimmed_messages) > 1 and trimmed_messages[0].role == "assistant":
trimmed_messages = trimmed_messages[1:]
state.messages = trimmed_messages
self.put(session_id, state)
logger.info(f"Trimmed messages for session {session_id} to {len(trimmed_messages)}")
def create_new_session(self, session_id: str) -> TurnState:
"""Create a new conversation session"""
state = TurnState(session_id=session_id)
self.put(session_id, state)
return state
def add_message(self, session_id: str, message: Message) -> None:
"""Add a message to the conversation history"""
state = self.get(session_id)
if not state:
state = self.create_new_session(session_id)
state.messages.append(message)
self.put(session_id, state)
def get_conversation_history(self, session_id: str, max_turns: int = 10) -> str:
"""Get formatted conversation history for prompts"""
state = self.get(session_id)
if not state or not state.messages:
return ""
# Get recent messages, keeping complete turns
recent_messages = state.messages[-(max_turns * 2):]
history_parts = []
for msg in recent_messages:
if msg.role == "user":
history_parts.append(f"User: {msg.content}")
elif msg.role == "assistant" and not msg.tool_call_id:
history_parts.append(f"Assistant: {msg.content}")
return "\n".join(history_parts)