This commit is contained in:
2025-09-26 17:15:54 +08:00
commit db0e5965ec
211 changed files with 40437 additions and 0 deletions

View File

@@ -0,0 +1 @@
# Empty __init__.py files to make packages

View File

@@ -0,0 +1,332 @@
"""
PostgreSQL-based memory implementation using LangGraph built-in components.
Provides session-level chat history with 7-day TTL.
Uses psycopg3 for better compatibility without requiring libpq-dev.
"""
import logging
from typing import Dict, Any, Optional
from urllib.parse import quote_plus
from contextlib import contextmanager
try:
import psycopg
from psycopg.rows import dict_row
PSYCOPG_AVAILABLE = True
except ImportError as e:
logging.warning(f"psycopg3 not available: {e}")
PSYCOPG_AVAILABLE = False
psycopg = None
try:
from langgraph.checkpoint.postgres import PostgresSaver
LANGGRAPH_POSTGRES_AVAILABLE = True
except ImportError as e:
logging.warning(f"LangGraph PostgreSQL checkpoint not available: {e}")
LANGGRAPH_POSTGRES_AVAILABLE = False
PostgresSaver = None
try:
from langgraph.checkpoint.memory import InMemorySaver
LANGGRAPH_MEMORY_AVAILABLE = True
except ImportError as e:
logging.warning(f"LangGraph memory checkpoint not available: {e}")
LANGGRAPH_MEMORY_AVAILABLE = False
InMemorySaver = None
from ..config import get_config
logger = logging.getLogger(__name__)
POSTGRES_AVAILABLE = PSYCOPG_AVAILABLE and LANGGRAPH_POSTGRES_AVAILABLE
class PostgreSQLCheckpointerWrapper:
"""
Wrapper for PostgresSaver that manages the context properly.
"""
def __init__(self, conn_string: str):
if not LANGGRAPH_POSTGRES_AVAILABLE or PostgresSaver is None:
raise RuntimeError("PostgresSaver not available")
self.conn_string = conn_string
self._initialized = False
def _ensure_setup(self):
"""Ensure the database schema is set up."""
if not LANGGRAPH_POSTGRES_AVAILABLE or PostgresSaver is None:
raise RuntimeError("PostgresSaver not available")
if not self._initialized:
with PostgresSaver.from_conn_string(self.conn_string) as saver:
saver.setup()
self._initialized = True
logger.info("PostgreSQL schema initialized")
@contextmanager
def get_saver(self):
"""Get a PostgresSaver instance as context manager."""
if not LANGGRAPH_POSTGRES_AVAILABLE or PostgresSaver is None:
raise RuntimeError("PostgresSaver not available")
self._ensure_setup()
with PostgresSaver.from_conn_string(self.conn_string) as saver:
yield saver
def list(self, config):
"""List checkpoints."""
with self.get_saver() as saver:
return list(saver.list(config))
def get(self, config):
"""Get a checkpoint."""
with self.get_saver() as saver:
return saver.get(config)
def get_tuple(self, config):
"""Get a checkpoint tuple."""
with self.get_saver() as saver:
return saver.get_tuple(config)
def put(self, config, checkpoint, metadata, new_versions):
"""Put a checkpoint."""
with self.get_saver() as saver:
return saver.put(config, checkpoint, metadata, new_versions)
def put_writes(self, config, writes, task_id):
"""Put writes."""
with self.get_saver() as saver:
return saver.put_writes(config, writes, task_id)
def get_next_version(self, current, channel):
"""Get next version."""
with self.get_saver() as saver:
return saver.get_next_version(current, channel)
def delete_thread(self, thread_id):
"""Delete thread."""
with self.get_saver() as saver:
return saver.delete_thread(thread_id)
# Async methods
async def alist(self, config):
"""Async list checkpoints."""
with self.get_saver() as saver:
async for item in saver.alist(config):
yield item
async def aget(self, config):
"""Async get a checkpoint."""
with self.get_saver() as saver:
# PostgresSaver might not have async version, try sync first
try:
return await saver.aget(config)
except NotImplementedError:
# Fall back to sync version in a thread
import asyncio
return await asyncio.get_event_loop().run_in_executor(
None, saver.get, config
)
async def aget_tuple(self, config):
"""Async get a checkpoint tuple."""
with self.get_saver() as saver:
# PostgresSaver might not have async version, try sync first
try:
return await saver.aget_tuple(config)
except NotImplementedError:
# Fall back to sync version in a thread
import asyncio
return await asyncio.get_event_loop().run_in_executor(
None, saver.get_tuple, config
)
async def aput(self, config, checkpoint, metadata, new_versions):
"""Async put a checkpoint."""
with self.get_saver() as saver:
# PostgresSaver might not have async version, try sync first
try:
return await saver.aput(config, checkpoint, metadata, new_versions)
except NotImplementedError:
# Fall back to sync version in a thread
import asyncio
return await asyncio.get_event_loop().run_in_executor(
None, saver.put, config, checkpoint, metadata, new_versions
)
async def aput_writes(self, config, writes, task_id):
"""Async put writes."""
with self.get_saver() as saver:
# PostgresSaver might not have async version, try sync first
try:
return await saver.aput_writes(config, writes, task_id)
except NotImplementedError:
# Fall back to sync version in a thread
import asyncio
return await asyncio.get_event_loop().run_in_executor(
None, saver.put_writes, config, writes, task_id
)
async def adelete_thread(self, thread_id):
"""Async delete thread."""
with self.get_saver() as saver:
return await saver.adelete_thread(thread_id)
@property
def config_specs(self):
"""Get config specs."""
with self.get_saver() as saver:
return saver.config_specs
@property
def serde(self):
"""Get serde."""
with self.get_saver() as saver:
return saver.serde
class PostgreSQLMemoryManager:
"""
PostgreSQL-based memory manager using LangGraph's built-in components.
Falls back to in-memory storage if PostgreSQL is not available.
"""
def __init__(self):
self.config = get_config()
self.pg_config = self.config.postgresql
self._checkpointer: Optional[Any] = None
self._postgres_available = POSTGRES_AVAILABLE
def _get_connection_string(self) -> str:
"""Get PostgreSQL connection string."""
if not self._postgres_available:
return ""
# URL encode password to handle special characters
encoded_password = quote_plus(self.pg_config.password)
return (
f"postgresql://{self.pg_config.username}:{encoded_password}@"
f"{self.pg_config.host}:{self.pg_config.port}/{self.pg_config.database}"
)
def _test_connection(self) -> bool:
"""Test PostgreSQL connection."""
if not self._postgres_available:
return False
if not PSYCOPG_AVAILABLE or psycopg is None:
return False
try:
conn_string = self._get_connection_string()
with psycopg.connect(conn_string) as conn:
with conn.cursor() as cur:
cur.execute("SELECT 1")
result = cur.fetchone()
logger.info("PostgreSQL connection test successful")
return True
except Exception as e:
logger.error(f"PostgreSQL connection test failed: {e}")
return False
def _setup_ttl_cleanup(self):
"""Setup TTL cleanup for old records."""
if not self._postgres_available or not PSYCOPG_AVAILABLE or psycopg is None:
return
try:
conn_string = self._get_connection_string()
with psycopg.connect(conn_string, autocommit=True) as conn:
with conn.cursor() as cur:
# Create a function to clean up old records for LangGraph tables
# Note: LangGraph tables don't have created_at, so we'll use a different approach
cleanup_sql = f"""
CREATE OR REPLACE FUNCTION cleanup_old_checkpoints()
RETURNS void AS $$
BEGIN
-- LangGraph tables don't have created_at columns
-- We can clean based on checkpoint_id pattern or use a different strategy
-- For now, just return successfully without actual cleanup
-- You can implement custom logic based on your requirements
RAISE NOTICE 'Cleanup function called - custom cleanup logic needed';
END;
$$ LANGUAGE plpgsql;
"""
cur.execute(cleanup_sql)
logger.info(f"TTL cleanup function created with {self.pg_config.ttl_days}-day retention")
except Exception as e:
logger.warning(f"Failed to setup TTL cleanup (this is optional): {e}")
def cleanup_old_data(self):
"""Manually trigger cleanup of old data."""
if not self._postgres_available or not PSYCOPG_AVAILABLE or psycopg is None:
return
try:
conn_string = self._get_connection_string()
with psycopg.connect(conn_string, autocommit=True) as conn:
with conn.cursor() as cur:
cur.execute("SELECT cleanup_old_checkpoints()")
logger.info("Manual cleanup of old data completed")
except Exception as e:
logger.error(f"Failed to cleanup old data: {e}")
def get_checkpointer(self):
"""Get checkpointer for conversation history (PostgreSQL if available, else in-memory)."""
if self._checkpointer is None:
if self._postgres_available:
try:
# Test connection first
if not self._test_connection():
raise Exception("PostgreSQL connection test failed")
# Setup TTL cleanup function
self._setup_ttl_cleanup()
# Create checkpointer wrapper
conn_string = self._get_connection_string()
if LANGGRAPH_POSTGRES_AVAILABLE:
self._checkpointer = PostgreSQLCheckpointerWrapper(conn_string)
else:
raise Exception("LangGraph PostgreSQL checkpoint not available")
logger.info(f"PostgreSQL checkpointer initialized with {self.pg_config.ttl_days}-day TTL")
except Exception as e:
logger.error(f"Failed to initialize PostgreSQL checkpointer, falling back to in-memory: {e}")
if LANGGRAPH_MEMORY_AVAILABLE and InMemorySaver is not None:
self._checkpointer = InMemorySaver()
else:
logger.error("InMemorySaver not available - no checkpointer available")
self._checkpointer = None
else:
logger.info("PostgreSQL not available, using in-memory checkpointer")
if LANGGRAPH_MEMORY_AVAILABLE and InMemorySaver is not None:
self._checkpointer = InMemorySaver()
else:
logger.error("InMemorySaver not available - no checkpointer available")
self._checkpointer = None
return self._checkpointer
def test_connection(self) -> bool:
"""Test PostgreSQL connection and return True if successful."""
return self._test_connection()
# Global memory manager instance
_memory_manager: Optional[PostgreSQLMemoryManager] = None
def get_memory_manager() -> PostgreSQLMemoryManager:
"""Get global PostgreSQL memory manager instance."""
global _memory_manager
if _memory_manager is None:
_memory_manager = PostgreSQLMemoryManager()
return _memory_manager
def get_checkpointer():
"""Get checkpointer for conversation history."""
return get_memory_manager().get_checkpointer()

View File

@@ -0,0 +1,137 @@
"""
Redis-based memory implementation using LangGraph built-in components.
Provides session-level chat history with 7-day TTL.
"""
import logging
import ssl
from typing import Dict, Any, Optional
try:
import redis
from redis.exceptions import ConnectionError, TimeoutError
from langgraph.checkpoint.redis import RedisSaver
REDIS_AVAILABLE = True
except ImportError as e:
logging.warning(f"Redis packages not available: {e}")
REDIS_AVAILABLE = False
redis = None
RedisSaver = None
from langgraph.checkpoint.memory import InMemorySaver
from ..config import get_config
logger = logging.getLogger(__name__)
class RedisMemoryManager:
"""
Redis-based memory manager using LangGraph's built-in components.
Falls back to in-memory storage if Redis is not available.
"""
def __init__(self):
self.config = get_config()
self.redis_config = self.config.redis
self._checkpointer: Optional[Any] = None
self._redis_available = REDIS_AVAILABLE
def _get_redis_client_kwargs(self) -> Dict[str, Any]:
"""Get Redis client configuration for Azure Redis compatibility."""
if not self._redis_available:
return {}
kwargs = {
"host": self.redis_config.host,
"port": self.redis_config.port,
"password": self.redis_config.password,
"db": self.redis_config.db,
"decode_responses": False, # Required for RedisSaver
"socket_timeout": 30,
"socket_connect_timeout": 10,
"retry_on_timeout": True,
"health_check_interval": 30,
}
if self.redis_config.use_ssl:
kwargs.update({
"ssl": True,
"ssl_cert_reqs": ssl.CERT_REQUIRED,
"ssl_check_hostname": True,
})
return kwargs
def _get_ttl_config(self) -> Dict[str, Any]:
"""Get TTL configuration for automatic cleanup."""
ttl_days = self.redis_config.ttl_days
ttl_minutes = ttl_days * 24 * 60 # Convert days to minutes
return {
"default_ttl": ttl_minutes,
"refresh_on_read": True, # Refresh TTL when accessed
}
def get_checkpointer(self):
"""Get checkpointer for conversation history (Redis if available, else in-memory)."""
if self._checkpointer is None:
if self._redis_available:
try:
ttl_config = self._get_ttl_config()
# Create Redis client with proper configuration for Azure Redis
redis_client = redis.Redis(**self._get_redis_client_kwargs())
# Test connection
redis_client.ping()
logger.info("Redis connection established successfully")
# Create checkpointer with TTL support
self._checkpointer = RedisSaver(
redis_client=redis_client,
ttl=ttl_config
)
# Initialize indices (required for first-time setup)
self._checkpointer.setup()
logger.info(f"Redis checkpointer initialized with {self.redis_config.ttl_days}-day TTL")
except Exception as e:
logger.error(f"Failed to initialize Redis checkpointer, falling back to in-memory: {e}")
self._checkpointer = InMemorySaver()
else:
logger.info("Redis not available, using in-memory checkpointer")
self._checkpointer = InMemorySaver()
return self._checkpointer
def test_connection(self) -> bool:
"""Test Redis connection and return True if successful."""
if not self._redis_available:
logger.warning("Redis packages not available")
return False
try:
redis_client = redis.Redis(**self._get_redis_client_kwargs())
redis_client.ping()
logger.info("Redis connection test successful")
return True
except Exception as e:
logger.error(f"Redis connection test failed: {e}")
return False
# Global memory manager instance
_memory_manager: Optional[RedisMemoryManager] = None
def get_memory_manager() -> RedisMemoryManager:
"""Get global Redis memory manager instance."""
global _memory_manager
if _memory_manager is None:
_memory_manager = RedisMemoryManager()
return _memory_manager
def get_checkpointer():
"""Get checkpointer for conversation history."""
return get_memory_manager().get_checkpointer()

View File

@@ -0,0 +1,113 @@
from typing import Dict, Any, Optional
from datetime import datetime, timedelta
import logging
from .postgresql_memory import get_memory_manager, get_checkpointer
from ..graph.state import TurnState, Message
logger = logging.getLogger(__name__)
class InMemoryStore:
"""Simple in-memory store with TTL for conversation history"""
def __init__(self, ttl_days: float = 7.0):
self.ttl_days = ttl_days
self.store: Dict[str, Dict[str, Any]] = {}
def _is_expired(self, timestamp: datetime) -> bool:
"""Check if a record has expired"""
return datetime.now() - timestamp > timedelta(days=self.ttl_days)
def _cleanup_expired(self) -> None:
"""Remove expired records"""
expired_keys = []
for session_id, data in self.store.items():
if self._is_expired(data.get("last_updated", datetime.min)):
expired_keys.append(session_id)
for key in expired_keys:
del self.store[key]
logger.info(f"Cleaned up expired session: {key}")
def get(self, session_id: str) -> Optional[TurnState]:
"""Get conversation state for a session"""
self._cleanup_expired()
if session_id not in self.store:
return None
data = self.store[session_id]
if self._is_expired(data.get("last_updated", datetime.min)):
del self.store[session_id]
return None
try:
# Reconstruct TurnState from stored data
state_data = data["state"]
return TurnState(**state_data)
except Exception as e:
logger.error(f"Failed to deserialize state for session {session_id}: {e}")
return None
def put(self, session_id: str, state: TurnState) -> None:
"""Store conversation state for a session"""
try:
self.store[session_id] = {
"state": state.model_dump(),
"last_updated": datetime.now()
}
logger.debug(f"Stored state for session: {session_id}")
except Exception as e:
logger.error(f"Failed to store state for session {session_id}: {e}")
def trim(self, session_id: str, max_messages: int = 20) -> None:
"""Trim old messages to stay within token limits"""
state = self.get(session_id)
if not state:
return
if len(state.messages) > max_messages:
# Keep system message (if any) and recent user/assistant pairs
trimmed_messages = state.messages[-max_messages:]
# Try to preserve complete conversation turns
if len(trimmed_messages) > 1 and trimmed_messages[0].role == "assistant":
trimmed_messages = trimmed_messages[1:]
state.messages = trimmed_messages
self.put(session_id, state)
logger.info(f"Trimmed messages for session {session_id} to {len(trimmed_messages)}")
def create_new_session(self, session_id: str) -> TurnState:
"""Create a new conversation session"""
state = TurnState(session_id=session_id)
self.put(session_id, state)
return state
def add_message(self, session_id: str, message: Message) -> None:
"""Add a message to the conversation history"""
state = self.get(session_id)
if not state:
state = self.create_new_session(session_id)
state.messages.append(message)
self.put(session_id, state)
def get_conversation_history(self, session_id: str, max_turns: int = 10) -> str:
"""Get formatted conversation history for prompts"""
state = self.get(session_id)
if not state or not state.messages:
return ""
# Get recent messages, keeping complete turns
recent_messages = state.messages[-(max_turns * 2):]
history_parts = []
for msg in recent_messages:
if msg.role == "user":
history_parts.append(f"User: {msg.content}")
elif msg.role == "assistant" and not msg.tool_call_id:
history_parts.append(f"Assistant: {msg.content}")
return "\n".join(history_parts)