138 lines
4.8 KiB
Python
138 lines
4.8 KiB
Python
"""
|
|
Redis-based memory implementation using LangGraph built-in components.
|
|
Provides session-level chat history with 7-day TTL.
|
|
"""
|
|
import logging
|
|
import ssl
|
|
from typing import Dict, Any, Optional
|
|
|
|
try:
|
|
import redis
|
|
from redis.exceptions import ConnectionError, TimeoutError
|
|
from langgraph.checkpoint.redis import RedisSaver
|
|
REDIS_AVAILABLE = True
|
|
except ImportError as e:
|
|
logging.warning(f"Redis packages not available: {e}")
|
|
REDIS_AVAILABLE = False
|
|
redis = None
|
|
RedisSaver = None
|
|
|
|
from langgraph.checkpoint.memory import InMemorySaver
|
|
from ..config import get_config
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class RedisMemoryManager:
|
|
"""
|
|
Redis-based memory manager using LangGraph's built-in components.
|
|
Falls back to in-memory storage if Redis is not available.
|
|
"""
|
|
|
|
def __init__(self):
|
|
self.config = get_config()
|
|
self.redis_config = self.config.redis
|
|
self._checkpointer: Optional[Any] = None
|
|
self._redis_available = REDIS_AVAILABLE
|
|
|
|
def _get_redis_client_kwargs(self) -> Dict[str, Any]:
|
|
"""Get Redis client configuration for Azure Redis compatibility."""
|
|
if not self._redis_available:
|
|
return {}
|
|
|
|
kwargs = {
|
|
"host": self.redis_config.host,
|
|
"port": self.redis_config.port,
|
|
"password": self.redis_config.password,
|
|
"db": self.redis_config.db,
|
|
"decode_responses": False, # Required for RedisSaver
|
|
"socket_timeout": 30,
|
|
"socket_connect_timeout": 10,
|
|
"retry_on_timeout": True,
|
|
"health_check_interval": 30,
|
|
}
|
|
|
|
if self.redis_config.use_ssl:
|
|
kwargs.update({
|
|
"ssl": True,
|
|
"ssl_cert_reqs": ssl.CERT_REQUIRED,
|
|
"ssl_check_hostname": True,
|
|
})
|
|
|
|
return kwargs
|
|
|
|
def _get_ttl_config(self) -> Dict[str, Any]:
|
|
"""Get TTL configuration for automatic cleanup."""
|
|
ttl_days = self.redis_config.ttl_days
|
|
ttl_minutes = ttl_days * 24 * 60 # Convert days to minutes
|
|
|
|
return {
|
|
"default_ttl": ttl_minutes,
|
|
"refresh_on_read": True, # Refresh TTL when accessed
|
|
}
|
|
|
|
def get_checkpointer(self):
|
|
"""Get checkpointer for conversation history (Redis if available, else in-memory)."""
|
|
if self._checkpointer is None:
|
|
if self._redis_available:
|
|
try:
|
|
ttl_config = self._get_ttl_config()
|
|
|
|
# Create Redis client with proper configuration for Azure Redis
|
|
redis_client = redis.Redis(**self._get_redis_client_kwargs())
|
|
|
|
# Test connection
|
|
redis_client.ping()
|
|
logger.info("Redis connection established successfully")
|
|
|
|
# Create checkpointer with TTL support
|
|
self._checkpointer = RedisSaver(
|
|
redis_client=redis_client,
|
|
ttl=ttl_config
|
|
)
|
|
|
|
# Initialize indices (required for first-time setup)
|
|
self._checkpointer.setup()
|
|
logger.info(f"Redis checkpointer initialized with {self.redis_config.ttl_days}-day TTL")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to initialize Redis checkpointer, falling back to in-memory: {e}")
|
|
self._checkpointer = InMemorySaver()
|
|
else:
|
|
logger.info("Redis not available, using in-memory checkpointer")
|
|
self._checkpointer = InMemorySaver()
|
|
|
|
return self._checkpointer
|
|
|
|
def test_connection(self) -> bool:
|
|
"""Test Redis connection and return True if successful."""
|
|
if not self._redis_available:
|
|
logger.warning("Redis packages not available")
|
|
return False
|
|
|
|
try:
|
|
redis_client = redis.Redis(**self._get_redis_client_kwargs())
|
|
redis_client.ping()
|
|
logger.info("Redis connection test successful")
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Redis connection test failed: {e}")
|
|
return False
|
|
|
|
|
|
# Global memory manager instance
|
|
_memory_manager: Optional[RedisMemoryManager] = None
|
|
|
|
|
|
def get_memory_manager() -> RedisMemoryManager:
|
|
"""Get global Redis memory manager instance."""
|
|
global _memory_manager
|
|
if _memory_manager is None:
|
|
_memory_manager = RedisMemoryManager()
|
|
return _memory_manager
|
|
|
|
|
|
def get_checkpointer():
|
|
"""Get checkpointer for conversation history."""
|
|
return get_memory_manager().get_checkpointer()
|