Files
catonline_ai/vw-agentic-rag/service/memory/postgresql_memory.py
2025-09-26 17:15:54 +08:00

333 lines
13 KiB
Python

"""
PostgreSQL-based memory implementation using LangGraph built-in components.
Provides session-level chat history with 7-day TTL.
Uses psycopg3 for better compatibility without requiring libpq-dev.
"""
import logging
from typing import Dict, Any, Optional
from urllib.parse import quote_plus
from contextlib import contextmanager
try:
import psycopg
from psycopg.rows import dict_row
PSYCOPG_AVAILABLE = True
except ImportError as e:
logging.warning(f"psycopg3 not available: {e}")
PSYCOPG_AVAILABLE = False
psycopg = None
try:
from langgraph.checkpoint.postgres import PostgresSaver
LANGGRAPH_POSTGRES_AVAILABLE = True
except ImportError as e:
logging.warning(f"LangGraph PostgreSQL checkpoint not available: {e}")
LANGGRAPH_POSTGRES_AVAILABLE = False
PostgresSaver = None
try:
from langgraph.checkpoint.memory import InMemorySaver
LANGGRAPH_MEMORY_AVAILABLE = True
except ImportError as e:
logging.warning(f"LangGraph memory checkpoint not available: {e}")
LANGGRAPH_MEMORY_AVAILABLE = False
InMemorySaver = None
from ..config import get_config
logger = logging.getLogger(__name__)
POSTGRES_AVAILABLE = PSYCOPG_AVAILABLE and LANGGRAPH_POSTGRES_AVAILABLE
class PostgreSQLCheckpointerWrapper:
"""
Wrapper for PostgresSaver that manages the context properly.
"""
def __init__(self, conn_string: str):
if not LANGGRAPH_POSTGRES_AVAILABLE or PostgresSaver is None:
raise RuntimeError("PostgresSaver not available")
self.conn_string = conn_string
self._initialized = False
def _ensure_setup(self):
"""Ensure the database schema is set up."""
if not LANGGRAPH_POSTGRES_AVAILABLE or PostgresSaver is None:
raise RuntimeError("PostgresSaver not available")
if not self._initialized:
with PostgresSaver.from_conn_string(self.conn_string) as saver:
saver.setup()
self._initialized = True
logger.info("PostgreSQL schema initialized")
@contextmanager
def get_saver(self):
"""Get a PostgresSaver instance as context manager."""
if not LANGGRAPH_POSTGRES_AVAILABLE or PostgresSaver is None:
raise RuntimeError("PostgresSaver not available")
self._ensure_setup()
with PostgresSaver.from_conn_string(self.conn_string) as saver:
yield saver
def list(self, config):
"""List checkpoints."""
with self.get_saver() as saver:
return list(saver.list(config))
def get(self, config):
"""Get a checkpoint."""
with self.get_saver() as saver:
return saver.get(config)
def get_tuple(self, config):
"""Get a checkpoint tuple."""
with self.get_saver() as saver:
return saver.get_tuple(config)
def put(self, config, checkpoint, metadata, new_versions):
"""Put a checkpoint."""
with self.get_saver() as saver:
return saver.put(config, checkpoint, metadata, new_versions)
def put_writes(self, config, writes, task_id):
"""Put writes."""
with self.get_saver() as saver:
return saver.put_writes(config, writes, task_id)
def get_next_version(self, current, channel):
"""Get next version."""
with self.get_saver() as saver:
return saver.get_next_version(current, channel)
def delete_thread(self, thread_id):
"""Delete thread."""
with self.get_saver() as saver:
return saver.delete_thread(thread_id)
# Async methods
async def alist(self, config):
"""Async list checkpoints."""
with self.get_saver() as saver:
async for item in saver.alist(config):
yield item
async def aget(self, config):
"""Async get a checkpoint."""
with self.get_saver() as saver:
# PostgresSaver might not have async version, try sync first
try:
return await saver.aget(config)
except NotImplementedError:
# Fall back to sync version in a thread
import asyncio
return await asyncio.get_event_loop().run_in_executor(
None, saver.get, config
)
async def aget_tuple(self, config):
"""Async get a checkpoint tuple."""
with self.get_saver() as saver:
# PostgresSaver might not have async version, try sync first
try:
return await saver.aget_tuple(config)
except NotImplementedError:
# Fall back to sync version in a thread
import asyncio
return await asyncio.get_event_loop().run_in_executor(
None, saver.get_tuple, config
)
async def aput(self, config, checkpoint, metadata, new_versions):
"""Async put a checkpoint."""
with self.get_saver() as saver:
# PostgresSaver might not have async version, try sync first
try:
return await saver.aput(config, checkpoint, metadata, new_versions)
except NotImplementedError:
# Fall back to sync version in a thread
import asyncio
return await asyncio.get_event_loop().run_in_executor(
None, saver.put, config, checkpoint, metadata, new_versions
)
async def aput_writes(self, config, writes, task_id):
"""Async put writes."""
with self.get_saver() as saver:
# PostgresSaver might not have async version, try sync first
try:
return await saver.aput_writes(config, writes, task_id)
except NotImplementedError:
# Fall back to sync version in a thread
import asyncio
return await asyncio.get_event_loop().run_in_executor(
None, saver.put_writes, config, writes, task_id
)
async def adelete_thread(self, thread_id):
"""Async delete thread."""
with self.get_saver() as saver:
return await saver.adelete_thread(thread_id)
@property
def config_specs(self):
"""Get config specs."""
with self.get_saver() as saver:
return saver.config_specs
@property
def serde(self):
"""Get serde."""
with self.get_saver() as saver:
return saver.serde
class PostgreSQLMemoryManager:
"""
PostgreSQL-based memory manager using LangGraph's built-in components.
Falls back to in-memory storage if PostgreSQL is not available.
"""
def __init__(self):
self.config = get_config()
self.pg_config = self.config.postgresql
self._checkpointer: Optional[Any] = None
self._postgres_available = POSTGRES_AVAILABLE
def _get_connection_string(self) -> str:
"""Get PostgreSQL connection string."""
if not self._postgres_available:
return ""
# URL encode password to handle special characters
encoded_password = quote_plus(self.pg_config.password)
return (
f"postgresql://{self.pg_config.username}:{encoded_password}@"
f"{self.pg_config.host}:{self.pg_config.port}/{self.pg_config.database}"
)
def _test_connection(self) -> bool:
"""Test PostgreSQL connection."""
if not self._postgres_available:
return False
if not PSYCOPG_AVAILABLE or psycopg is None:
return False
try:
conn_string = self._get_connection_string()
with psycopg.connect(conn_string) as conn:
with conn.cursor() as cur:
cur.execute("SELECT 1")
result = cur.fetchone()
logger.info("PostgreSQL connection test successful")
return True
except Exception as e:
logger.error(f"PostgreSQL connection test failed: {e}")
return False
def _setup_ttl_cleanup(self):
"""Setup TTL cleanup for old records."""
if not self._postgres_available or not PSYCOPG_AVAILABLE or psycopg is None:
return
try:
conn_string = self._get_connection_string()
with psycopg.connect(conn_string, autocommit=True) as conn:
with conn.cursor() as cur:
# Create a function to clean up old records for LangGraph tables
# Note: LangGraph tables don't have created_at, so we'll use a different approach
cleanup_sql = f"""
CREATE OR REPLACE FUNCTION cleanup_old_checkpoints()
RETURNS void AS $$
BEGIN
-- LangGraph tables don't have created_at columns
-- We can clean based on checkpoint_id pattern or use a different strategy
-- For now, just return successfully without actual cleanup
-- You can implement custom logic based on your requirements
RAISE NOTICE 'Cleanup function called - custom cleanup logic needed';
END;
$$ LANGUAGE plpgsql;
"""
cur.execute(cleanup_sql)
logger.info(f"TTL cleanup function created with {self.pg_config.ttl_days}-day retention")
except Exception as e:
logger.warning(f"Failed to setup TTL cleanup (this is optional): {e}")
def cleanup_old_data(self):
"""Manually trigger cleanup of old data."""
if not self._postgres_available or not PSYCOPG_AVAILABLE or psycopg is None:
return
try:
conn_string = self._get_connection_string()
with psycopg.connect(conn_string, autocommit=True) as conn:
with conn.cursor() as cur:
cur.execute("SELECT cleanup_old_checkpoints()")
logger.info("Manual cleanup of old data completed")
except Exception as e:
logger.error(f"Failed to cleanup old data: {e}")
def get_checkpointer(self):
"""Get checkpointer for conversation history (PostgreSQL if available, else in-memory)."""
if self._checkpointer is None:
if self._postgres_available:
try:
# Test connection first
if not self._test_connection():
raise Exception("PostgreSQL connection test failed")
# Setup TTL cleanup function
self._setup_ttl_cleanup()
# Create checkpointer wrapper
conn_string = self._get_connection_string()
if LANGGRAPH_POSTGRES_AVAILABLE:
self._checkpointer = PostgreSQLCheckpointerWrapper(conn_string)
else:
raise Exception("LangGraph PostgreSQL checkpoint not available")
logger.info(f"PostgreSQL checkpointer initialized with {self.pg_config.ttl_days}-day TTL")
except Exception as e:
logger.error(f"Failed to initialize PostgreSQL checkpointer, falling back to in-memory: {e}")
if LANGGRAPH_MEMORY_AVAILABLE and InMemorySaver is not None:
self._checkpointer = InMemorySaver()
else:
logger.error("InMemorySaver not available - no checkpointer available")
self._checkpointer = None
else:
logger.info("PostgreSQL not available, using in-memory checkpointer")
if LANGGRAPH_MEMORY_AVAILABLE and InMemorySaver is not None:
self._checkpointer = InMemorySaver()
else:
logger.error("InMemorySaver not available - no checkpointer available")
self._checkpointer = None
return self._checkpointer
def test_connection(self) -> bool:
"""Test PostgreSQL connection and return True if successful."""
return self._test_connection()
# Global memory manager instance
_memory_manager: Optional[PostgreSQLMemoryManager] = None
def get_memory_manager() -> PostgreSQLMemoryManager:
"""Get global PostgreSQL memory manager instance."""
global _memory_manager
if _memory_manager is None:
_memory_manager = PostgreSQLMemoryManager()
return _memory_manager
def get_checkpointer():
"""Get checkpointer for conversation history."""
return get_memory_manager().get_checkpointer()