170 lines
5.9 KiB
Python
170 lines
5.9 KiB
Python
|
|
"""Redis-backed conversation store for persistent chat sessions.
|
||
|
|
|
||
|
|
Sessions are stored as JSON strings under the key `session:{session_id}`.
|
||
|
|
The Redis TTL is refreshed on every write so active sessions stay alive.
|
||
|
|
On expiry, `get_session` returns None — callers should create a new session.
|
||
|
|
"""
|
||
|
|
|
||
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
import json
|
||
|
|
import time
|
||
|
|
import uuid
|
||
|
|
from typing import Any
|
||
|
|
|
||
|
|
from loguru import logger
|
||
|
|
|
||
|
|
from app.domain.conversation import ConversationMessage, ConversationSession, ConversationStore
|
||
|
|
|
||
|
|
|
||
|
|
class RedisConversationStore(ConversationStore):
|
||
|
|
"""Store conversation sessions in Redis with automatic TTL expiry.
|
||
|
|
|
||
|
|
Each session is serialised as a JSON object at key ``session:{session_id}``.
|
||
|
|
The TTL is reset on every write so sessions stay alive as long as they are active.
|
||
|
|
"""
|
||
|
|
|
||
|
|
# Prefix for all session keys to avoid collisions with other Redis consumers.
|
||
|
|
_PREFIX = "session:"
|
||
|
|
|
||
|
|
def __init__(self, *, redis_client: Any, timeout_seconds: int = 1800) -> None:
|
||
|
|
"""Initialise the store with an existing Redis client and a TTL in seconds."""
|
||
|
|
self._redis = redis_client
|
||
|
|
self._ttl = timeout_seconds
|
||
|
|
|
||
|
|
# ------------------------------------------------------------------
|
||
|
|
# Internal helpers
|
||
|
|
# ------------------------------------------------------------------
|
||
|
|
|
||
|
|
def _key(self, session_id: str) -> str:
|
||
|
|
"""Build the Redis key for a session."""
|
||
|
|
return f"{self._PREFIX}{session_id}"
|
||
|
|
|
||
|
|
def _serialise(self, session: ConversationSession) -> str:
|
||
|
|
"""Serialise a ConversationSession to a JSON string."""
|
||
|
|
return json.dumps(
|
||
|
|
{
|
||
|
|
"session_id": session.session_id,
|
||
|
|
"created_at": session.created_at,
|
||
|
|
"updated_at": session.updated_at,
|
||
|
|
"metadata": session.metadata,
|
||
|
|
"messages": [
|
||
|
|
{
|
||
|
|
"role": msg.role,
|
||
|
|
"content": msg.content,
|
||
|
|
"timestamp": msg.timestamp,
|
||
|
|
"sources": msg.sources,
|
||
|
|
}
|
||
|
|
for msg in session.messages
|
||
|
|
],
|
||
|
|
},
|
||
|
|
ensure_ascii=False,
|
||
|
|
)
|
||
|
|
|
||
|
|
def _deserialise(self, raw: bytes | str) -> ConversationSession:
|
||
|
|
"""Deserialise a JSON string back into a ConversationSession."""
|
||
|
|
data = json.loads(raw)
|
||
|
|
messages = [
|
||
|
|
ConversationMessage(
|
||
|
|
role=m["role"],
|
||
|
|
content=m["content"],
|
||
|
|
timestamp=m["timestamp"],
|
||
|
|
sources=m.get("sources", []),
|
||
|
|
)
|
||
|
|
for m in data.get("messages", [])
|
||
|
|
]
|
||
|
|
session = ConversationSession(
|
||
|
|
session_id=data["session_id"],
|
||
|
|
created_at=data.get("created_at", 0),
|
||
|
|
updated_at=data.get("updated_at", 0),
|
||
|
|
metadata=data.get("metadata", {}),
|
||
|
|
)
|
||
|
|
session.messages = messages
|
||
|
|
return session
|
||
|
|
|
||
|
|
def _save(self, session: ConversationSession) -> None:
|
||
|
|
"""Persist a session to Redis and refresh its TTL."""
|
||
|
|
self._redis.setex(self._key(session.session_id), self._ttl, self._serialise(session))
|
||
|
|
|
||
|
|
# ------------------------------------------------------------------
|
||
|
|
# ConversationStore protocol
|
||
|
|
# ------------------------------------------------------------------
|
||
|
|
|
||
|
|
def create_session(self, metadata: dict | None = None) -> ConversationSession:
|
||
|
|
"""Create a new empty session and persist it immediately."""
|
||
|
|
now = int(time.time())
|
||
|
|
session = ConversationSession(
|
||
|
|
session_id=str(uuid.uuid4())[:8],
|
||
|
|
created_at=now,
|
||
|
|
updated_at=now,
|
||
|
|
metadata=metadata or {},
|
||
|
|
)
|
||
|
|
self._save(session)
|
||
|
|
return session
|
||
|
|
|
||
|
|
def get_session(self, session_id: str) -> ConversationSession | None:
|
||
|
|
"""Return a session by ID, or None if it does not exist or has expired."""
|
||
|
|
raw = self._redis.get(self._key(session_id))
|
||
|
|
if raw is None:
|
||
|
|
return None
|
||
|
|
try:
|
||
|
|
return self._deserialise(raw)
|
||
|
|
except Exception:
|
||
|
|
logger.warning("Failed to deserialise session: {}", session_id)
|
||
|
|
return None
|
||
|
|
|
||
|
|
def save_message(
|
||
|
|
self,
|
||
|
|
session_id: str,
|
||
|
|
*,
|
||
|
|
role: str,
|
||
|
|
content: str,
|
||
|
|
sources: list[dict] | None = None,
|
||
|
|
) -> ConversationSession | None:
|
||
|
|
"""Append a message to a session and refresh its TTL."""
|
||
|
|
session = self.get_session(session_id)
|
||
|
|
if session is None:
|
||
|
|
return None
|
||
|
|
session.messages.append(
|
||
|
|
ConversationMessage(
|
||
|
|
role=role,
|
||
|
|
content=content,
|
||
|
|
timestamp=int(time.time()),
|
||
|
|
sources=sources or [],
|
||
|
|
)
|
||
|
|
)
|
||
|
|
session.updated_at = int(time.time())
|
||
|
|
self._save(session)
|
||
|
|
return session
|
||
|
|
|
||
|
|
def delete_session(self, session_id: str) -> bool:
|
||
|
|
"""Delete a session. Returns True if it existed, False otherwise."""
|
||
|
|
deleted = self._redis.delete(self._key(session_id))
|
||
|
|
return bool(deleted)
|
||
|
|
|
||
|
|
def list_sessions(self) -> list[dict]:
|
||
|
|
"""Return summary dicts for all live sessions visible in this Redis DB.
|
||
|
|
|
||
|
|
Note: KEYS is used for simplicity; replace with SCAN for large deployments.
|
||
|
|
"""
|
||
|
|
pattern = f"{self._PREFIX}*"
|
||
|
|
keys = self._redis.keys(pattern)
|
||
|
|
result = []
|
||
|
|
for key in keys:
|
||
|
|
raw = self._redis.get(key)
|
||
|
|
if raw is None:
|
||
|
|
continue
|
||
|
|
try:
|
||
|
|
data = json.loads(raw)
|
||
|
|
result.append(
|
||
|
|
{
|
||
|
|
"session_id": data["session_id"],
|
||
|
|
"message_count": len(data.get("messages", [])),
|
||
|
|
"created_at": data.get("created_at", 0),
|
||
|
|
"updated_at": data.get("updated_at", 0),
|
||
|
|
}
|
||
|
|
)
|
||
|
|
except Exception:
|
||
|
|
continue
|
||
|
|
return result
|