"""Redis-backed conversation store for persistent chat sessions. Sessions are stored as JSON strings under the key `session:{session_id}`. The Redis TTL is refreshed on every write so active sessions stay alive. On expiry, `get_session` returns None — callers should create a new session. """ from __future__ import annotations import json import time import uuid from typing import Any from loguru import logger from app.domain.conversation import ConversationMessage, ConversationSession, ConversationStore class RedisConversationStore(ConversationStore): """Store conversation sessions in Redis with automatic TTL expiry. Each session is serialised as a JSON object at key ``session:{session_id}``. The TTL is reset on every write so sessions stay alive as long as they are active. """ # Prefix for all session keys to avoid collisions with other Redis consumers. _PREFIX = "session:" def __init__(self, *, redis_client: Any, timeout_seconds: int = 1800) -> None: """Initialise the store with an existing Redis client and a TTL in seconds.""" self._redis = redis_client self._ttl = timeout_seconds # ------------------------------------------------------------------ # Internal helpers # ------------------------------------------------------------------ def _key(self, session_id: str) -> str: """Build the Redis key for a session.""" return f"{self._PREFIX}{session_id}" def _serialise(self, session: ConversationSession) -> str: """Serialise a ConversationSession to a JSON string.""" return json.dumps( { "session_id": session.session_id, "created_at": session.created_at, "updated_at": session.updated_at, "metadata": session.metadata, "messages": [ { "role": msg.role, "content": msg.content, "timestamp": msg.timestamp, "sources": msg.sources, } for msg in session.messages ], }, ensure_ascii=False, ) def _deserialise(self, raw: bytes | str) -> ConversationSession: """Deserialise a JSON string back into a ConversationSession.""" data = json.loads(raw) messages = [ ConversationMessage( role=m["role"], content=m["content"], timestamp=m["timestamp"], sources=m.get("sources", []), ) for m in data.get("messages", []) ] session = ConversationSession( session_id=data["session_id"], created_at=data.get("created_at", 0), updated_at=data.get("updated_at", 0), metadata=data.get("metadata", {}), ) session.messages = messages return session def _save(self, session: ConversationSession) -> None: """Persist a session to Redis and refresh its TTL.""" self._redis.setex(self._key(session.session_id), self._ttl, self._serialise(session)) # ------------------------------------------------------------------ # ConversationStore protocol # ------------------------------------------------------------------ def create_session(self, metadata: dict | None = None) -> ConversationSession: """Create a new empty session and persist it immediately.""" now = int(time.time()) session = ConversationSession( session_id=str(uuid.uuid4())[:8], created_at=now, updated_at=now, metadata=metadata or {}, ) self._save(session) return session def get_session(self, session_id: str) -> ConversationSession | None: """Return a session by ID, or None if it does not exist or has expired.""" raw = self._redis.get(self._key(session_id)) if raw is None: return None try: return self._deserialise(raw) except Exception: logger.warning("Failed to deserialise session: {}", session_id) return None def save_message( self, session_id: str, *, role: str, content: str, sources: list[dict] | None = None, ) -> ConversationSession | None: """Append a message to a session and refresh its TTL.""" session = self.get_session(session_id) if session is None: return None session.messages.append( ConversationMessage( role=role, content=content, timestamp=int(time.time()), sources=sources or [], ) ) session.updated_at = int(time.time()) self._save(session) return session def delete_session(self, session_id: str) -> bool: """Delete a session. Returns True if it existed, False otherwise.""" deleted = self._redis.delete(self._key(session_id)) return bool(deleted) def list_sessions(self) -> list[dict]: """Return summary dicts for all live sessions visible in this Redis DB. Note: KEYS is used for simplicity; replace with SCAN for large deployments. """ pattern = f"{self._PREFIX}*" keys = self._redis.keys(pattern) result = [] for key in keys: raw = self._redis.get(key) if raw is None: continue try: data = json.loads(raw) result.append( { "session_id": data["session_id"], "message_count": len(data.get("messages", [])), "created_at": data.get("created_at", 0), "updated_at": data.get("updated_at", 0), } ) except Exception: continue return result