init
This commit is contained in:
113
vw-agentic-rag/service/memory/store.py
Normal file
113
vw-agentic-rag/service/memory/store.py
Normal file
@@ -0,0 +1,113 @@
|
||||
from typing import Dict, Any, Optional
|
||||
from datetime import datetime, timedelta
|
||||
import logging
|
||||
|
||||
from .postgresql_memory import get_memory_manager, get_checkpointer
|
||||
from ..graph.state import TurnState, Message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class InMemoryStore:
|
||||
"""Simple in-memory store with TTL for conversation history"""
|
||||
|
||||
def __init__(self, ttl_days: float = 7.0):
|
||||
self.ttl_days = ttl_days
|
||||
self.store: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
def _is_expired(self, timestamp: datetime) -> bool:
|
||||
"""Check if a record has expired"""
|
||||
return datetime.now() - timestamp > timedelta(days=self.ttl_days)
|
||||
|
||||
def _cleanup_expired(self) -> None:
|
||||
"""Remove expired records"""
|
||||
expired_keys = []
|
||||
for session_id, data in self.store.items():
|
||||
if self._is_expired(data.get("last_updated", datetime.min)):
|
||||
expired_keys.append(session_id)
|
||||
|
||||
for key in expired_keys:
|
||||
del self.store[key]
|
||||
logger.info(f"Cleaned up expired session: {key}")
|
||||
|
||||
def get(self, session_id: str) -> Optional[TurnState]:
|
||||
"""Get conversation state for a session"""
|
||||
self._cleanup_expired()
|
||||
|
||||
if session_id not in self.store:
|
||||
return None
|
||||
|
||||
data = self.store[session_id]
|
||||
if self._is_expired(data.get("last_updated", datetime.min)):
|
||||
del self.store[session_id]
|
||||
return None
|
||||
|
||||
try:
|
||||
# Reconstruct TurnState from stored data
|
||||
state_data = data["state"]
|
||||
return TurnState(**state_data)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to deserialize state for session {session_id}: {e}")
|
||||
return None
|
||||
|
||||
def put(self, session_id: str, state: TurnState) -> None:
|
||||
"""Store conversation state for a session"""
|
||||
try:
|
||||
self.store[session_id] = {
|
||||
"state": state.model_dump(),
|
||||
"last_updated": datetime.now()
|
||||
}
|
||||
logger.debug(f"Stored state for session: {session_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to store state for session {session_id}: {e}")
|
||||
|
||||
def trim(self, session_id: str, max_messages: int = 20) -> None:
|
||||
"""Trim old messages to stay within token limits"""
|
||||
state = self.get(session_id)
|
||||
if not state:
|
||||
return
|
||||
|
||||
if len(state.messages) > max_messages:
|
||||
# Keep system message (if any) and recent user/assistant pairs
|
||||
trimmed_messages = state.messages[-max_messages:]
|
||||
|
||||
# Try to preserve complete conversation turns
|
||||
if len(trimmed_messages) > 1 and trimmed_messages[0].role == "assistant":
|
||||
trimmed_messages = trimmed_messages[1:]
|
||||
|
||||
state.messages = trimmed_messages
|
||||
self.put(session_id, state)
|
||||
logger.info(f"Trimmed messages for session {session_id} to {len(trimmed_messages)}")
|
||||
|
||||
def create_new_session(self, session_id: str) -> TurnState:
|
||||
"""Create a new conversation session"""
|
||||
state = TurnState(session_id=session_id)
|
||||
self.put(session_id, state)
|
||||
return state
|
||||
|
||||
def add_message(self, session_id: str, message: Message) -> None:
|
||||
"""Add a message to the conversation history"""
|
||||
state = self.get(session_id)
|
||||
if not state:
|
||||
state = self.create_new_session(session_id)
|
||||
|
||||
state.messages.append(message)
|
||||
self.put(session_id, state)
|
||||
|
||||
def get_conversation_history(self, session_id: str, max_turns: int = 10) -> str:
|
||||
"""Get formatted conversation history for prompts"""
|
||||
state = self.get(session_id)
|
||||
if not state or not state.messages:
|
||||
return ""
|
||||
|
||||
# Get recent messages, keeping complete turns
|
||||
recent_messages = state.messages[-(max_turns * 2):]
|
||||
|
||||
history_parts = []
|
||||
for msg in recent_messages:
|
||||
if msg.role == "user":
|
||||
history_parts.append(f"User: {msg.content}")
|
||||
elif msg.role == "assistant" and not msg.tool_call_id:
|
||||
history_parts.append(f"Assistant: {msg.content}")
|
||||
|
||||
return "\n".join(history_parts)
|
||||
Reference in New Issue
Block a user