333 lines
13 KiB
Python
333 lines
13 KiB
Python
"""
|
|
PostgreSQL-based memory implementation using LangGraph built-in components.
|
|
Provides session-level chat history with 7-day TTL.
|
|
Uses psycopg3 for better compatibility without requiring libpq-dev.
|
|
"""
|
|
import logging
|
|
from typing import Dict, Any, Optional
|
|
from urllib.parse import quote_plus
|
|
from contextlib import contextmanager
|
|
|
|
try:
|
|
import psycopg
|
|
from psycopg.rows import dict_row
|
|
PSYCOPG_AVAILABLE = True
|
|
except ImportError as e:
|
|
logging.warning(f"psycopg3 not available: {e}")
|
|
PSYCOPG_AVAILABLE = False
|
|
psycopg = None
|
|
|
|
try:
|
|
from langgraph.checkpoint.postgres import PostgresSaver
|
|
LANGGRAPH_POSTGRES_AVAILABLE = True
|
|
except ImportError as e:
|
|
logging.warning(f"LangGraph PostgreSQL checkpoint not available: {e}")
|
|
LANGGRAPH_POSTGRES_AVAILABLE = False
|
|
PostgresSaver = None
|
|
|
|
try:
|
|
from langgraph.checkpoint.memory import InMemorySaver
|
|
LANGGRAPH_MEMORY_AVAILABLE = True
|
|
except ImportError as e:
|
|
logging.warning(f"LangGraph memory checkpoint not available: {e}")
|
|
LANGGRAPH_MEMORY_AVAILABLE = False
|
|
InMemorySaver = None
|
|
|
|
from ..config import get_config
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
POSTGRES_AVAILABLE = PSYCOPG_AVAILABLE and LANGGRAPH_POSTGRES_AVAILABLE
|
|
|
|
|
|
class PostgreSQLCheckpointerWrapper:
|
|
"""
|
|
Wrapper for PostgresSaver that manages the context properly.
|
|
"""
|
|
|
|
def __init__(self, conn_string: str):
|
|
if not LANGGRAPH_POSTGRES_AVAILABLE or PostgresSaver is None:
|
|
raise RuntimeError("PostgresSaver not available")
|
|
self.conn_string = conn_string
|
|
self._initialized = False
|
|
|
|
def _ensure_setup(self):
|
|
"""Ensure the database schema is set up."""
|
|
if not LANGGRAPH_POSTGRES_AVAILABLE or PostgresSaver is None:
|
|
raise RuntimeError("PostgresSaver not available")
|
|
if not self._initialized:
|
|
with PostgresSaver.from_conn_string(self.conn_string) as saver:
|
|
saver.setup()
|
|
self._initialized = True
|
|
logger.info("PostgreSQL schema initialized")
|
|
|
|
@contextmanager
|
|
def get_saver(self):
|
|
"""Get a PostgresSaver instance as context manager."""
|
|
if not LANGGRAPH_POSTGRES_AVAILABLE or PostgresSaver is None:
|
|
raise RuntimeError("PostgresSaver not available")
|
|
self._ensure_setup()
|
|
with PostgresSaver.from_conn_string(self.conn_string) as saver:
|
|
yield saver
|
|
|
|
def list(self, config):
|
|
"""List checkpoints."""
|
|
with self.get_saver() as saver:
|
|
return list(saver.list(config))
|
|
|
|
def get(self, config):
|
|
"""Get a checkpoint."""
|
|
with self.get_saver() as saver:
|
|
return saver.get(config)
|
|
|
|
def get_tuple(self, config):
|
|
"""Get a checkpoint tuple."""
|
|
with self.get_saver() as saver:
|
|
return saver.get_tuple(config)
|
|
|
|
def put(self, config, checkpoint, metadata, new_versions):
|
|
"""Put a checkpoint."""
|
|
with self.get_saver() as saver:
|
|
return saver.put(config, checkpoint, metadata, new_versions)
|
|
|
|
def put_writes(self, config, writes, task_id):
|
|
"""Put writes."""
|
|
with self.get_saver() as saver:
|
|
return saver.put_writes(config, writes, task_id)
|
|
|
|
def get_next_version(self, current, channel):
|
|
"""Get next version."""
|
|
with self.get_saver() as saver:
|
|
return saver.get_next_version(current, channel)
|
|
|
|
def delete_thread(self, thread_id):
|
|
"""Delete thread."""
|
|
with self.get_saver() as saver:
|
|
return saver.delete_thread(thread_id)
|
|
|
|
# Async methods
|
|
async def alist(self, config):
|
|
"""Async list checkpoints."""
|
|
with self.get_saver() as saver:
|
|
async for item in saver.alist(config):
|
|
yield item
|
|
|
|
async def aget(self, config):
|
|
"""Async get a checkpoint."""
|
|
with self.get_saver() as saver:
|
|
# PostgresSaver might not have async version, try sync first
|
|
try:
|
|
return await saver.aget(config)
|
|
except NotImplementedError:
|
|
# Fall back to sync version in a thread
|
|
import asyncio
|
|
return await asyncio.get_event_loop().run_in_executor(
|
|
None, saver.get, config
|
|
)
|
|
|
|
async def aget_tuple(self, config):
|
|
"""Async get a checkpoint tuple."""
|
|
with self.get_saver() as saver:
|
|
# PostgresSaver might not have async version, try sync first
|
|
try:
|
|
return await saver.aget_tuple(config)
|
|
except NotImplementedError:
|
|
# Fall back to sync version in a thread
|
|
import asyncio
|
|
return await asyncio.get_event_loop().run_in_executor(
|
|
None, saver.get_tuple, config
|
|
)
|
|
|
|
async def aput(self, config, checkpoint, metadata, new_versions):
|
|
"""Async put a checkpoint."""
|
|
with self.get_saver() as saver:
|
|
# PostgresSaver might not have async version, try sync first
|
|
try:
|
|
return await saver.aput(config, checkpoint, metadata, new_versions)
|
|
except NotImplementedError:
|
|
# Fall back to sync version in a thread
|
|
import asyncio
|
|
return await asyncio.get_event_loop().run_in_executor(
|
|
None, saver.put, config, checkpoint, metadata, new_versions
|
|
)
|
|
|
|
async def aput_writes(self, config, writes, task_id):
|
|
"""Async put writes."""
|
|
with self.get_saver() as saver:
|
|
# PostgresSaver might not have async version, try sync first
|
|
try:
|
|
return await saver.aput_writes(config, writes, task_id)
|
|
except NotImplementedError:
|
|
# Fall back to sync version in a thread
|
|
import asyncio
|
|
return await asyncio.get_event_loop().run_in_executor(
|
|
None, saver.put_writes, config, writes, task_id
|
|
)
|
|
|
|
async def adelete_thread(self, thread_id):
|
|
"""Async delete thread."""
|
|
with self.get_saver() as saver:
|
|
return await saver.adelete_thread(thread_id)
|
|
|
|
@property
|
|
def config_specs(self):
|
|
"""Get config specs."""
|
|
with self.get_saver() as saver:
|
|
return saver.config_specs
|
|
|
|
@property
|
|
def serde(self):
|
|
"""Get serde."""
|
|
with self.get_saver() as saver:
|
|
return saver.serde
|
|
|
|
|
|
class PostgreSQLMemoryManager:
|
|
"""
|
|
PostgreSQL-based memory manager using LangGraph's built-in components.
|
|
Falls back to in-memory storage if PostgreSQL is not available.
|
|
"""
|
|
|
|
def __init__(self):
|
|
self.config = get_config()
|
|
self.pg_config = self.config.postgresql
|
|
self._checkpointer: Optional[Any] = None
|
|
self._postgres_available = POSTGRES_AVAILABLE
|
|
|
|
def _get_connection_string(self) -> str:
|
|
"""Get PostgreSQL connection string."""
|
|
if not self._postgres_available:
|
|
return ""
|
|
|
|
# URL encode password to handle special characters
|
|
encoded_password = quote_plus(self.pg_config.password)
|
|
|
|
return (
|
|
f"postgresql://{self.pg_config.username}:{encoded_password}@"
|
|
f"{self.pg_config.host}:{self.pg_config.port}/{self.pg_config.database}"
|
|
)
|
|
|
|
def _test_connection(self) -> bool:
|
|
"""Test PostgreSQL connection."""
|
|
if not self._postgres_available:
|
|
return False
|
|
|
|
if not PSYCOPG_AVAILABLE or psycopg is None:
|
|
return False
|
|
|
|
try:
|
|
conn_string = self._get_connection_string()
|
|
with psycopg.connect(conn_string) as conn:
|
|
with conn.cursor() as cur:
|
|
cur.execute("SELECT 1")
|
|
result = cur.fetchone()
|
|
logger.info("PostgreSQL connection test successful")
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"PostgreSQL connection test failed: {e}")
|
|
return False
|
|
|
|
def _setup_ttl_cleanup(self):
|
|
"""Setup TTL cleanup for old records."""
|
|
if not self._postgres_available or not PSYCOPG_AVAILABLE or psycopg is None:
|
|
return
|
|
|
|
try:
|
|
conn_string = self._get_connection_string()
|
|
with psycopg.connect(conn_string, autocommit=True) as conn:
|
|
with conn.cursor() as cur:
|
|
# Create a function to clean up old records for LangGraph tables
|
|
# Note: LangGraph tables don't have created_at, so we'll use a different approach
|
|
cleanup_sql = f"""
|
|
CREATE OR REPLACE FUNCTION cleanup_old_checkpoints()
|
|
RETURNS void AS $$
|
|
BEGIN
|
|
-- LangGraph tables don't have created_at columns
|
|
-- We can clean based on checkpoint_id pattern or use a different strategy
|
|
-- For now, just return successfully without actual cleanup
|
|
-- You can implement custom logic based on your requirements
|
|
RAISE NOTICE 'Cleanup function called - custom cleanup logic needed';
|
|
END;
|
|
$$ LANGUAGE plpgsql;
|
|
"""
|
|
cur.execute(cleanup_sql)
|
|
|
|
logger.info(f"TTL cleanup function created with {self.pg_config.ttl_days}-day retention")
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Failed to setup TTL cleanup (this is optional): {e}")
|
|
|
|
def cleanup_old_data(self):
|
|
"""Manually trigger cleanup of old data."""
|
|
if not self._postgres_available or not PSYCOPG_AVAILABLE or psycopg is None:
|
|
return
|
|
|
|
try:
|
|
conn_string = self._get_connection_string()
|
|
with psycopg.connect(conn_string, autocommit=True) as conn:
|
|
with conn.cursor() as cur:
|
|
cur.execute("SELECT cleanup_old_checkpoints()")
|
|
logger.info("Manual cleanup of old data completed")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to cleanup old data: {e}")
|
|
|
|
def get_checkpointer(self):
|
|
"""Get checkpointer for conversation history (PostgreSQL if available, else in-memory)."""
|
|
if self._checkpointer is None:
|
|
if self._postgres_available:
|
|
try:
|
|
# Test connection first
|
|
if not self._test_connection():
|
|
raise Exception("PostgreSQL connection test failed")
|
|
|
|
# Setup TTL cleanup function
|
|
self._setup_ttl_cleanup()
|
|
|
|
# Create checkpointer wrapper
|
|
conn_string = self._get_connection_string()
|
|
if LANGGRAPH_POSTGRES_AVAILABLE:
|
|
self._checkpointer = PostgreSQLCheckpointerWrapper(conn_string)
|
|
else:
|
|
raise Exception("LangGraph PostgreSQL checkpoint not available")
|
|
|
|
logger.info(f"PostgreSQL checkpointer initialized with {self.pg_config.ttl_days}-day TTL")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to initialize PostgreSQL checkpointer, falling back to in-memory: {e}")
|
|
if LANGGRAPH_MEMORY_AVAILABLE and InMemorySaver is not None:
|
|
self._checkpointer = InMemorySaver()
|
|
else:
|
|
logger.error("InMemorySaver not available - no checkpointer available")
|
|
self._checkpointer = None
|
|
else:
|
|
logger.info("PostgreSQL not available, using in-memory checkpointer")
|
|
if LANGGRAPH_MEMORY_AVAILABLE and InMemorySaver is not None:
|
|
self._checkpointer = InMemorySaver()
|
|
else:
|
|
logger.error("InMemorySaver not available - no checkpointer available")
|
|
self._checkpointer = None
|
|
|
|
return self._checkpointer
|
|
|
|
def test_connection(self) -> bool:
|
|
"""Test PostgreSQL connection and return True if successful."""
|
|
return self._test_connection()
|
|
|
|
|
|
# Global memory manager instance
|
|
_memory_manager: Optional[PostgreSQLMemoryManager] = None
|
|
|
|
|
|
def get_memory_manager() -> PostgreSQLMemoryManager:
|
|
"""Get global PostgreSQL memory manager instance."""
|
|
global _memory_manager
|
|
if _memory_manager is None:
|
|
_memory_manager = PostgreSQLMemoryManager()
|
|
return _memory_manager
|
|
|
|
|
|
def get_checkpointer():
|
|
"""Get checkpointer for conversation history."""
|
|
return get_memory_manager().get_checkpointer()
|