init
This commit is contained in:
1
vw-agentic-rag/service/__init__.py
Normal file
1
vw-agentic-rag/service/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Empty __init__.py files to make packages
|
||||
146
vw-agentic-rag/service/ai_sdk_adapter.py
Normal file
146
vw-agentic-rag/service/ai_sdk_adapter.py
Normal file
@@ -0,0 +1,146 @@
|
||||
"""
|
||||
AI SDK Data Stream Protocol adapter
|
||||
Converts our internal SSE events to AI SDK compatible format
|
||||
Following the official Data Stream Protocol: TYPE_ID:CONTENT_JSON\n
|
||||
"""
|
||||
import json
|
||||
import uuid
|
||||
from typing import Dict, Any, AsyncGenerator
|
||||
|
||||
|
||||
def format_data_stream_part(type_id: str, content: Any) -> str:
|
||||
"""Format data as AI SDK Data Stream Protocol part: TYPE_ID:JSON\n"""
|
||||
content_json = json.dumps(content, ensure_ascii=False)
|
||||
return f"{type_id}:{content_json}\n"
|
||||
|
||||
|
||||
def create_text_part(text: str) -> str:
|
||||
"""Create text part (type 0)"""
|
||||
return format_data_stream_part("0", text)
|
||||
|
||||
|
||||
def create_data_part(data: list) -> str:
|
||||
"""Create data part (type 2) for additional data"""
|
||||
return format_data_stream_part("2", data)
|
||||
|
||||
|
||||
def create_error_part(error: str) -> str:
|
||||
"""Create error part (type 3)"""
|
||||
return format_data_stream_part("3", error)
|
||||
|
||||
|
||||
def create_tool_call_part(tool_call_id: str, tool_name: str, args: dict) -> str:
|
||||
"""Create tool call part (type 9)"""
|
||||
return format_data_stream_part("9", {
|
||||
"toolCallId": tool_call_id,
|
||||
"toolName": tool_name,
|
||||
"args": args
|
||||
})
|
||||
|
||||
|
||||
def create_tool_result_part(tool_call_id: str, result: Any) -> str:
|
||||
"""Create tool result part (type a)"""
|
||||
return format_data_stream_part("a", {
|
||||
"toolCallId": tool_call_id,
|
||||
"result": result
|
||||
})
|
||||
|
||||
|
||||
def create_finish_step_part(finish_reason: str = "stop", usage: Dict[str, int] | None = None, is_continued: bool = False) -> str:
|
||||
"""Create finish step part (type e)"""
|
||||
usage = usage or {"promptTokens": 0, "completionTokens": 0}
|
||||
return format_data_stream_part("e", {
|
||||
"finishReason": finish_reason,
|
||||
"usage": usage,
|
||||
"isContinued": is_continued
|
||||
})
|
||||
|
||||
|
||||
def create_finish_message_part(finish_reason: str = "stop", usage: Dict[str, int] | None = None) -> str:
|
||||
"""Create finish message part (type d)"""
|
||||
usage = usage or {"promptTokens": 0, "completionTokens": 0}
|
||||
return format_data_stream_part("d", {
|
||||
"finishReason": finish_reason,
|
||||
"usage": usage
|
||||
})
|
||||
|
||||
|
||||
class AISDKEventAdapter:
|
||||
"""Adapter to convert our internal events to AI SDK Data Stream Protocol format"""
|
||||
|
||||
def __init__(self):
|
||||
self.tool_calls = {} # Track tool calls
|
||||
self.current_message_id = str(uuid.uuid4())
|
||||
|
||||
def convert_event(self, event_line: str) -> str | None:
|
||||
"""Convert our SSE event to AI SDK Data Stream Protocol format"""
|
||||
if not event_line.strip():
|
||||
return None
|
||||
|
||||
try:
|
||||
# Handle multi-line SSE format
|
||||
lines = event_line.strip().split('\n')
|
||||
event_type = None
|
||||
data = None
|
||||
|
||||
for line in lines:
|
||||
if line.startswith("event: "):
|
||||
event_type = line.replace("event: ", "")
|
||||
elif line.startswith("data: "):
|
||||
data_str = line[6:] # Remove "data: "
|
||||
if data_str:
|
||||
data = json.loads(data_str)
|
||||
|
||||
if event_type and data:
|
||||
return self._convert_by_type(event_type, data)
|
||||
|
||||
except (json.JSONDecodeError, IndexError, KeyError) as e:
|
||||
# Skip malformed events
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
def _convert_by_type(self, event_type: str, data: Dict[str, Any]) -> str | None:
|
||||
"""Convert event by type to Data Stream Protocol format"""
|
||||
|
||||
if event_type == "tokens":
|
||||
# Token streaming -> text part (type 0)
|
||||
delta = data.get("delta", "")
|
||||
if delta:
|
||||
return create_text_part(delta)
|
||||
|
||||
elif event_type == "tool_start":
|
||||
# Tool start -> tool call part (type 9)
|
||||
tool_id = data.get("id", str(uuid.uuid4()))
|
||||
tool_name = data.get("name", "unknown")
|
||||
args = data.get("args", {})
|
||||
self.tool_calls[tool_id] = {"name": tool_name, "args": args}
|
||||
return create_tool_call_part(tool_id, tool_name, args)
|
||||
|
||||
elif event_type == "tool_result":
|
||||
# Tool result -> tool result part (type a)
|
||||
tool_id = data.get("id", "")
|
||||
results = data.get("results", [])
|
||||
return create_tool_result_part(tool_id, results)
|
||||
|
||||
elif event_type == "tool_error":
|
||||
# Tool error -> error part (type 3)
|
||||
error = data.get("error", "Tool execution failed")
|
||||
return create_error_part(error)
|
||||
|
||||
elif event_type == "error":
|
||||
# Error -> error part (type 3)
|
||||
error = data.get("error", "Unknown error")
|
||||
return create_error_part(error)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def stream_ai_sdk_compatible(internal_stream: AsyncGenerator[str, None]) -> AsyncGenerator[str, None]:
|
||||
"""Convert our internal SSE stream to AI SDK Data Stream Protocol compatible format"""
|
||||
adapter = AISDKEventAdapter()
|
||||
|
||||
async for event in internal_stream:
|
||||
converted = adapter.convert_event(event)
|
||||
if converted:
|
||||
yield converted
|
||||
121
vw-agentic-rag/service/ai_sdk_chat.py
Normal file
121
vw-agentic-rag/service/ai_sdk_chat.py
Normal file
@@ -0,0 +1,121 @@
|
||||
"""
|
||||
AI SDK compatible chat endpoint
|
||||
"""
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from fastapi import Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from .config import get_config
|
||||
from .graph.state import TurnState, Message
|
||||
from .schemas.messages import ChatRequest
|
||||
from .ai_sdk_adapter import stream_ai_sdk_compatible
|
||||
from .sse import create_error_event
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def handle_ai_sdk_chat(request: ChatRequest, app_state) -> StreamingResponse:
|
||||
"""Handle chat request with AI SDK Data Stream Protocol"""
|
||||
|
||||
async def ai_sdk_stream() -> AsyncGenerator[str, None]:
|
||||
try:
|
||||
app_config = get_config()
|
||||
memory_manager = app_state.memory_manager
|
||||
graph = app_state.graph
|
||||
|
||||
# Prepare the new user message for LangGraph (session memory handled automatically)
|
||||
graph_config = {
|
||||
"configurable": {
|
||||
"thread_id": request.session_id
|
||||
}
|
||||
}
|
||||
|
||||
# Get the latest user message from AI SDK format
|
||||
new_user_message = None
|
||||
if request.messages:
|
||||
last_message = request.messages[-1]
|
||||
if last_message.get("role") == "user":
|
||||
new_user_message = HumanMessage(content=last_message.get("content", ""))
|
||||
|
||||
if not new_user_message:
|
||||
logger.error("No user message found in request")
|
||||
yield create_error_event("No user message provided")
|
||||
return
|
||||
|
||||
# Create event queue for internal streaming
|
||||
event_queue = asyncio.Queue()
|
||||
|
||||
async def stream_callback(event_str: str):
|
||||
await event_queue.put(event_str)
|
||||
|
||||
async def run_workflow():
|
||||
try:
|
||||
# Set stream callback in context for the workflow
|
||||
from .graph.graph import stream_callback_context
|
||||
stream_callback_context.set(stream_callback)
|
||||
|
||||
# Create TurnState with the new user message
|
||||
# AgenticWorkflow will handle LangGraph interaction and session history
|
||||
from .graph.state import TurnState, Message
|
||||
|
||||
turn_state = TurnState(
|
||||
messages=[Message(
|
||||
role="user",
|
||||
content=str(new_user_message.content),
|
||||
timestamp=None
|
||||
)],
|
||||
session_id=request.session_id,
|
||||
tool_results=[],
|
||||
final_answer=""
|
||||
)
|
||||
|
||||
# Use AgenticWorkflow.astream with stream_callback parameter
|
||||
async for final_state in graph.astream(turn_state, stream_callback=stream_callback):
|
||||
# The workflow handles all streaming internally via stream_callback
|
||||
pass # final_state contains the complete result
|
||||
await event_queue.put(None) # Signal completion
|
||||
except Exception as e:
|
||||
logger.error(f"Workflow execution error: {e}", exc_info=True)
|
||||
await event_queue.put(create_error_event(f"Processing error: {str(e)}"))
|
||||
await event_queue.put(None)
|
||||
|
||||
# Start workflow task
|
||||
workflow_task = asyncio.create_task(run_workflow())
|
||||
|
||||
# Convert internal events to AI SDK format
|
||||
async def internal_stream():
|
||||
try:
|
||||
while True:
|
||||
event = await event_queue.get()
|
||||
if event is None:
|
||||
break
|
||||
yield event
|
||||
finally:
|
||||
if not workflow_task.done():
|
||||
workflow_task.cancel()
|
||||
|
||||
# Stream converted events
|
||||
async for ai_sdk_event in stream_ai_sdk_compatible(internal_stream()):
|
||||
yield ai_sdk_event
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"AI SDK chat error: {e}")
|
||||
# Send error in AI SDK format
|
||||
from .ai_sdk_adapter import create_error_part
|
||||
yield create_error_part(f"Server error: {str(e)}")
|
||||
|
||||
return StreamingResponse(
|
||||
ai_sdk_stream(),
|
||||
media_type="text/plain",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"Access-Control-Allow-Origin": "*",
|
||||
"Access-Control-Allow-Headers": "*",
|
||||
"x-vercel-ai-data-stream": "v1", # AI SDK Data Stream Protocol header
|
||||
}
|
||||
)
|
||||
297
vw-agentic-rag/service/config.py
Normal file
297
vw-agentic-rag/service/config.py
Normal file
@@ -0,0 +1,297 @@
|
||||
import yaml
|
||||
import os
|
||||
from typing import Dict, Any, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class OpenAIConfig(BaseModel):
|
||||
base_url: str = "https://api.openai.com/v1"
|
||||
api_key: str
|
||||
model: str = "gpt-4o"
|
||||
|
||||
|
||||
class AzureConfig(BaseModel):
|
||||
base_url: str
|
||||
api_key: str
|
||||
deployment: str
|
||||
api_version: str = "2024-02-01"
|
||||
|
||||
|
||||
class EmbeddingConfig(BaseModel):
|
||||
base_url: str
|
||||
api_key: str
|
||||
model: str
|
||||
dimension: int
|
||||
api_version: Optional[str]
|
||||
|
||||
|
||||
class IndexConfig(BaseModel):
|
||||
standard_regulation_index: str
|
||||
chunk_index: str
|
||||
chunk_user_manual_index: str
|
||||
|
||||
|
||||
class RetrievalConfig(BaseModel):
|
||||
endpoint: str
|
||||
api_key: str
|
||||
api_version: str
|
||||
semantic_configuration: str
|
||||
embedding: EmbeddingConfig
|
||||
index: IndexConfig
|
||||
|
||||
|
||||
class PostgreSQLConfig(BaseModel):
|
||||
host: str
|
||||
port: int = 5432
|
||||
database: str
|
||||
username: str
|
||||
password: str
|
||||
ttl_days: int = 7
|
||||
|
||||
|
||||
class RedisConfig(BaseModel):
|
||||
host: str
|
||||
port: int = 6379
|
||||
password: str
|
||||
use_ssl: bool = True
|
||||
db: int = 0
|
||||
ttl_days: int = 7
|
||||
|
||||
|
||||
class AppLoggingConfig(BaseModel):
|
||||
level: str = "INFO"
|
||||
|
||||
|
||||
class AppConfig(BaseModel):
|
||||
name: str = "agentic-rag"
|
||||
memory_ttl_days: int = 7
|
||||
max_tool_rounds: int = 3 # Maximum allowed tool calling rounds
|
||||
max_tool_rounds_user_manual: int = 3 # Maximum allowed tool calling rounds for user manual agent
|
||||
cors_origins: list[str] = Field(default_factory=lambda: ["*"])
|
||||
logging: AppLoggingConfig = Field(default_factory=AppLoggingConfig)
|
||||
# Service configuration
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 8000
|
||||
|
||||
|
||||
class SearchConfig(BaseModel):
|
||||
"""Search index configuration"""
|
||||
standard_regulation_index: str = ""
|
||||
chunk_index: str = ""
|
||||
chunk_user_manual_index: str = ""
|
||||
|
||||
|
||||
class CitationConfig(BaseModel):
|
||||
"""Citation link configuration"""
|
||||
base_url: str = "" # Default empty string
|
||||
|
||||
|
||||
class LLMParametersConfig(BaseModel):
|
||||
"""LLM parameters configuration"""
|
||||
temperature: Optional[float] = None
|
||||
max_context_length: int = 96000 # Maximum context length for conversation history (in tokens)
|
||||
max_output_tokens: Optional[int] = None # Optional limit for LLM output tokens (None = no limit)
|
||||
|
||||
|
||||
class LLMPromptsConfig(BaseModel):
|
||||
"""LLM prompts configuration"""
|
||||
agent_system_prompt: str
|
||||
synthesis_system_prompt: Optional[str] = None
|
||||
synthesis_user_prompt: Optional[str] = None
|
||||
intent_recognition_prompt: Optional[str] = None
|
||||
user_manual_prompt: Optional[str] = None
|
||||
|
||||
|
||||
class LLMPromptConfig(BaseModel):
|
||||
"""LLM prompt configuration from llm_prompt.yaml"""
|
||||
parameters: LLMParametersConfig = Field(default_factory=LLMParametersConfig)
|
||||
prompts: LLMPromptsConfig
|
||||
|
||||
|
||||
class LLMRagConfig(BaseModel):
|
||||
"""Legacy LLM RAG configuration for backward compatibility"""
|
||||
temperature: Optional[float] = None
|
||||
max_context_length: int = 96000 # Maximum context length for conversation history (in tokens)
|
||||
max_output_tokens: Optional[int] = None # Optional limit for LLM output tokens (None = no limit)
|
||||
# Legacy prompts for backward compatibility
|
||||
system_prompt: Optional[str] = None
|
||||
user_prompt: Optional[str] = None
|
||||
# New autonomous agent prompts
|
||||
agent_system_prompt: Optional[str] = None
|
||||
synthesis_system_prompt: Optional[str] = None
|
||||
synthesis_user_prompt: Optional[str] = None
|
||||
|
||||
|
||||
class LLMConfig(BaseModel):
|
||||
rag: LLMRagConfig
|
||||
|
||||
|
||||
class LoggingConfig(BaseModel):
|
||||
level: str = "INFO"
|
||||
format: str = "json"
|
||||
|
||||
|
||||
class Config(BaseSettings):
|
||||
provider: str = "openai"
|
||||
openai: Optional[OpenAIConfig] = None
|
||||
azure: Optional[AzureConfig] = None
|
||||
retrieval: RetrievalConfig
|
||||
postgresql: PostgreSQLConfig
|
||||
redis: Optional[RedisConfig] = None
|
||||
app: AppConfig = Field(default_factory=AppConfig)
|
||||
search: SearchConfig = Field(default_factory=SearchConfig)
|
||||
citation: CitationConfig = Field(default_factory=CitationConfig)
|
||||
llm: Optional[LLMConfig] = None
|
||||
logging: LoggingConfig = Field(default_factory=LoggingConfig)
|
||||
|
||||
# New LLM prompt configuration
|
||||
llm_prompt: Optional[LLMPromptConfig] = None
|
||||
|
||||
@classmethod
|
||||
def from_yaml(cls, config_path: str = "config.yaml", llm_prompt_path: str = "llm_prompt.yaml") -> "Config":
|
||||
"""Load configuration from YAML files with environment variable substitution"""
|
||||
# Load main config
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
yaml_data = yaml.safe_load(f)
|
||||
|
||||
# Substitute environment variables
|
||||
yaml_data = cls._substitute_env_vars(yaml_data)
|
||||
|
||||
# Load LLM prompt config if exists
|
||||
llm_prompt_data = None
|
||||
if os.path.exists(llm_prompt_path):
|
||||
with open(llm_prompt_path, 'r', encoding='utf-8') as f:
|
||||
llm_prompt_data = yaml.safe_load(f)
|
||||
llm_prompt_data = cls._substitute_env_vars(llm_prompt_data)
|
||||
yaml_data['llm_prompt'] = llm_prompt_data
|
||||
|
||||
return cls(**yaml_data)
|
||||
|
||||
@classmethod
|
||||
def _substitute_env_vars(cls, data: Any) -> Any:
|
||||
"""Recursively substitute ${VAR} and ${VAR:-default} patterns with environment variables"""
|
||||
if isinstance(data, dict):
|
||||
return {k: cls._substitute_env_vars(v) for k, v in data.items()}
|
||||
elif isinstance(data, list):
|
||||
return [cls._substitute_env_vars(item) for item in data]
|
||||
elif isinstance(data, str):
|
||||
# Handle ${VAR:-default} pattern
|
||||
if data.startswith("${") and data.endswith("}"):
|
||||
env_spec = data[2:-1]
|
||||
if ":-" in env_spec:
|
||||
var_name, default_value = env_spec.split(":-", 1)
|
||||
return os.getenv(var_name, default_value)
|
||||
else:
|
||||
return os.getenv(env_spec, data) # Return original if env var not found
|
||||
return data
|
||||
else:
|
||||
return data
|
||||
|
||||
def get_llm_config(self) -> Dict[str, Any]:
|
||||
"""Get LLM configuration based on provider"""
|
||||
base_config = {}
|
||||
|
||||
# Get temperature and max_output_tokens from llm_prompt config first, fallback to legacy llm.rag config
|
||||
if self.llm_prompt and self.llm_prompt.parameters:
|
||||
# Only add temperature if explicitly set (not None)
|
||||
if self.llm_prompt.parameters.temperature is not None:
|
||||
base_config["temperature"] = self.llm_prompt.parameters.temperature
|
||||
# Only add max_output_tokens if explicitly set (not None)
|
||||
if self.llm_prompt.parameters.max_output_tokens is not None:
|
||||
base_config["max_tokens"] = self.llm_prompt.parameters.max_output_tokens
|
||||
elif self.llm and self.llm.rag:
|
||||
# Only add temperature if explicitly set (not None)
|
||||
if hasattr(self.llm.rag, 'temperature') and self.llm.rag.temperature is not None:
|
||||
base_config["temperature"] = self.llm.rag.temperature
|
||||
# Only add max_output_tokens if explicitly set (not None)
|
||||
if self.llm.rag.max_output_tokens is not None:
|
||||
base_config["max_tokens"] = self.llm.rag.max_output_tokens
|
||||
|
||||
if self.provider == "openai" and self.openai:
|
||||
return {
|
||||
**base_config,
|
||||
"provider": "openai",
|
||||
"base_url": self.openai.base_url,
|
||||
"api_key": self.openai.api_key,
|
||||
"model": self.openai.model,
|
||||
}
|
||||
elif self.provider == "azure" and self.azure:
|
||||
return {
|
||||
**base_config,
|
||||
"provider": "azure",
|
||||
"base_url": self.azure.base_url,
|
||||
"api_key": self.azure.api_key,
|
||||
"deployment": self.azure.deployment,
|
||||
"api_version": self.azure.api_version,
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Invalid provider '{self.provider}' or missing configuration")
|
||||
|
||||
def get_rag_prompts(self) -> Dict[str, str]:
|
||||
"""Get RAG prompts configuration - prioritize llm_prompt.yaml over legacy config"""
|
||||
# Use new llm_prompt config if available
|
||||
if self.llm_prompt and self.llm_prompt.prompts:
|
||||
return {
|
||||
"system_prompt": self.llm_prompt.prompts.agent_system_prompt,
|
||||
"user_prompt": "{{user_query}}", # Default template
|
||||
"agent_system_prompt": self.llm_prompt.prompts.agent_system_prompt,
|
||||
"synthesis_system_prompt": self.llm_prompt.prompts.synthesis_system_prompt or "You are a helpful assistant.",
|
||||
"synthesis_user_prompt": self.llm_prompt.prompts.synthesis_user_prompt or "{{user_query}}",
|
||||
"intent_recognition_prompt": self.llm_prompt.prompts.intent_recognition_prompt or "",
|
||||
"user_manual_prompt": self.llm_prompt.prompts.user_manual_prompt or "",
|
||||
}
|
||||
|
||||
# Fallback to legacy llm.rag config
|
||||
if self.llm and self.llm.rag:
|
||||
return {
|
||||
"system_prompt": self.llm.rag.system_prompt or "You are a helpful assistant.",
|
||||
"user_prompt": self.llm.rag.user_prompt or "{{user_query}}",
|
||||
"agent_system_prompt": self.llm.rag.agent_system_prompt or "You are a helpful assistant.",
|
||||
"synthesis_system_prompt": self.llm.rag.synthesis_system_prompt or "You are a helpful assistant.",
|
||||
"synthesis_user_prompt": self.llm.rag.synthesis_user_prompt or "{{user_query}}",
|
||||
"intent_recognition_prompt": "",
|
||||
"user_manual_prompt": "",
|
||||
}
|
||||
|
||||
# Default fallback
|
||||
return {
|
||||
"system_prompt": "You are a helpful assistant.",
|
||||
"user_prompt": "{{user_query}}",
|
||||
"agent_system_prompt": "You are a helpful assistant.",
|
||||
"synthesis_system_prompt": "You are a helpful assistant.",
|
||||
"synthesis_user_prompt": "{{user_query}}",
|
||||
"intent_recognition_prompt": "",
|
||||
"user_manual_prompt": "",
|
||||
}
|
||||
|
||||
def get_max_context_length(self) -> int:
|
||||
"""Get maximum context length for conversation history"""
|
||||
# Use new llm_prompt config if available
|
||||
if self.llm_prompt and self.llm_prompt.parameters:
|
||||
return self.llm_prompt.parameters.max_context_length
|
||||
|
||||
# Fallback to legacy llm.rag config
|
||||
if self.llm and self.llm.rag:
|
||||
return self.llm.rag.max_context_length
|
||||
|
||||
# Default fallback
|
||||
return 96000
|
||||
|
||||
|
||||
# Global config instance
|
||||
config: Optional[Config] = None
|
||||
|
||||
|
||||
def load_config(config_path: str = "config.yaml", llm_prompt_path: str = "llm_prompt.yaml") -> Config:
|
||||
"""Load and return the global configuration"""
|
||||
global config
|
||||
config = Config.from_yaml(config_path, llm_prompt_path)
|
||||
return config
|
||||
|
||||
|
||||
def get_config() -> Config:
|
||||
"""Get the current configuration instance"""
|
||||
if config is None:
|
||||
raise RuntimeError("Configuration not loaded. Call load_config() first.")
|
||||
return config
|
||||
1
vw-agentic-rag/service/graph/__init__.py
Normal file
1
vw-agentic-rag/service/graph/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Empty __init__.py files to make packages
|
||||
746
vw-agentic-rag/service/graph/graph.py
Normal file
746
vw-agentic-rag/service/graph/graph.py
Normal file
@@ -0,0 +1,746 @@
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import asyncio
|
||||
from typing import Dict, Any, List, Callable, Annotated, Literal, TypedDict, Optional, Union, cast
|
||||
from datetime import datetime
|
||||
from urllib.parse import quote
|
||||
from contextvars import ContextVar
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langgraph.graph import StateGraph, END, add_messages, MessagesState
|
||||
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, ToolMessage, BaseMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from .state import TurnState, Message, ToolResult, AgentState
|
||||
from .message_trimmer import create_conversation_trimmer
|
||||
from .tools import get_tool_schemas, get_tools_by_name
|
||||
from .user_manual_tools import get_user_manual_tools_by_name
|
||||
from .intent_recognition import intent_recognition_node, intent_router
|
||||
from .user_manual_rag import user_manual_rag_node
|
||||
from ..llm_client import LLMClient
|
||||
from ..config import get_config
|
||||
from ..utils.templates import render_prompt_template
|
||||
from ..memory.postgresql_memory import get_checkpointer
|
||||
from ..utils.error_handler import (
|
||||
StructuredLogger, ErrorCategory, ErrorCode,
|
||||
handle_async_errors, get_user_message
|
||||
)
|
||||
from ..sse import (
|
||||
create_tool_start_event,
|
||||
create_tool_result_event,
|
||||
create_tool_error_event,
|
||||
create_token_event,
|
||||
create_error_event
|
||||
)
|
||||
|
||||
logger = StructuredLogger(__name__)
|
||||
|
||||
# Cache configuration at module level to avoid repeated get_config() calls
|
||||
_cached_config = None
|
||||
|
||||
def get_cached_config():
|
||||
"""Get cached configuration, loading it if not already cached"""
|
||||
global _cached_config
|
||||
if _cached_config is None:
|
||||
_cached_config = get_config()
|
||||
return _cached_config
|
||||
|
||||
# Context variable for streaming callback (thread-safe)
|
||||
stream_callback_context: ContextVar[Optional[Callable]] = ContextVar('stream_callback', default=None)
|
||||
|
||||
|
||||
# Agent node (autonomous function calling agent)
|
||||
async def call_model(state: AgentState, config: Optional[RunnableConfig] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Agent node that autonomously uses tools and generates final answer.
|
||||
Implements "detect-first-then-stream" strategy for optimal multi-round behavior:
|
||||
1. Always start with non-streaming detection to check for tool needs
|
||||
2. If tool_calls exist → return immediately for routing to tools
|
||||
3. If no tool_calls → temporarily disable tools and perform streaming final synthesis
|
||||
"""
|
||||
app_config = get_cached_config()
|
||||
llm_client = LLMClient()
|
||||
|
||||
# Get stream callback from context variable
|
||||
stream_callback = stream_callback_context.get()
|
||||
|
||||
# Get tool schemas and bind tools for planning phase
|
||||
tool_schemas = get_tool_schemas()
|
||||
llm_client.bind_tools(tool_schemas, force_tool_choice=True)
|
||||
|
||||
# Create conversation trimmer for managing context length
|
||||
trimmer = create_conversation_trimmer()
|
||||
|
||||
# Prepare messages with system prompt
|
||||
messages = state["messages"].copy()
|
||||
if not messages or not isinstance(messages[0], SystemMessage):
|
||||
rag_prompts = app_config.get_rag_prompts()
|
||||
system_prompt = rag_prompts.get("agent_system_prompt", "")
|
||||
if not system_prompt:
|
||||
raise ValueError("system_prompt is null")
|
||||
|
||||
messages = [SystemMessage(content=system_prompt)] + messages
|
||||
|
||||
# Track tool rounds
|
||||
current_round = state.get("tool_rounds", 0)
|
||||
# Get max_tool_rounds from state, fallback to config if not set
|
||||
max_rounds = state.get("max_tool_rounds", None)
|
||||
if max_rounds is None:
|
||||
max_rounds = app_config.app.max_tool_rounds
|
||||
|
||||
# Only apply trimming at the start of a new conversation turn (when tool_rounds = 0)
|
||||
# This prevents trimming current turn's tool results during multi-round tool calling
|
||||
if current_round == 0:
|
||||
# Trim conversation history to manage context length (only for previous conversation turns)
|
||||
if trimmer.should_trim(messages):
|
||||
messages = trimmer.trim_conversation_history(messages)
|
||||
logger.info("Applied conversation history trimming for context management (new conversation turn)")
|
||||
else:
|
||||
logger.info(f"Skipping trimming during tool round {current_round} to preserve current turn's context")
|
||||
|
||||
logger.info(f"Agent node: tool_rounds={current_round}, max_tool_rounds={max_rounds}")
|
||||
|
||||
# Check if this should be final synthesis (max rounds reached)
|
||||
has_tool_messages = any(isinstance(msg, ToolMessage) for msg in messages)
|
||||
is_final_synthesis = has_tool_messages and current_round >= max_rounds
|
||||
|
||||
if is_final_synthesis:
|
||||
logger.info("Starting final synthesis phase - no more tool calls allowed")
|
||||
# ✅ STEP 1: Final synthesis with tools disabled from the start
|
||||
# Disable tools to prevent any tool calling during synthesis
|
||||
try:
|
||||
original_tools = llm_client.bind_tools([], force_tool_choice=False) # Disable tools
|
||||
|
||||
if not stream_callback:
|
||||
# No streaming callback, generate final response without tools
|
||||
draft = await llm_client.ainvoke(list(messages))
|
||||
return {"messages": [draft]}
|
||||
|
||||
# ✅ STEP 2: Streaming final synthesis with improved HTML comment filtering
|
||||
response_content = ""
|
||||
accumulated_content = ""
|
||||
|
||||
async for token in llm_client.astream(list(messages)):
|
||||
accumulated_content += token
|
||||
response_content += token
|
||||
|
||||
# Check for complete HTML comments in accumulated content
|
||||
while "<!--" in accumulated_content and "-->" in accumulated_content:
|
||||
comment_start = accumulated_content.find("<!--")
|
||||
comment_end = accumulated_content.find("-->", comment_start)
|
||||
|
||||
if comment_start >= 0 and comment_end >= 0:
|
||||
# Send content before comment
|
||||
before_comment = accumulated_content[:comment_start]
|
||||
if stream_callback and before_comment:
|
||||
await stream_callback(create_token_event(before_comment))
|
||||
|
||||
# Skip the comment and continue with content after
|
||||
accumulated_content = accumulated_content[comment_end + 3:]
|
||||
else:
|
||||
break
|
||||
|
||||
# Send accumulated content if no pending comment
|
||||
if "<!--" not in accumulated_content:
|
||||
if stream_callback and accumulated_content:
|
||||
await stream_callback(create_token_event(accumulated_content))
|
||||
accumulated_content = ""
|
||||
|
||||
# Send any remaining content (if not in middle of comment)
|
||||
if accumulated_content and "<!--" not in accumulated_content:
|
||||
if stream_callback:
|
||||
await stream_callback(create_token_event(accumulated_content))
|
||||
|
||||
return {"messages": [AIMessage(content=response_content)]}
|
||||
|
||||
finally:
|
||||
# ✅ STEP 3: Restore tool binding for next interaction
|
||||
llm_client.bind_tools(tool_schemas, force_tool_choice=True)
|
||||
|
||||
else:
|
||||
logger.info(f"Tool calling round {current_round + 1}/{max_rounds}")
|
||||
|
||||
# ✅ STEP 1: Non-streaming detection to check for tool needs
|
||||
draft = await llm_client.ainvoke_with_tools(list(messages))
|
||||
|
||||
# ✅ STEP 2: If draft has tool_calls, return immediately (let routing handle it)
|
||||
if isinstance(draft, AIMessage) and hasattr(draft, 'tool_calls') and draft.tool_calls:
|
||||
# Increment tool round counter for next iteration
|
||||
new_tool_rounds = current_round + 1
|
||||
logger.info(f"Incremented tool_rounds to {new_tool_rounds}")
|
||||
return {"messages": [draft], "tool_rounds": new_tool_rounds}
|
||||
|
||||
# ✅ STEP 3: No tool_calls needed → Enter final synthesis with streaming
|
||||
# Temporarily disable tools to prevent accidental tool calling during synthesis
|
||||
try:
|
||||
llm_client.bind_tools([], force_tool_choice=False) # Disable tools
|
||||
|
||||
if not stream_callback:
|
||||
# No streaming callback, use the draft we already have
|
||||
return {"messages": [draft]}
|
||||
|
||||
# ✅ STEP 4: Streaming final synthesis with improved HTML comment filtering
|
||||
response_content = ""
|
||||
accumulated_content = ""
|
||||
|
||||
async for token in llm_client.astream(list(messages)):
|
||||
accumulated_content += token
|
||||
response_content += token
|
||||
|
||||
# Check for complete HTML comments in accumulated content
|
||||
while "<!--" in accumulated_content and "-->" in accumulated_content:
|
||||
comment_start = accumulated_content.find("<!--")
|
||||
comment_end = accumulated_content.find("-->", comment_start)
|
||||
|
||||
if comment_start >= 0 and comment_end >= 0:
|
||||
# Send content before comment
|
||||
before_comment = accumulated_content[:comment_start]
|
||||
if stream_callback and before_comment:
|
||||
await stream_callback(create_token_event(before_comment))
|
||||
|
||||
# Skip the comment and continue with content after
|
||||
accumulated_content = accumulated_content[comment_end + 3:]
|
||||
else:
|
||||
break
|
||||
|
||||
# Send accumulated content if no pending comment
|
||||
if "<!--" not in accumulated_content:
|
||||
if stream_callback and accumulated_content:
|
||||
await stream_callback(create_token_event(accumulated_content))
|
||||
accumulated_content = ""
|
||||
|
||||
# Send any remaining content (if not in middle of comment)
|
||||
if accumulated_content and "<!--" not in accumulated_content:
|
||||
if stream_callback:
|
||||
await stream_callback(create_token_event(accumulated_content))
|
||||
|
||||
return {"messages": [AIMessage(content=response_content)]}
|
||||
|
||||
finally:
|
||||
# ✅ STEP 5: Restore tool binding for next interaction
|
||||
llm_client.bind_tools(tool_schemas, force_tool_choice=True)
|
||||
|
||||
|
||||
# Tools routing condition (simplified for "detect-first-then-stream" strategy)
|
||||
def should_continue(state: AgentState) -> Literal["tools", "agent", "post_process"]:
|
||||
"""
|
||||
Simplified routing logic for "detect-first-then-stream" strategy:
|
||||
- has tool_calls → route to tools
|
||||
- no tool_calls → route to post_process (final synthesis already completed)
|
||||
"""
|
||||
messages = state["messages"]
|
||||
if not messages:
|
||||
logger.info("should_continue: No messages, routing to post_process")
|
||||
return "post_process"
|
||||
|
||||
last_message = messages[-1]
|
||||
current_round = state.get("tool_rounds", 0)
|
||||
# Get max_tool_rounds from state, fallback to config if not set
|
||||
max_rounds = state.get("max_tool_rounds", None)
|
||||
if max_rounds is None:
|
||||
app_config = get_cached_config()
|
||||
max_rounds = app_config.app.max_tool_rounds
|
||||
|
||||
logger.info(f"should_continue: Last message type: {type(last_message)}, tool_rounds: {current_round}/{max_rounds}")
|
||||
|
||||
# If last message is AI message with tool calls, route to tools
|
||||
if isinstance(last_message, AIMessage):
|
||||
has_tool_calls = hasattr(last_message, 'tool_calls') and last_message.tool_calls
|
||||
logger.info(f"should_continue: AI message has tool_calls: {has_tool_calls}")
|
||||
|
||||
if has_tool_calls:
|
||||
logger.info("should_continue: Routing to tools")
|
||||
return "tools"
|
||||
else:
|
||||
# No tool calls = final synthesis already completed in call_model
|
||||
logger.info("should_continue: No tool calls, routing to post_process")
|
||||
return "post_process"
|
||||
|
||||
# If last message is tool message(s), continue with agent for next round or final synthesis
|
||||
if isinstance(last_message, ToolMessage):
|
||||
logger.info("should_continue: Tool message completed, continuing to agent")
|
||||
return "agent"
|
||||
|
||||
logger.info("should_continue: Routing to post_process")
|
||||
return "post_process"
|
||||
|
||||
|
||||
# Custom tool node with streaming support and parallel execution
|
||||
async def run_tools_with_streaming(state: AgentState, config: Optional[RunnableConfig] = None) -> Dict[str, Any]:
|
||||
"""Execute tools with streaming events - supports parallel execution"""
|
||||
messages = state["messages"]
|
||||
last_message = messages[-1]
|
||||
|
||||
# Get stream callback from context variable
|
||||
stream_callback = stream_callback_context.get()
|
||||
|
||||
if not isinstance(last_message, AIMessage) or not hasattr(last_message, 'tool_calls'):
|
||||
return {"messages": []}
|
||||
|
||||
tool_calls = last_message.tool_calls or []
|
||||
tool_results = []
|
||||
new_messages = []
|
||||
|
||||
# Tools mapping
|
||||
tools_map = get_tools_by_name()
|
||||
|
||||
async def execute_single_tool(tool_call):
|
||||
"""Execute a single tool call with enhanced error handling"""
|
||||
# Get stream callback from context
|
||||
stream_callback = stream_callback_context.get()
|
||||
|
||||
# Apply error handling decorator
|
||||
@handle_async_errors(
|
||||
ErrorCategory.TOOL,
|
||||
ErrorCode.TOOL_ERROR,
|
||||
stream_callback,
|
||||
tool_call.get("id", "unknown") if isinstance(tool_call, dict) else "unknown"
|
||||
)
|
||||
async def _execute():
|
||||
# Validate tool_call format
|
||||
if not isinstance(tool_call, dict):
|
||||
raise ValueError(f"Tool call must be dict, got {type(tool_call)}")
|
||||
|
||||
tool_name = tool_call.get("name")
|
||||
tool_args = tool_call.get("args", {})
|
||||
tool_id = tool_call.get("id", "unknown")
|
||||
|
||||
if not tool_name or tool_name not in tools_map:
|
||||
raise ValueError(f"Tool '{tool_name}' not found")
|
||||
|
||||
logger.info(f"Executing tool: {tool_name}", extra={
|
||||
"tool_id": tool_id, "tool_name": tool_name
|
||||
})
|
||||
|
||||
# Send start event
|
||||
if stream_callback:
|
||||
await stream_callback(create_tool_start_event(tool_id, tool_name, tool_args))
|
||||
|
||||
# Execute tool
|
||||
import time
|
||||
start_time = time.time()
|
||||
result = await tools_map[tool_name].ainvoke(tool_args)
|
||||
execution_time = int((time.time() - start_time) * 1000)
|
||||
|
||||
# Process result
|
||||
if isinstance(result, dict):
|
||||
result["tool_call_id"] = tool_id
|
||||
if "results" in result and isinstance(result["results"], list):
|
||||
for i, search_result in enumerate(result["results"]):
|
||||
if isinstance(search_result, dict):
|
||||
search_result["@tool_call_id"] = tool_id
|
||||
search_result["@order_num"] = i
|
||||
|
||||
# Send result event
|
||||
if stream_callback:
|
||||
await stream_callback(create_tool_result_event(
|
||||
tool_id, tool_name, result.get("results", []), execution_time
|
||||
))
|
||||
|
||||
# Create tool message
|
||||
tool_message = ToolMessage(
|
||||
content=json.dumps(result, ensure_ascii=False),
|
||||
tool_call_id=tool_id,
|
||||
name=tool_name
|
||||
)
|
||||
|
||||
return {
|
||||
"message": tool_message,
|
||||
"results": result.get("results", []) if isinstance(result, dict) else [],
|
||||
"success": True
|
||||
}
|
||||
|
||||
try:
|
||||
return await _execute()
|
||||
except Exception as e:
|
||||
# Handle any errors not caught by decorator
|
||||
tool_id = tool_call.get("id", "unknown") if isinstance(tool_call, dict) else "unknown"
|
||||
tool_name = tool_call.get("name", "unknown") if isinstance(tool_call, dict) else "unknown"
|
||||
|
||||
error_message = ToolMessage(
|
||||
content=f"Error: {get_user_message(ErrorCategory.TOOL)}",
|
||||
tool_call_id=tool_id,
|
||||
name=tool_name
|
||||
)
|
||||
|
||||
return {
|
||||
"message": error_message,
|
||||
"results": [],
|
||||
"success": False
|
||||
}
|
||||
|
||||
# Execute all tool calls in parallel using asyncio.gather
|
||||
if tool_calls:
|
||||
logger.info(f"Executing {len(tool_calls)} tool calls in parallel")
|
||||
tool_execution_results = await asyncio.gather(
|
||||
*[execute_single_tool(tool_call) for tool_call in tool_calls],
|
||||
return_exceptions=True
|
||||
)
|
||||
|
||||
# Process results
|
||||
for execution_result in tool_execution_results:
|
||||
if execution_result is None:
|
||||
continue
|
||||
if isinstance(execution_result, Exception):
|
||||
logger.error(f"Tool execution exception: {execution_result}")
|
||||
continue
|
||||
if not isinstance(execution_result, dict):
|
||||
logger.error(f"Unexpected execution result type: {type(execution_result)}")
|
||||
continue
|
||||
|
||||
new_messages.append(execution_result["message"])
|
||||
if execution_result["success"] and execution_result["results"]:
|
||||
tool_results.extend(execution_result["results"])
|
||||
|
||||
logger.info(f"Parallel tool execution completed. {len(new_messages)} tools executed, {len(tool_results)} results collected")
|
||||
|
||||
return {
|
||||
"messages": new_messages,
|
||||
"tool_results": tool_results
|
||||
}
|
||||
|
||||
|
||||
# Helper functions for citation processing
|
||||
def _extract_citations_mapping(agent_response: str) -> Dict[int, Dict[str, Any]]:
|
||||
"""Extract citations mapping CSV from agent response HTML comment"""
|
||||
try:
|
||||
# Look for citations_map comment
|
||||
pattern = r'<!-- citations_map\s*(.*?)\s*-->'
|
||||
match = re.search(pattern, agent_response, re.DOTALL | re.IGNORECASE)
|
||||
|
||||
if not match:
|
||||
logger.warning("No citations_map comment found in agent response")
|
||||
return {}
|
||||
|
||||
csv_content = match.group(1).strip()
|
||||
citations_mapping = {}
|
||||
|
||||
for line in csv_content.split('\n'):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
parts = line.split(',')
|
||||
if len(parts) >= 3:
|
||||
try:
|
||||
citation_num = int(parts[0])
|
||||
tool_call_id = parts[1].strip()
|
||||
order_num = int(parts[2])
|
||||
|
||||
citations_mapping[citation_num] = {
|
||||
'tool_call_id': tool_call_id,
|
||||
'order_num': order_num
|
||||
}
|
||||
except (ValueError, IndexError) as e:
|
||||
logger.warning(f"Failed to parse citation line: {line}, error: {e}")
|
||||
continue
|
||||
|
||||
return citations_mapping
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting citations mapping: {e}")
|
||||
return {}
|
||||
|
||||
|
||||
def _build_citation_markdown(citations_mapping: Dict[int, Dict[str, Any]], tool_results: List[Dict[str, Any]]) -> str:
|
||||
"""Build citation markdown based on mapping and tool results, following build_citations.py logic"""
|
||||
if not citations_mapping:
|
||||
return ""
|
||||
|
||||
# Get configuration for citation base URL
|
||||
config = get_cached_config()
|
||||
cat_base_url = config.citation.base_url
|
||||
|
||||
# Collect citation lines first; only emit header if we have at least one valid citation
|
||||
entries: List[str] = []
|
||||
|
||||
for citation_num in sorted(citations_mapping.keys()):
|
||||
mapping = citations_mapping[citation_num]
|
||||
tool_call_id = mapping['tool_call_id']
|
||||
order_num = mapping['order_num']
|
||||
|
||||
# Find the corresponding tool result
|
||||
result = _find_tool_result(tool_results, tool_call_id, order_num)
|
||||
if not result:
|
||||
logger.warning(f"No tool result found for citation [{citation_num}]")
|
||||
continue
|
||||
|
||||
# Extract citation information following build_citations.py logic
|
||||
full_headers = result.get('full_headers', '')
|
||||
lowest_header = full_headers.split("||", 1)[0] if full_headers else ""
|
||||
header_display = f": {lowest_header}" if lowest_header else ""
|
||||
|
||||
document_code = result.get('document_code', '')
|
||||
document_category = result.get('document_category', '')
|
||||
|
||||
# Determine standard/regulation title (assuming English language)
|
||||
standard_regulation_title = ''
|
||||
if document_category == 'Standard':
|
||||
standard_regulation_title = result.get('x_Standard_Title_EN', '') or result.get('x_Standard_Title_CN', '')
|
||||
elif document_category == 'Regulation':
|
||||
standard_regulation_title = result.get('x_Regulation_Title_EN', '') or result.get('x_Regulation_Title_CN', '')
|
||||
|
||||
# Build link
|
||||
func_uuid = result.get('func_uuid', '')
|
||||
uuid = result.get('x_Standard_Regulation_Id', '')
|
||||
document_code_encoded = quote(document_code, safe='') if document_code else ''
|
||||
standard_regulation_title_encoded = quote(standard_regulation_title, safe='') if standard_regulation_title else ''
|
||||
link_name = f"{document_code_encoded}({standard_regulation_title_encoded})" if (document_code_encoded or standard_regulation_title_encoded) else ''
|
||||
link = f'{cat_base_url}?funcUuid={func_uuid}&uuid={uuid}&name={link_name}'
|
||||
|
||||
# Format citation line
|
||||
title = result.get('title', '')
|
||||
entries.append(f"[{citation_num}] {title}{header_display} | [{standard_regulation_title} | {document_code}]({link})")
|
||||
|
||||
# If no valid citations were found, do not include the header
|
||||
if not entries:
|
||||
return ""
|
||||
|
||||
# Build citations section with entries separated by a blank line (matching previous formatting)
|
||||
md = "\n\n### 📘 Citations:\n" + "\n\n".join(entries) + "\n\n"
|
||||
return md
|
||||
|
||||
|
||||
def _find_tool_result(tool_results: List[Dict[str, Any]], tool_call_id: str, order_num: int) -> Optional[Dict[str, Any]]:
|
||||
"""Find tool result by tool_call_id and order_num"""
|
||||
matching_results = []
|
||||
|
||||
for result in tool_results:
|
||||
if result.get('@tool_call_id') == tool_call_id:
|
||||
matching_results.append(result)
|
||||
|
||||
# Sort by order and return the one at the specified position
|
||||
if matching_results and 0 <= order_num < len(matching_results):
|
||||
# If results have @order_num, use it; otherwise use position in list
|
||||
if '@order_num' in matching_results[0]:
|
||||
for result in matching_results:
|
||||
if result.get('@order_num') == order_num:
|
||||
return result
|
||||
else:
|
||||
return matching_results[order_num]
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _remove_citations_comment(agent_response: str) -> str:
|
||||
"""Remove citations mapping HTML comment from agent response"""
|
||||
pattern = r'<!-- citations_map\s*.*?\s*-->'
|
||||
return re.sub(pattern, '', agent_response, flags=re.DOTALL | re.IGNORECASE).strip()
|
||||
|
||||
|
||||
# Post-processing node with citation list and link building
|
||||
async def post_process_node(state: AgentState, config: Optional[RunnableConfig] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Post-processing node that builds citation list and links based on agent's citations mapping
|
||||
and tool call results, following the logic from build_citations.py
|
||||
"""
|
||||
try:
|
||||
logger.info("🔧 POST_PROCESS_NODE: Starting citation processing")
|
||||
|
||||
# Get stream callback from context variable
|
||||
stream_callback = stream_callback_context.get()
|
||||
|
||||
# Get the last AI message (agent's response with citations mapping)
|
||||
agent_response = ""
|
||||
citations_mapping = {}
|
||||
|
||||
for message in reversed(state["messages"]):
|
||||
if isinstance(message, AIMessage) and message.content:
|
||||
# Ensure content is a string
|
||||
if isinstance(message.content, str):
|
||||
agent_response = message.content
|
||||
break
|
||||
|
||||
if not agent_response:
|
||||
logger.warning("POST_PROCESS_NODE: No agent response found")
|
||||
return {"messages": [], "final_answer": ""}
|
||||
|
||||
# Extract citations mapping from agent response
|
||||
citations_mapping = _extract_citations_mapping(agent_response)
|
||||
logger.info(f"POST_PROCESS_NODE: Extracted {len(citations_mapping)} citations")
|
||||
|
||||
# Build citation markdown
|
||||
citation_markdown = _build_citation_markdown(citations_mapping, state["tool_results"])
|
||||
|
||||
# Combine agent response (without HTML comment) with citations
|
||||
clean_response = _remove_citations_comment(agent_response)
|
||||
final_content = clean_response + citation_markdown
|
||||
|
||||
logger.info("POST_PROCESS_NODE: Built complete response with citations")
|
||||
|
||||
# Send citation markdown as a single block instead of streaming
|
||||
stream_callback = stream_callback_context.get()
|
||||
if stream_callback and citation_markdown:
|
||||
logger.info("POST_PROCESS_NODE: Sending citation markdown as single block to client")
|
||||
await stream_callback(create_token_event(citation_markdown))
|
||||
|
||||
# Create AI message with complete content
|
||||
final_ai_message = AIMessage(content=final_content)
|
||||
|
||||
return {
|
||||
"messages": [final_ai_message],
|
||||
"final_answer": final_content
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Post-processing error: {e}")
|
||||
error_message = "\n\n❌ **Error generating citations**\n\nPlease check the search results above."
|
||||
|
||||
# Send error message as single block
|
||||
stream_callback = stream_callback_context.get()
|
||||
if stream_callback:
|
||||
await stream_callback(create_token_event(error_message))
|
||||
|
||||
error_content = agent_response + error_message if agent_response else error_message
|
||||
error_ai_message = AIMessage(content=error_content)
|
||||
return {
|
||||
"messages": [error_ai_message],
|
||||
"final_answer": error_ai_message.content
|
||||
}
|
||||
|
||||
|
||||
# Main workflow class
|
||||
class AgenticWorkflow:
|
||||
"""LangGraph-based autonomous agent workflow following v0.6.0+ best practices"""
|
||||
|
||||
def __init__(self):
|
||||
# Build StateGraph with TypedDict state
|
||||
workflow = StateGraph(AgentState)
|
||||
|
||||
# Add nodes following best practices
|
||||
workflow.add_node("intent_recognition", intent_recognition_node)
|
||||
workflow.add_node("agent", call_model)
|
||||
workflow.add_node("user_manual_rag", user_manual_rag_node)
|
||||
workflow.add_node("tools", run_tools_with_streaming)
|
||||
workflow.add_node("post_process", post_process_node)
|
||||
|
||||
# Set entry point to intent recognition
|
||||
workflow.set_entry_point("intent_recognition")
|
||||
|
||||
# Intent recognition routes to either Standard_Regulation_RAG or User_Manual_RAG
|
||||
workflow.add_conditional_edges(
|
||||
"intent_recognition",
|
||||
intent_router,
|
||||
{
|
||||
"Standard_Regulation_RAG": "agent",
|
||||
"User_Manual_RAG": "user_manual_rag"
|
||||
}
|
||||
)
|
||||
|
||||
# Standard RAG workflow (existing pattern)
|
||||
workflow.add_conditional_edges(
|
||||
"agent",
|
||||
should_continue,
|
||||
{
|
||||
"tools": "tools",
|
||||
"agent": "agent", # Allow agent to continue for multi-round
|
||||
"post_process": "post_process"
|
||||
}
|
||||
)
|
||||
|
||||
# Tools route back to should_continue for multi-round decision
|
||||
workflow.add_conditional_edges(
|
||||
"tools",
|
||||
should_continue,
|
||||
{
|
||||
"agent": "agent", # Continue to agent for next round
|
||||
"post_process": "post_process" # Or finish if max rounds reached
|
||||
}
|
||||
)
|
||||
|
||||
# User Manual RAG directly goes to END (single turn)
|
||||
workflow.add_edge("user_manual_rag", END)
|
||||
|
||||
# Post-process is terminal
|
||||
workflow.add_edge("post_process", END)
|
||||
|
||||
# Compile graph with PostgreSQL checkpointer for session memory
|
||||
try:
|
||||
checkpointer = get_checkpointer()
|
||||
self.graph = workflow.compile(checkpointer=checkpointer)
|
||||
logger.info("Graph compiled with PostgreSQL checkpointer for session memory")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to initialize PostgreSQL checkpointer, using memory-only graph: {e}")
|
||||
self.graph = workflow.compile()
|
||||
|
||||
async def astream(self, state: TurnState, stream_callback: Callable | None = None):
|
||||
"""Stream agent execution using LangGraph with PostgreSQL session memory"""
|
||||
try:
|
||||
# Get configuration
|
||||
config = get_cached_config()
|
||||
|
||||
# Prepare initial messages for the graph
|
||||
messages = []
|
||||
for msg in state.messages:
|
||||
if msg.role == "user":
|
||||
messages.append(HumanMessage(content=msg.content))
|
||||
elif msg.role == "assistant":
|
||||
messages.append(AIMessage(content=msg.content))
|
||||
|
||||
# Create initial agent state (without stream_callback to avoid serialization issues)
|
||||
initial_state: AgentState = {
|
||||
"messages": messages,
|
||||
"session_id": state.session_id,
|
||||
"intent": None, # Will be determined by intent recognition node
|
||||
"tool_results": [],
|
||||
"final_answer": "",
|
||||
"tool_rounds": 0,
|
||||
"max_tool_rounds": config.app.max_tool_rounds, # Use configuration value
|
||||
"max_tool_rounds_user_manual": config.app.max_tool_rounds_user_manual # Use configuration value for user manual agent
|
||||
}
|
||||
|
||||
# Set stream callback in context variable (thread-safe)
|
||||
stream_callback_context.set(stream_callback)
|
||||
|
||||
# Create proper RunnableConfig
|
||||
runnable_config = RunnableConfig(configurable={"thread_id": state.session_id})
|
||||
|
||||
# Stream graph execution with session memory
|
||||
async for step in self.graph.astream(initial_state, config=runnable_config):
|
||||
if "post_process" in step:
|
||||
final_state = step["post_process"]
|
||||
|
||||
# Extract the tool summary message and update state
|
||||
state.final_answer = final_state.get("final_answer", "")
|
||||
|
||||
# Add the summary as a regular assistant message
|
||||
if state.final_answer:
|
||||
state.messages.append(Message(
|
||||
role="assistant",
|
||||
content=state.final_answer,
|
||||
timestamp=datetime.now()
|
||||
))
|
||||
|
||||
yield {"final": state}
|
||||
break
|
||||
elif "user_manual_rag" in step:
|
||||
# Handle user manual RAG completion
|
||||
final_state = step["user_manual_rag"]
|
||||
|
||||
# Extract the response from user manual RAG
|
||||
state.final_answer = final_state.get("final_answer", "")
|
||||
|
||||
# Add the response as a regular assistant message
|
||||
if state.final_answer:
|
||||
state.messages.append(Message(
|
||||
role="assistant",
|
||||
content=state.final_answer,
|
||||
timestamp=datetime.now()
|
||||
))
|
||||
|
||||
yield {"final": state}
|
||||
break
|
||||
else:
|
||||
# Process regular steps (intent_recognition, agent, tools)
|
||||
yield step
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"AgentWorkflow error: {e}")
|
||||
state.final_answer = "I apologize, but I encountered an error while processing your request."
|
||||
yield {"final": state}
|
||||
|
||||
|
||||
def build_graph() -> AgenticWorkflow:
|
||||
"""Build and return the autonomous agent workflow"""
|
||||
return AgenticWorkflow()
|
||||
136
vw-agentic-rag/service/graph/intent_recognition.py
Normal file
136
vw-agentic-rag/service/graph/intent_recognition.py
Normal file
@@ -0,0 +1,136 @@
|
||||
"""
|
||||
Intent recognition functionality for the Agentic RAG system.
|
||||
This module contains the intent classification logic.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Any, Optional, Literal
|
||||
from langchain_core.messages import SystemMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .state import AgentState
|
||||
from ..llm_client import LLMClient
|
||||
from ..config import get_config
|
||||
from ..utils.error_handler import StructuredLogger
|
||||
|
||||
logger = StructuredLogger(__name__)
|
||||
|
||||
|
||||
# Intent Recognition Models
|
||||
class Intent(BaseModel):
|
||||
"""Intent classification model for routing user queries"""
|
||||
label: Literal["Standard_Regulation_RAG", "User_Manual_RAG"]
|
||||
confidence: Optional[float] = None
|
||||
|
||||
|
||||
def get_last_user_message(messages) -> str:
|
||||
"""Extract the last user message from conversation history"""
|
||||
for message in reversed(messages):
|
||||
if hasattr(message, 'content'):
|
||||
content = message.content
|
||||
# Handle both string and list content
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
elif isinstance(content, list):
|
||||
# Extract string content from list
|
||||
return " ".join([str(item) for item in content if isinstance(item, str)])
|
||||
return ""
|
||||
|
||||
|
||||
def render_conversation_history(messages, max_messages: int = 10) -> str:
|
||||
"""Render conversation history for context"""
|
||||
recent_messages = messages[-max_messages:] if len(messages) > max_messages else messages
|
||||
lines = []
|
||||
for msg in recent_messages:
|
||||
if hasattr(msg, 'content'):
|
||||
content = msg.content
|
||||
if isinstance(content, str):
|
||||
# Determine message type by class name or other attributes
|
||||
if 'Human' in str(type(msg)):
|
||||
lines.append(f"<user>{content}</user>")
|
||||
elif 'AI' in str(type(msg)):
|
||||
lines.append(f"<ai>{content}</ai>")
|
||||
elif isinstance(content, list):
|
||||
content_str = " ".join([str(item) for item in content if isinstance(item, str)])
|
||||
if 'Human' in str(type(msg)):
|
||||
lines.append(f"<user>{content_str}</user>")
|
||||
elif 'AI' in str(type(msg)):
|
||||
lines.append(f"<ai>{content_str}</ai>")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
async def intent_recognition_node(state: AgentState, config: Optional[RunnableConfig] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Intent recognition node that uses LLM to classify user queries into specific domains
|
||||
"""
|
||||
try:
|
||||
logger.info("🎯 INTENT_RECOGNITION_NODE: Starting intent classification")
|
||||
|
||||
app_config = get_config()
|
||||
llm_client = LLMClient()
|
||||
|
||||
# Get current user query and conversation history
|
||||
current_query = get_last_user_message(state["messages"])
|
||||
conversation_context = render_conversation_history(state["messages"])
|
||||
|
||||
# Get intent classification prompt from configuration
|
||||
rag_prompts = app_config.get_rag_prompts()
|
||||
intent_prompt_template = rag_prompts.get("intent_recognition_prompt")
|
||||
|
||||
if not intent_prompt_template:
|
||||
logger.error("Intent recognition prompt not found in configuration")
|
||||
return {"intent": "Standard_Regulation_RAG"}
|
||||
|
||||
# Format the prompt with instruction to return only the label
|
||||
system_prompt = intent_prompt_template.format(
|
||||
current_query=current_query,
|
||||
conversation_context=conversation_context
|
||||
) + "\n\nIMPORTANT: You must respond with ONLY one of these two exact labels: 'Standard_Regulation_RAG' or 'User_Manual_RAG'. Do not include any other text or explanation."
|
||||
|
||||
# Classify intent using regular LLM call
|
||||
intent_result = await llm_client.llm.ainvoke([
|
||||
SystemMessage(content=system_prompt)
|
||||
])
|
||||
|
||||
# Parse the response to extract the intent label
|
||||
response_text = ""
|
||||
if hasattr(intent_result, 'content') and intent_result.content:
|
||||
if isinstance(intent_result.content, str):
|
||||
response_text = intent_result.content.strip()
|
||||
elif isinstance(intent_result.content, list):
|
||||
# Handle list content by joining string elements
|
||||
response_text = " ".join([str(item) for item in intent_result.content if isinstance(item, str)]).strip()
|
||||
|
||||
# Extract intent label from response
|
||||
if "User_Manual_RAG" in response_text:
|
||||
intent_label = "User_Manual_RAG"
|
||||
elif "Standard_Regulation_RAG" in response_text:
|
||||
intent_label = "Standard_Regulation_RAG"
|
||||
else:
|
||||
# Default fallback
|
||||
logger.warning(f"Could not parse intent from response: {response_text}, defaulting to Standard_Regulation_RAG")
|
||||
intent_label = "Standard_Regulation_RAG"
|
||||
|
||||
logger.info(f"🎯 INTENT_RECOGNITION_NODE: Classified intent as '{intent_label}'")
|
||||
|
||||
return {"intent": intent_label}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Intent recognition error: {e}")
|
||||
# Default to Standard_Regulation_RAG if classification fails
|
||||
logger.info("🎯 INTENT_RECOGNITION_NODE: Defaulting to Standard_Regulation_RAG due to error")
|
||||
return {"intent": "Standard_Regulation_RAG"}
|
||||
|
||||
|
||||
def intent_router(state: AgentState) -> Literal["Standard_Regulation_RAG", "User_Manual_RAG"]:
|
||||
"""
|
||||
Route based on intent classification result
|
||||
"""
|
||||
intent = state.get("intent")
|
||||
if intent is None:
|
||||
logger.warning("🎯 INTENT_ROUTER: No intent found, defaulting to Standard_Regulation_RAG")
|
||||
return "Standard_Regulation_RAG"
|
||||
|
||||
logger.info(f"🎯 INTENT_ROUTER: Routing to {intent}")
|
||||
return intent
|
||||
270
vw-agentic-rag/service/graph/message_trimmer.py
Normal file
270
vw-agentic-rag/service/graph/message_trimmer.py
Normal file
@@ -0,0 +1,270 @@
|
||||
"""
|
||||
Conversation history trimming utilities for managing context length.
|
||||
"""
|
||||
import logging
|
||||
from typing import List, Optional, Sequence, Tuple
|
||||
from langchain_core.messages import BaseMessage, SystemMessage, HumanMessage, ToolMessage, AIMessage, AnyMessage
|
||||
from langchain_core.messages.utils import trim_messages, count_tokens_approximately
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ConversationTrimmer:
|
||||
"""
|
||||
Manages conversation history to prevent exceeding LLM context limits.
|
||||
"""
|
||||
|
||||
def __init__(self, max_context_length: int = 96000, preserve_system: bool = True):
|
||||
"""
|
||||
Initialize the conversation trimmer.
|
||||
|
||||
Args:
|
||||
max_context_length: Maximum context length for conversation history (in tokens)
|
||||
preserve_system: Whether to always preserve system messages
|
||||
"""
|
||||
self.max_context_length = max_context_length
|
||||
self.preserve_system = preserve_system
|
||||
# Reserve tokens for response generation (use 85% for history, 15% for response)
|
||||
self.history_token_limit = int(max_context_length * 0.85)
|
||||
|
||||
def trim_conversation_history(self, messages: Sequence[AnyMessage]) -> List[BaseMessage]:
|
||||
"""
|
||||
Trim conversation history to fit within token limits.
|
||||
|
||||
Args:
|
||||
messages: List of conversation messages
|
||||
|
||||
Returns:
|
||||
Trimmed list of messages
|
||||
"""
|
||||
if not messages:
|
||||
return list(messages)
|
||||
|
||||
try:
|
||||
# Convert to list for processing
|
||||
message_list = list(messages)
|
||||
|
||||
# First, try multi-round tool call optimization
|
||||
optimized_messages = self._optimize_multi_round_tool_calls(message_list)
|
||||
|
||||
# Check if optimization is sufficient
|
||||
try:
|
||||
token_count = count_tokens_approximately(optimized_messages)
|
||||
if token_count <= self.history_token_limit:
|
||||
original_count = len(message_list)
|
||||
optimized_count = len(optimized_messages)
|
||||
if optimized_count < original_count:
|
||||
logger.info(f"Multi-round tool optimization: {original_count} -> {optimized_count} messages")
|
||||
return optimized_messages
|
||||
except Exception:
|
||||
# If token counting fails, continue with LangChain trimming
|
||||
pass
|
||||
|
||||
# If still too long, use LangChain's trim_messages utility
|
||||
trimmed_messages = trim_messages(
|
||||
optimized_messages,
|
||||
strategy="last", # Keep most recent messages
|
||||
token_counter=count_tokens_approximately,
|
||||
max_tokens=self.history_token_limit,
|
||||
start_on="human", # Ensure valid conversation start
|
||||
end_on=("human", "tool", "ai"), # Allow ending on human, tool, or AI messages
|
||||
include_system=self.preserve_system, # Preserve system messages
|
||||
allow_partial=False # Don't split individual messages
|
||||
)
|
||||
|
||||
original_count = len(messages)
|
||||
trimmed_count = len(trimmed_messages)
|
||||
|
||||
if trimmed_count < original_count:
|
||||
logger.info(f"Trimmed conversation history: {original_count} -> {trimmed_count} messages")
|
||||
|
||||
return trimmed_messages
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error trimming conversation history: {e}")
|
||||
# Fallback: keep last N messages
|
||||
return self._fallback_trim(list(messages))
|
||||
|
||||
def _optimize_multi_round_tool_calls(self, messages: List[AnyMessage]) -> List[BaseMessage]:
|
||||
"""
|
||||
Optimize conversation history by removing older tool call results in multi-round scenarios.
|
||||
This reduces token usage while preserving conversation context.
|
||||
|
||||
Strategy:
|
||||
1. Always preserve system messages
|
||||
2. Always preserve the original user query
|
||||
3. Keep the most recent AI-Tool message pairs (for context continuity)
|
||||
4. Remove older ToolMessage content which typically contains large JSON responses
|
||||
|
||||
Args:
|
||||
messages: List of conversation messages
|
||||
|
||||
Returns:
|
||||
Optimized list of messages
|
||||
"""
|
||||
if len(messages) <= 4: # Too short to optimize
|
||||
return [msg for msg in messages]
|
||||
|
||||
# Identify message patterns
|
||||
tool_rounds = self._identify_tool_rounds(messages)
|
||||
|
||||
if len(tool_rounds) <= 1: # Single or no tool round, no optimization needed
|
||||
return [msg for msg in messages]
|
||||
|
||||
logger.info(f"Multi-round tool optimization: Found {len(tool_rounds)} tool rounds")
|
||||
|
||||
# Build optimized message list
|
||||
optimized = []
|
||||
|
||||
# Always preserve system messages
|
||||
for msg in messages:
|
||||
if isinstance(msg, SystemMessage):
|
||||
optimized.append(msg)
|
||||
|
||||
# Preserve initial user query (first human message after system)
|
||||
first_human_added = False
|
||||
for msg in messages:
|
||||
if isinstance(msg, HumanMessage) and not first_human_added:
|
||||
optimized.append(msg)
|
||||
first_human_added = True
|
||||
break
|
||||
|
||||
# Keep only the most recent tool round (preserve context for next round)
|
||||
if tool_rounds:
|
||||
latest_round_start, latest_round_end = tool_rounds[-1]
|
||||
|
||||
# Add messages from the latest tool round
|
||||
for i in range(latest_round_start, min(latest_round_end + 1, len(messages))):
|
||||
msg = messages[i]
|
||||
if not isinstance(msg, SystemMessage) and not (isinstance(msg, HumanMessage) and not first_human_added):
|
||||
optimized.append(msg)
|
||||
|
||||
logger.info(f"Multi-round optimization: {len(messages)} -> {len(optimized)} messages (removed {len(tool_rounds)-1} older tool rounds)")
|
||||
return optimized
|
||||
|
||||
def _identify_tool_rounds(self, messages: List[AnyMessage]) -> List[Tuple[int, int]]:
|
||||
"""
|
||||
Identify tool calling rounds in the message sequence.
|
||||
|
||||
A tool round typically consists of:
|
||||
- AI message with tool_calls
|
||||
- One or more ToolMessage responses
|
||||
|
||||
Returns:
|
||||
List of (start_index, end_index) tuples for each tool round
|
||||
"""
|
||||
rounds = []
|
||||
i = 0
|
||||
|
||||
while i < len(messages):
|
||||
msg = messages[i]
|
||||
|
||||
# Look for AI message with tool calls
|
||||
if isinstance(msg, AIMessage) and hasattr(msg, 'tool_calls') and msg.tool_calls:
|
||||
round_start = i
|
||||
round_end = i
|
||||
|
||||
# Find the end of this tool round (look for consecutive ToolMessages)
|
||||
j = i + 1
|
||||
while j < len(messages) and isinstance(messages[j], ToolMessage):
|
||||
round_end = j
|
||||
j += 1
|
||||
|
||||
# Only consider it a tool round if we found at least one ToolMessage
|
||||
if round_end > round_start:
|
||||
rounds.append((round_start, round_end))
|
||||
i = round_end + 1
|
||||
else:
|
||||
i += 1
|
||||
else:
|
||||
i += 1
|
||||
|
||||
return rounds
|
||||
|
||||
def _fallback_trim(self, messages: List[AnyMessage], max_messages: int = 20) -> List[BaseMessage]:
|
||||
"""
|
||||
Fallback trimming based on message count.
|
||||
|
||||
Args:
|
||||
messages: List of conversation messages
|
||||
max_messages: Maximum number of messages to keep
|
||||
|
||||
Returns:
|
||||
Trimmed list of messages
|
||||
"""
|
||||
if len(messages) <= max_messages:
|
||||
return [msg for msg in messages] # Convert to BaseMessage
|
||||
|
||||
# Preserve system message if it exists
|
||||
system_messages = [msg for msg in messages if isinstance(msg, SystemMessage)]
|
||||
other_messages = [msg for msg in messages if not isinstance(msg, SystemMessage)]
|
||||
|
||||
# Keep the most recent messages
|
||||
recent_messages = other_messages[-(max_messages - len(system_messages)):]
|
||||
|
||||
result = system_messages + recent_messages
|
||||
logger.info(f"Fallback trimming: {len(messages)} -> {len(result)} messages")
|
||||
|
||||
return [msg for msg in result] # Ensure BaseMessage type
|
||||
|
||||
def should_trim(self, messages: Sequence[AnyMessage]) -> bool:
|
||||
"""
|
||||
Check if conversation history should be trimmed.
|
||||
|
||||
Strategy:
|
||||
1. Always trim if there are multiple tool rounds from previous conversation turns
|
||||
2. Also trim if approaching token limit
|
||||
|
||||
Args:
|
||||
messages: List of conversation messages
|
||||
|
||||
Returns:
|
||||
True if trimming is needed
|
||||
"""
|
||||
try:
|
||||
# Convert to list for processing
|
||||
message_list = list(messages)
|
||||
|
||||
# Check for multiple tool rounds - if found, always trim to remove old tool results
|
||||
tool_rounds = self._identify_tool_rounds(message_list)
|
||||
if len(tool_rounds) > 1:
|
||||
logger.info(f"Found {len(tool_rounds)} tool rounds - trimming to remove old tool results")
|
||||
return True
|
||||
|
||||
# Also check token count for traditional trimming
|
||||
token_count = count_tokens_approximately(message_list)
|
||||
return token_count > self.history_token_limit
|
||||
except Exception:
|
||||
# Fallback to message count
|
||||
return len(messages) > 30
|
||||
|
||||
|
||||
def create_conversation_trimmer(max_context_length: Optional[int] = None) -> ConversationTrimmer:
|
||||
"""
|
||||
Create a conversation trimmer with config-based settings.
|
||||
|
||||
Args:
|
||||
max_context_length: Override for maximum context length
|
||||
|
||||
Returns:
|
||||
ConversationTrimmer instance
|
||||
"""
|
||||
# If max_context_length is provided, use it directly
|
||||
if max_context_length is not None:
|
||||
return ConversationTrimmer(
|
||||
max_context_length=max_context_length,
|
||||
preserve_system=True
|
||||
)
|
||||
|
||||
# Try to get from config, fallback to default if config not available
|
||||
try:
|
||||
from ..config import get_config
|
||||
config = get_config()
|
||||
effective_max_context_length = config.get_max_context_length()
|
||||
except (RuntimeError, AttributeError):
|
||||
effective_max_context_length = 96000
|
||||
|
||||
return ConversationTrimmer(
|
||||
max_context_length=effective_max_context_length,
|
||||
preserve_system=True
|
||||
)
|
||||
66
vw-agentic-rag/service/graph/state.py
Normal file
66
vw-agentic-rag/service/graph/state.py
Normal file
@@ -0,0 +1,66 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Dict, Any, Optional, Literal
|
||||
from datetime import datetime
|
||||
from typing_extensions import Annotated
|
||||
from langgraph.graph.message import add_messages
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
"""Base message class for conversation history"""
|
||||
role: str # "user", "assistant", "tool"
|
||||
content: str
|
||||
timestamp: Optional[datetime] = None
|
||||
tool_call_id: Optional[str] = None
|
||||
tool_name: Optional[str] = None
|
||||
|
||||
|
||||
class Citation(BaseModel):
|
||||
"""Citation mapping between numbers and result IDs"""
|
||||
number: int
|
||||
result_id: str
|
||||
url: Optional[str] = None
|
||||
|
||||
|
||||
class ToolResult(BaseModel):
|
||||
"""Normalized tool result schema"""
|
||||
id: str
|
||||
title: str
|
||||
url: Optional[str] = None
|
||||
score: Optional[float] = None
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
content: Optional[str] = None # For chunk results
|
||||
# Standard/regulation specific fields
|
||||
publisher: Optional[str] = None
|
||||
publish_date: Optional[str] = None
|
||||
document_code: Optional[str] = None
|
||||
document_category: Optional[str] = None
|
||||
|
||||
|
||||
class TurnState(BaseModel):
|
||||
"""State container for LangGraph workflow"""
|
||||
session_id: str
|
||||
messages: List[Message] = Field(default_factory=list)
|
||||
tool_results: List[ToolResult] = Field(default_factory=list)
|
||||
citations: List[Citation] = Field(default_factory=list)
|
||||
meta: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
# Additional fields for tracking
|
||||
current_step: int = 0
|
||||
max_steps: int = 5
|
||||
final_answer: Optional[str] = None
|
||||
|
||||
|
||||
# TypedDict for LangGraph AgentState (LangGraph native format)
|
||||
from typing import TypedDict
|
||||
from langgraph.graph import MessagesState
|
||||
|
||||
class AgentState(MessagesState):
|
||||
"""LangGraph state with intent recognition support"""
|
||||
session_id: str
|
||||
intent: Optional[Literal["Standard_Regulation_RAG", "User_Manual_RAG"]]
|
||||
tool_results: Annotated[List[Dict[str, Any]], lambda x, y: (x or []) + (y or [])]
|
||||
final_answer: str
|
||||
tool_rounds: int
|
||||
max_tool_rounds: int
|
||||
max_tool_rounds_user_manual: int
|
||||
98
vw-agentic-rag/service/graph/tools.py
Normal file
98
vw-agentic-rag/service/graph/tools.py
Normal file
@@ -0,0 +1,98 @@
|
||||
"""
|
||||
Tool definitions and schemas for the Agentic RAG system.
|
||||
This module contains all tool implementations and their corresponding schemas.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Any, List
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from ..retrieval.retrieval import AgenticRetrieval
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Tool Definitions using @tool decorator (following LangGraph best practices)
|
||||
@tool
|
||||
async def retrieve_standard_regulation(query: str) -> Dict[str, Any]:
|
||||
"""Search for attributes/metadata of China standards and regulations in automobile/manufacturing industry"""
|
||||
async with AgenticRetrieval() as retrieval:
|
||||
try:
|
||||
result = await retrieval.retrieve_standard_regulation(
|
||||
query=query
|
||||
)
|
||||
return {
|
||||
"tool_name": "retrieve_standard_regulation",
|
||||
"results_count": len(result.results),
|
||||
"results": result.results, # Already dict objects, no need for model_dump()
|
||||
"took_ms": result.took_ms
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Retrieval error: {e}")
|
||||
return {"error": str(e), "results_count": 0, "results": []}
|
||||
|
||||
|
||||
@tool
|
||||
async def retrieve_doc_chunk_standard_regulation(query: str) -> Dict[str, Any]:
|
||||
"""Search for detailed document content chunks of China standards and regulations in automobile/manufacturing industry"""
|
||||
async with AgenticRetrieval() as retrieval:
|
||||
try:
|
||||
result = await retrieval.retrieve_doc_chunk_standard_regulation(
|
||||
query=query
|
||||
)
|
||||
return {
|
||||
"tool_name": "retrieve_doc_chunk_standard_regulation",
|
||||
"results_count": len(result.results),
|
||||
"results": result.results, # Already dict objects, no need for model_dump()
|
||||
"took_ms": result.took_ms
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Doc chunk retrieval error: {e}")
|
||||
return {"error": str(e), "results_count": 0, "results": []}
|
||||
|
||||
|
||||
# Available tools list
|
||||
tools = [retrieve_standard_regulation, retrieve_doc_chunk_standard_regulation]
|
||||
|
||||
|
||||
def get_tool_schemas() -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Generate tool schemas for LLM function calling.
|
||||
|
||||
Returns:
|
||||
List of tool schemas in OpenAI function calling format
|
||||
"""
|
||||
tools.append();
|
||||
|
||||
tool_schemas = []
|
||||
for tool in tools:
|
||||
schema = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Search query for retrieving relevant information"
|
||||
}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
}
|
||||
}
|
||||
tool_schemas.append(schema)
|
||||
|
||||
return tool_schemas
|
||||
|
||||
|
||||
def get_tools_by_name() -> Dict[str, Any]:
|
||||
"""
|
||||
Create a mapping of tool names to tool functions.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping tool names to tool functions
|
||||
"""
|
||||
return {tool.name: tool for tool in tools}
|
||||
464
vw-agentic-rag/service/graph/user_manual_rag.py
Normal file
464
vw-agentic-rag/service/graph/user_manual_rag.py
Normal file
@@ -0,0 +1,464 @@
|
||||
"""
|
||||
User Manual Agent node for the Agentic RAG system.
|
||||
This module contains the autonomous user manual agent that can use tools and generate responses.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Any, List, Optional, Callable, Literal
|
||||
from contextvars import ContextVar
|
||||
from langchain_core.messages import AIMessage, SystemMessage, BaseMessage, ToolMessage, HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from .state import AgentState
|
||||
from .user_manual_tools import get_user_manual_tool_schemas, get_user_manual_tools_by_name
|
||||
from .message_trimmer import create_conversation_trimmer
|
||||
from ..llm_client import LLMClient
|
||||
from ..config import get_config
|
||||
from ..sse import (
|
||||
create_tool_start_event,
|
||||
create_tool_result_event,
|
||||
create_tool_error_event,
|
||||
create_token_event,
|
||||
create_error_event
|
||||
)
|
||||
from ..utils.error_handler import (
|
||||
StructuredLogger, ErrorCategory, ErrorCode,
|
||||
handle_async_errors, get_user_message
|
||||
)
|
||||
|
||||
logger = StructuredLogger(__name__)
|
||||
|
||||
# Cache configuration at module level to avoid repeated get_config() calls
|
||||
_cached_config = None
|
||||
|
||||
def get_cached_config():
|
||||
"""Get cached configuration, loading it if not already cached"""
|
||||
global _cached_config
|
||||
if _cached_config is None:
|
||||
_cached_config = get_config()
|
||||
return _cached_config
|
||||
|
||||
|
||||
# User Manual Agent node (autonomous function calling agent)
|
||||
async def user_manual_agent_node(state: AgentState, config: Optional[RunnableConfig] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
User Manual Agent node that autonomously uses user manual tools and generates final answer.
|
||||
Implements "detect-first-then-stream" strategy for optimal multi-round behavior:
|
||||
1. Always start with non-streaming detection to check for tool needs
|
||||
2. If tool_calls exist → return immediately for routing to tools
|
||||
3. If no tool_calls → temporarily disable tools and perform streaming final synthesis
|
||||
"""
|
||||
app_config = get_cached_config()
|
||||
llm_client = LLMClient()
|
||||
|
||||
# Get stream callback from context variable
|
||||
from .graph import stream_callback_context
|
||||
stream_callback = stream_callback_context.get()
|
||||
|
||||
# Get user manual tool schemas and bind tools for planning phase
|
||||
tool_schemas = get_user_manual_tool_schemas()
|
||||
llm_client.bind_tools(tool_schemas, force_tool_choice=True)
|
||||
|
||||
# Create conversation trimmer for managing context length
|
||||
trimmer = create_conversation_trimmer()
|
||||
|
||||
# Prepare messages with user manual system prompt
|
||||
messages = state["messages"].copy()
|
||||
if not messages or not isinstance(messages[0], SystemMessage):
|
||||
rag_prompts = app_config.get_rag_prompts()
|
||||
user_manual_prompt = rag_prompts.get("user_manual_prompt", "")
|
||||
if not user_manual_prompt:
|
||||
raise ValueError("user_manual_prompt is null")
|
||||
|
||||
# For user manual agent, we need to format the prompt with placeholders
|
||||
# Extract current query and conversation history
|
||||
current_query = ""
|
||||
for message in reversed(messages):
|
||||
if isinstance(message, HumanMessage):
|
||||
current_query = message.content
|
||||
break
|
||||
|
||||
conversation_history = ""
|
||||
if len(messages) > 1:
|
||||
conversation_history = render_conversation_history(messages[:-1]) # Exclude current query
|
||||
|
||||
# Format system prompt (initially with empty context, tools will provide it)
|
||||
formatted_system_prompt = user_manual_prompt.format(
|
||||
conversation_history=conversation_history,
|
||||
context_content="", # Will be filled by tools
|
||||
current_query=current_query
|
||||
)
|
||||
|
||||
messages = [SystemMessage(content=formatted_system_prompt)] + messages
|
||||
|
||||
# Track tool rounds
|
||||
current_round = state.get("tool_rounds", 0)
|
||||
# Get max_tool_rounds_user_manual from state, fallback to config if not set
|
||||
max_rounds = state.get("max_tool_rounds_user_manual", None)
|
||||
if max_rounds is None:
|
||||
max_rounds = app_config.app.max_tool_rounds_user_manual
|
||||
|
||||
# Only apply trimming at the start of a new conversation turn (when tool_rounds = 0)
|
||||
# This prevents trimming current turn's tool results during multi-round tool calling
|
||||
if current_round == 0:
|
||||
# Trim conversation history to manage context length (only for previous conversation turns)
|
||||
if trimmer.should_trim(messages):
|
||||
messages = trimmer.trim_conversation_history(messages)
|
||||
logger.info("Applied conversation history trimming for context management (new conversation turn)")
|
||||
else:
|
||||
logger.info(f"Skipping trimming during tool round {current_round} to preserve current turn's context")
|
||||
|
||||
logger.info(f"User Manual Agent node: tool_rounds={current_round}, max_tool_rounds={max_rounds}")
|
||||
|
||||
# Check if this should be final synthesis (max rounds reached)
|
||||
has_tool_messages = any(isinstance(msg, ToolMessage) for msg in messages)
|
||||
is_final_synthesis = has_tool_messages and current_round >= max_rounds
|
||||
|
||||
if is_final_synthesis:
|
||||
logger.info("Starting final synthesis phase - no more tool calls allowed")
|
||||
# ✅ STEP 1: Final synthesis with tools disabled from the start
|
||||
# Disable tools to prevent any tool calling during synthesis
|
||||
try:
|
||||
original_tools = llm_client.bind_tools([], force_tool_choice=False) # Disable tools
|
||||
|
||||
if not stream_callback:
|
||||
# No streaming callback, generate final response without tools
|
||||
draft = await llm_client.ainvoke(list(messages))
|
||||
return {"messages": [draft]}
|
||||
|
||||
# ✅ STEP 2: Streaming final synthesis with improved HTML comment filtering
|
||||
response_content = ""
|
||||
accumulated_content = ""
|
||||
|
||||
async for token in llm_client.astream(list(messages)):
|
||||
accumulated_content += token
|
||||
response_content += token
|
||||
|
||||
# Check for complete HTML comments in accumulated content
|
||||
while "<!--" in accumulated_content and "-->" in accumulated_content:
|
||||
comment_start = accumulated_content.find("<!--")
|
||||
comment_end = accumulated_content.find("-->", comment_start)
|
||||
|
||||
if comment_start >= 0 and comment_end >= 0:
|
||||
# Send content before comment
|
||||
before_comment = accumulated_content[:comment_start]
|
||||
if stream_callback and before_comment:
|
||||
await stream_callback(create_token_event(before_comment))
|
||||
|
||||
# Skip the comment and continue with content after
|
||||
accumulated_content = accumulated_content[comment_end + 3:]
|
||||
else:
|
||||
break
|
||||
|
||||
# Send accumulated content if no pending comment
|
||||
if "<!--" not in accumulated_content:
|
||||
if stream_callback and accumulated_content:
|
||||
await stream_callback(create_token_event(accumulated_content))
|
||||
accumulated_content = ""
|
||||
|
||||
# Send any remaining content (if not in middle of comment)
|
||||
if accumulated_content and "<!--" not in accumulated_content:
|
||||
if stream_callback:
|
||||
await stream_callback(create_token_event(accumulated_content))
|
||||
|
||||
return {"messages": [AIMessage(content=response_content)]}
|
||||
|
||||
finally:
|
||||
# ✅ STEP 3: Restore tool binding for next interaction
|
||||
llm_client.bind_tools(tool_schemas, force_tool_choice=True)
|
||||
|
||||
else:
|
||||
logger.info(f"User Manual tool calling round {current_round + 1}/{max_rounds}")
|
||||
|
||||
# ✅ STEP 1: Non-streaming detection to check for tool needs
|
||||
draft = await llm_client.ainvoke_with_tools(list(messages))
|
||||
|
||||
# ✅ STEP 2: If draft has tool_calls, execute them within this node
|
||||
if isinstance(draft, AIMessage) and hasattr(draft, 'tool_calls') and draft.tool_calls:
|
||||
logger.info(f"Detected {len(draft.tool_calls)} tool calls, executing within user manual agent")
|
||||
|
||||
# Create a new state with the tool call message added
|
||||
tool_call_state = state.copy()
|
||||
updated_messages = state["messages"].copy()
|
||||
updated_messages.append(draft)
|
||||
tool_call_state["messages"] = updated_messages
|
||||
|
||||
# Execute the tools using the existing streaming tool execution function
|
||||
tool_results = await run_user_manual_tools_with_streaming(tool_call_state)
|
||||
tool_messages = tool_results.get("messages", [])
|
||||
|
||||
# Increment tool round counter for next iteration
|
||||
new_tool_rounds = current_round + 1
|
||||
logger.info(f"Incremented user manual tool_rounds to {new_tool_rounds}")
|
||||
|
||||
# Continue with another round if under max rounds
|
||||
if new_tool_rounds < max_rounds:
|
||||
# Recursive call for next round with all messages
|
||||
final_messages = updated_messages + tool_messages
|
||||
recursive_state = state.copy()
|
||||
recursive_state["messages"] = final_messages
|
||||
recursive_state["tool_rounds"] = new_tool_rounds
|
||||
return await user_manual_agent_node(recursive_state)
|
||||
else:
|
||||
# Max rounds reached, force final synthesis
|
||||
logger.info("Max tool rounds reached, forcing final synthesis")
|
||||
# Update messages for final synthesis
|
||||
messages = updated_messages + tool_messages
|
||||
# Continue to final synthesis below
|
||||
|
||||
# ✅ STEP 3: No tool_calls needed or max rounds reached → Enter final synthesis with streaming
|
||||
# Temporarily disable tools to prevent accidental tool calling during synthesis
|
||||
try:
|
||||
llm_client.bind_tools([], force_tool_choice=False) # Disable tools
|
||||
|
||||
if not stream_callback:
|
||||
# No streaming callback, use the draft we already have
|
||||
return {"messages": [draft]}
|
||||
|
||||
# ✅ STEP 4: Streaming final synthesis with improved HTML comment filtering
|
||||
response_content = ""
|
||||
accumulated_content = ""
|
||||
|
||||
async for token in llm_client.astream(list(messages)):
|
||||
accumulated_content += token
|
||||
response_content += token
|
||||
|
||||
# Check for complete HTML comments in accumulated content
|
||||
while "<!--" in accumulated_content and "-->" in accumulated_content:
|
||||
comment_start = accumulated_content.find("<!--")
|
||||
comment_end = accumulated_content.find("-->", comment_start)
|
||||
|
||||
if comment_start >= 0 and comment_end >= 0:
|
||||
# Send content before comment
|
||||
before_comment = accumulated_content[:comment_start]
|
||||
if stream_callback and before_comment:
|
||||
await stream_callback(create_token_event(before_comment))
|
||||
|
||||
# Skip the comment and continue with content after
|
||||
accumulated_content = accumulated_content[comment_end + 3:]
|
||||
else:
|
||||
break
|
||||
|
||||
# Send accumulated content if no pending comment
|
||||
if "<!--" not in accumulated_content:
|
||||
if stream_callback and accumulated_content:
|
||||
await stream_callback(create_token_event(accumulated_content))
|
||||
accumulated_content = ""
|
||||
|
||||
# Send any remaining content (if not in middle of comment)
|
||||
if accumulated_content and "<!--" not in accumulated_content:
|
||||
if stream_callback:
|
||||
await stream_callback(create_token_event(accumulated_content))
|
||||
|
||||
return {"messages": [AIMessage(content=response_content)]}
|
||||
|
||||
finally:
|
||||
# ✅ STEP 5: Restore tool binding for next interaction
|
||||
llm_client.bind_tools(tool_schemas, force_tool_choice=True)
|
||||
|
||||
|
||||
def render_conversation_history(messages, max_messages: int = 10) -> str:
|
||||
"""Render conversation history for context"""
|
||||
recent_messages = messages[-max_messages:] if len(messages) > max_messages else messages
|
||||
lines = []
|
||||
for msg in recent_messages:
|
||||
if hasattr(msg, 'content'):
|
||||
content = msg.content
|
||||
if isinstance(content, str):
|
||||
# Determine message type by class name or other attributes
|
||||
if 'Human' in str(type(msg)):
|
||||
lines.append(f"<user>{content}</user>")
|
||||
elif 'AI' in str(type(msg)):
|
||||
lines.append(f"<ai>{content}</ai>")
|
||||
elif isinstance(content, list):
|
||||
content_str = " ".join([str(item) for item in content if isinstance(item, str)])
|
||||
if 'Human' in str(type(msg)):
|
||||
lines.append(f"<user>{content_str}</user>")
|
||||
elif 'AI' in str(type(msg)):
|
||||
lines.append(f"<ai>{content_str}</ai>")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
# User Manual Tools routing condition
|
||||
def user_manual_should_continue(state: AgentState) -> Literal["user_manual_tools", "user_manual_agent", "post_process"]:
|
||||
"""
|
||||
Routing logic for user manual agent:
|
||||
- has tool_calls → route to user_manual_tools
|
||||
- no tool_calls → route to post_process (final synthesis already completed)
|
||||
"""
|
||||
messages = state["messages"]
|
||||
if not messages:
|
||||
logger.info("user_manual_should_continue: No messages, routing to post_process")
|
||||
return "post_process"
|
||||
|
||||
last_message = messages[-1]
|
||||
current_round = state.get("tool_rounds", 0)
|
||||
# Get max_tool_rounds_user_manual from state, fallback to config if not set
|
||||
max_rounds = state.get("max_tool_rounds_user_manual", None)
|
||||
if max_rounds is None:
|
||||
app_config = get_cached_config()
|
||||
max_rounds = app_config.app.max_tool_rounds_user_manual
|
||||
|
||||
logger.info(f"user_manual_should_continue: Last message type: {type(last_message)}, tool_rounds: {current_round}/{max_rounds}")
|
||||
|
||||
# If last message is AI message with tool calls, route to tools
|
||||
if isinstance(last_message, AIMessage):
|
||||
has_tool_calls = hasattr(last_message, 'tool_calls') and last_message.tool_calls
|
||||
logger.info(f"user_manual_should_continue: AI message has tool_calls: {has_tool_calls}")
|
||||
|
||||
if has_tool_calls:
|
||||
logger.info("user_manual_should_continue: Routing to user_manual_tools")
|
||||
return "user_manual_tools"
|
||||
else:
|
||||
# No tool calls = final synthesis already completed in user_manual_agent_node
|
||||
logger.info("user_manual_should_continue: No tool calls, routing to post_process")
|
||||
return "post_process"
|
||||
|
||||
# If last message is tool message(s), continue with agent for next round or final synthesis
|
||||
if isinstance(last_message, ToolMessage):
|
||||
logger.info("user_manual_should_continue: Tool message completed, continuing to user_manual_agent")
|
||||
return "user_manual_agent"
|
||||
|
||||
logger.info("user_manual_should_continue: Routing to post_process")
|
||||
return "post_process"
|
||||
|
||||
|
||||
# User Manual Tools node with streaming support
|
||||
async def run_user_manual_tools_with_streaming(state: AgentState, config: Optional[RunnableConfig] = None) -> Dict[str, Any]:
|
||||
"""Execute user manual tools with streaming events - supports parallel execution"""
|
||||
messages = state["messages"]
|
||||
last_message = messages[-1]
|
||||
|
||||
# Get stream callback from context variable
|
||||
from .graph import stream_callback_context
|
||||
stream_callback = stream_callback_context.get()
|
||||
|
||||
if not isinstance(last_message, AIMessage) or not hasattr(last_message, 'tool_calls'):
|
||||
return {"messages": []}
|
||||
|
||||
tool_calls = last_message.tool_calls or []
|
||||
tool_results = []
|
||||
new_messages = []
|
||||
|
||||
# User manual tools mapping
|
||||
tools_map = get_user_manual_tools_by_name()
|
||||
|
||||
async def execute_single_tool(tool_call):
|
||||
"""Execute a single user manual tool call with enhanced error handling"""
|
||||
# Get stream callback from context
|
||||
from .graph import stream_callback_context
|
||||
stream_callback = stream_callback_context.get()
|
||||
|
||||
# Apply error handling decorator
|
||||
@handle_async_errors(
|
||||
ErrorCategory.TOOL,
|
||||
ErrorCode.TOOL_ERROR,
|
||||
stream_callback,
|
||||
tool_call.get("id", "unknown") if isinstance(tool_call, dict) else "unknown"
|
||||
)
|
||||
async def _execute():
|
||||
# Validate tool_call format
|
||||
if not isinstance(tool_call, dict):
|
||||
raise ValueError(f"Tool call must be dict, got {type(tool_call)}")
|
||||
|
||||
tool_name = tool_call.get("name")
|
||||
tool_args = tool_call.get("args", {})
|
||||
tool_id = tool_call.get("id", "unknown")
|
||||
|
||||
if not tool_name:
|
||||
raise ValueError("Tool call missing 'name' field")
|
||||
|
||||
if tool_name not in tools_map:
|
||||
available_tools = list(tools_map.keys())
|
||||
raise ValueError(f"Tool '{tool_name}' not found. Available user manual tools: {available_tools}")
|
||||
|
||||
tool_func = tools_map[tool_name]
|
||||
|
||||
# Stream tool start event
|
||||
if stream_callback:
|
||||
await stream_callback(create_tool_start_event(tool_id, tool_name, tool_args))
|
||||
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# Execute the user manual tool
|
||||
result = await tool_func.ainvoke(tool_args)
|
||||
|
||||
# Calculate execution time
|
||||
took_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
# Stream tool result event
|
||||
if stream_callback:
|
||||
await stream_callback(create_tool_result_event(tool_id, tool_name, result, took_ms))
|
||||
|
||||
# Create tool message
|
||||
tool_message = ToolMessage(
|
||||
content=str(result),
|
||||
tool_call_id=tool_id,
|
||||
name=tool_name
|
||||
)
|
||||
|
||||
return tool_message, {"name": tool_name, "result": result, "took_ms": took_ms}
|
||||
|
||||
except Exception as e:
|
||||
took_ms = int((time.time() - start_time) * 1000)
|
||||
error_msg = get_user_message(ErrorCategory.TOOL)
|
||||
|
||||
# Stream tool error event
|
||||
if stream_callback:
|
||||
await stream_callback(create_tool_error_event(tool_id, tool_name, error_msg))
|
||||
|
||||
# Create error tool message
|
||||
tool_message = ToolMessage(
|
||||
content=f"Error executing {tool_name}: {error_msg}",
|
||||
tool_call_id=tool_id,
|
||||
name=tool_name
|
||||
)
|
||||
|
||||
return tool_message, {"name": tool_name, "error": error_msg, "took_ms": took_ms}
|
||||
|
||||
return await _execute()
|
||||
|
||||
# Execute user manual tools (typically just one for user manual retrieval)
|
||||
import asyncio
|
||||
tasks = [execute_single_tool(tool_call) for tool_call in tool_calls]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
for i, result in enumerate(results):
|
||||
if isinstance(result, Exception):
|
||||
# Handle execution exception
|
||||
tool_call = tool_calls[i]
|
||||
tool_id = tool_call.get("id", f"error_{i}") or f"error_{i}"
|
||||
tool_name = tool_call.get("name", "unknown")
|
||||
error_msg = get_user_message(ErrorCategory.TOOL)
|
||||
|
||||
if stream_callback:
|
||||
await stream_callback(create_tool_error_event(tool_id, tool_name, error_msg))
|
||||
|
||||
error_message = ToolMessage(
|
||||
content=f"Error executing {tool_name}: {error_msg}",
|
||||
tool_call_id=tool_id,
|
||||
name=tool_name
|
||||
)
|
||||
new_messages.append(error_message)
|
||||
elif isinstance(result, tuple) and len(result) == 2:
|
||||
# result is a tuple: (tool_message, tool_result)
|
||||
tool_message, tool_result = result
|
||||
new_messages.append(tool_message)
|
||||
tool_results.append(tool_result)
|
||||
else:
|
||||
# Unexpected result format
|
||||
logger.error(f"Unexpected tool execution result format: {type(result)}")
|
||||
continue
|
||||
|
||||
return {"messages": new_messages, "tool_results": tool_results}
|
||||
|
||||
|
||||
# Legacy function for backward compatibility
|
||||
async def user_manual_rag_node(state: AgentState, config: Optional[RunnableConfig] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Legacy user manual RAG node - redirects to new agent-based implementation
|
||||
"""
|
||||
logger.info("📚 USER_MANUAL_RAG_NODE: Redirecting to user_manual_agent_node")
|
||||
return await user_manual_agent_node(state, config)
|
||||
77
vw-agentic-rag/service/graph/user_manual_tools.py
Normal file
77
vw-agentic-rag/service/graph/user_manual_tools.py
Normal file
@@ -0,0 +1,77 @@
|
||||
"""
|
||||
User manual specific tools for the Agentic RAG system.
|
||||
This module contains tools specifically for user manual retrieval and processing.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Any, List
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from ..retrieval.retrieval import AgenticRetrieval
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# User Manual Tools
|
||||
@tool
|
||||
async def retrieve_system_usermanual(query: str) -> Dict[str, Any]:
|
||||
"""Search for document content chunks of user manual of this system(CATOnline)"""
|
||||
async with AgenticRetrieval() as retrieval:
|
||||
try:
|
||||
result = await retrieval.retrieve_doc_chunk_user_manual(
|
||||
query=query
|
||||
)
|
||||
return {
|
||||
"tool_name": "retrieve_system_usermanual",
|
||||
"results_count": len(result.results),
|
||||
"results": result.results, # Already dict objects, no need for model_dump()
|
||||
"took_ms": result.took_ms
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"User manual retrieval error: {e}")
|
||||
return {"error": str(e), "results_count": 0, "results": []}
|
||||
|
||||
|
||||
# User manual tools list
|
||||
user_manual_tools = [retrieve_system_usermanual]
|
||||
|
||||
|
||||
def get_user_manual_tool_schemas() -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Generate tool schemas for user manual tools.
|
||||
|
||||
Returns:
|
||||
List of tool schemas in OpenAI function calling format
|
||||
"""
|
||||
tool_schemas = []
|
||||
for tool in user_manual_tools:
|
||||
schema = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Search query for retrieving relevant information"
|
||||
}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
}
|
||||
}
|
||||
tool_schemas.append(schema)
|
||||
|
||||
return tool_schemas
|
||||
|
||||
|
||||
def get_user_manual_tools_by_name() -> Dict[str, Any]:
|
||||
"""
|
||||
Create a mapping of user manual tool names to tool functions.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping tool names to tool functions
|
||||
"""
|
||||
return {tool.name: tool for tool in user_manual_tools}
|
||||
103
vw-agentic-rag/service/llm_client.py
Normal file
103
vw-agentic-rag/service/llm_client.py
Normal file
@@ -0,0 +1,103 @@
|
||||
from typing import AsyncIterator, Dict, Any, List, Optional
|
||||
from langchain_openai import ChatOpenAI, AzureChatOpenAI
|
||||
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, SystemMessage, ToolMessage
|
||||
from langchain_core.tools import BaseTool
|
||||
import logging
|
||||
|
||||
from .config import get_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LLMClient:
|
||||
"""Wrapper for OpenAI/Azure OpenAI clients with streaming and function calling support"""
|
||||
|
||||
def __init__(self):
|
||||
self.config = get_config()
|
||||
self.llm = self._create_llm()
|
||||
self.llm_with_tools = None
|
||||
|
||||
def _create_llm(self) -> ChatOpenAI | AzureChatOpenAI:
|
||||
"""Create LLM client based on configuration"""
|
||||
llm_config = self.config.get_llm_config()
|
||||
|
||||
if llm_config["provider"] == "openai":
|
||||
# Create base parameters
|
||||
params = {
|
||||
"base_url": llm_config["base_url"],
|
||||
"api_key": llm_config["api_key"],
|
||||
"model": llm_config["model"],
|
||||
"streaming": True,
|
||||
}
|
||||
# Only add temperature if explicitly set
|
||||
if "temperature" in llm_config:
|
||||
params["temperature"] = llm_config["temperature"]
|
||||
return ChatOpenAI(**params)
|
||||
elif llm_config["provider"] == "azure":
|
||||
# Create base parameters
|
||||
params = {
|
||||
"azure_endpoint": llm_config["base_url"],
|
||||
"api_key": llm_config["api_key"],
|
||||
"azure_deployment": llm_config["deployment"],
|
||||
"api_version": llm_config["api_version"],
|
||||
"streaming": True,
|
||||
}
|
||||
# Only add temperature if explicitly set
|
||||
if "temperature" in llm_config:
|
||||
params["temperature"] = llm_config["temperature"]
|
||||
return AzureChatOpenAI(**params)
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {llm_config['provider']}")
|
||||
|
||||
def bind_tools(self, tools: List[Dict[str, Any]], force_tool_choice: bool = False):
|
||||
"""Bind tools to LLM for function calling"""
|
||||
if force_tool_choice:
|
||||
# Use tool_choice="required" to force tool calling for DeepSeek
|
||||
self.llm_with_tools = self.llm.bind_tools(tools, tool_choice="required")
|
||||
else:
|
||||
self.llm_with_tools = self.llm.bind_tools(tools)
|
||||
|
||||
async def astream(self, messages: list[BaseMessage]) -> AsyncIterator[str]:
|
||||
"""Stream LLM response tokens"""
|
||||
try:
|
||||
async for chunk in self.llm.astream(messages):
|
||||
if chunk.content and isinstance(chunk.content, str):
|
||||
yield chunk.content
|
||||
except Exception as e:
|
||||
logger.error(f"LLM streaming error: {e}")
|
||||
raise
|
||||
|
||||
async def ainvoke(self, messages: list[BaseMessage]) -> AIMessage:
|
||||
"""Get complete LLM response"""
|
||||
try:
|
||||
response = await self.llm.ainvoke(messages)
|
||||
if isinstance(response, AIMessage):
|
||||
return response
|
||||
else:
|
||||
# Convert to AIMessage if needed
|
||||
return AIMessage(content=str(response.content) if response.content else "")
|
||||
except Exception as e:
|
||||
logger.error(f"LLM invoke error: {e}")
|
||||
raise
|
||||
|
||||
async def ainvoke_with_tools(self, messages: list[BaseMessage]) -> AIMessage:
|
||||
"""Get LLM response with tool calling capability"""
|
||||
try:
|
||||
if not self.llm_with_tools:
|
||||
raise ValueError("Tools not bound to LLM. Call bind_tools() first.")
|
||||
response = await self.llm_with_tools.ainvoke(messages)
|
||||
if isinstance(response, AIMessage):
|
||||
return response
|
||||
else:
|
||||
return AIMessage(content=str(response.content) if response.content else "")
|
||||
except Exception as e:
|
||||
logger.error(f"LLM with tools invoke error: {e}")
|
||||
raise
|
||||
|
||||
def create_messages(self, system_prompt: str, user_prompt: str) -> list[BaseMessage]:
|
||||
"""Create message list for LLM"""
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append(SystemMessage(content=system_prompt))
|
||||
messages.append(HumanMessage(content=user_prompt))
|
||||
return messages
|
||||
187
vw-agentic-rag/service/main.py
Normal file
187
vw-agentic-rag/service/main.py
Normal file
@@ -0,0 +1,187 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
import uvicorn
|
||||
|
||||
from .config import load_config, get_config
|
||||
from .schemas.messages import ChatRequest
|
||||
from .memory.postgresql_memory import get_memory_manager
|
||||
from .graph.state import TurnState, Message
|
||||
from .graph.graph import build_graph
|
||||
from .sse import create_error_event
|
||||
from .utils.error_handler import StructuredLogger, ErrorCategory, ErrorCode, handle_async_errors
|
||||
from .utils.middleware import ErrorMiddleware
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = StructuredLogger(__name__)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Application lifespan manager"""
|
||||
# Startup
|
||||
try:
|
||||
logger.info("Starting application initialization...")
|
||||
|
||||
# Initialize PostgreSQL memory manager
|
||||
memory_manager = get_memory_manager()
|
||||
connection_ok = memory_manager.test_connection()
|
||||
logger.info(f"PostgreSQL memory manager initialized (connected: {connection_ok})")
|
||||
|
||||
# Initialize global components
|
||||
app.state.memory_manager = memory_manager
|
||||
app.state.graph = build_graph()
|
||||
|
||||
logger.info("Application startup complete")
|
||||
yield
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start application: {e}")
|
||||
raise
|
||||
finally:
|
||||
# Shutdown
|
||||
logger.info("Application shutdown")
|
||||
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
"""Application factory"""
|
||||
# Load configuration first
|
||||
config = load_config()
|
||||
logger.info(f"Loaded configuration for provider: {config.provider}")
|
||||
|
||||
app = FastAPI(
|
||||
title="Agentic RAG API",
|
||||
description="Agentic RAG application for manufacturing standards and regulations",
|
||||
version="0.1.0",
|
||||
lifespan=lifespan
|
||||
)
|
||||
|
||||
# Add error handling middleware
|
||||
app.add_middleware(ErrorMiddleware)
|
||||
|
||||
# Add CORS middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=config.app.cors_origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Define routes
|
||||
@app.post("/api/chat")
|
||||
async def chat_endpoint(request: ChatRequest):
|
||||
"""Main chat endpoint with SSE streaming"""
|
||||
try:
|
||||
return StreamingResponse(
|
||||
stream_chat_response(request),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"Access-Control-Allow-Origin": "*",
|
||||
"Access-Control-Allow-Headers": "*",
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Chat endpoint error: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.post("/api/ai-sdk/chat")
|
||||
async def ai_sdk_chat_endpoint(request: ChatRequest):
|
||||
"""AI SDK compatible chat endpoint"""
|
||||
try:
|
||||
# Import here to avoid circular imports
|
||||
from .ai_sdk_chat import handle_ai_sdk_chat
|
||||
return await handle_ai_sdk_chat(request, app.state)
|
||||
except Exception as e:
|
||||
logger.error(f"AI SDK chat endpoint error: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""Health check endpoint"""
|
||||
return {"status": "healthy", "service": "agentic-rag"}
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
"""Root endpoint"""
|
||||
return {"message": "Agentic RAG API for Manufacturing Standards & Regulations"}
|
||||
|
||||
return app
|
||||
|
||||
|
||||
# Create the global app instance for uvicorn
|
||||
app = create_app()
|
||||
|
||||
|
||||
@handle_async_errors(ErrorCategory.LLM, ErrorCode.LLM_ERROR)
|
||||
async def stream_chat_response(request: ChatRequest) -> AsyncGenerator[str, None]:
|
||||
"""Stream chat response with enhanced error handling"""
|
||||
config = get_config()
|
||||
memory_manager = app.state.memory_manager
|
||||
graph = app.state.graph
|
||||
|
||||
# Create conversation state
|
||||
state = TurnState(session_id=request.session_id)
|
||||
|
||||
# Add user message
|
||||
if request.messages:
|
||||
last_message = request.messages[-1]
|
||||
if last_message.get("role") == "user":
|
||||
user_message = Message(
|
||||
role="user",
|
||||
content=last_message.get("content", "")
|
||||
)
|
||||
state.messages.append(user_message)
|
||||
|
||||
# Create event queue for streaming
|
||||
event_queue = asyncio.Queue()
|
||||
|
||||
async def stream_callback(event_str: str):
|
||||
await event_queue.put(event_str)
|
||||
|
||||
# Execute workflow in background task
|
||||
async def run_workflow():
|
||||
try:
|
||||
async for _ in graph.astream(state, stream_callback):
|
||||
pass
|
||||
await event_queue.put(None) # Signal completion
|
||||
except Exception as e:
|
||||
logger.error("Workflow execution failed", error=e,
|
||||
category=ErrorCategory.LLM, error_code=ErrorCode.LLM_ERROR)
|
||||
await event_queue.put(create_error_event("Processing error: AI service is temporarily unavailable"))
|
||||
await event_queue.put(None)
|
||||
|
||||
# Start workflow task
|
||||
workflow_task = asyncio.create_task(run_workflow())
|
||||
|
||||
# Stream events as they come
|
||||
try:
|
||||
while True:
|
||||
event = await event_queue.get()
|
||||
if event is None: # Completion signal
|
||||
break
|
||||
yield event
|
||||
finally:
|
||||
if not workflow_task.done():
|
||||
workflow_task.cancel()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
config = load_config() # Load configuration first
|
||||
uvicorn.run(
|
||||
"service.main:app",
|
||||
host=config.app.host,
|
||||
port=config.app.port,
|
||||
reload=True,
|
||||
log_level="info"
|
||||
)
|
||||
1
vw-agentic-rag/service/memory/__init__.py
Normal file
1
vw-agentic-rag/service/memory/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Empty __init__.py files to make packages
|
||||
332
vw-agentic-rag/service/memory/postgresql_memory.py
Normal file
332
vw-agentic-rag/service/memory/postgresql_memory.py
Normal file
@@ -0,0 +1,332 @@
|
||||
"""
|
||||
PostgreSQL-based memory implementation using LangGraph built-in components.
|
||||
Provides session-level chat history with 7-day TTL.
|
||||
Uses psycopg3 for better compatibility without requiring libpq-dev.
|
||||
"""
|
||||
import logging
|
||||
from typing import Dict, Any, Optional
|
||||
from urllib.parse import quote_plus
|
||||
from contextlib import contextmanager
|
||||
|
||||
try:
|
||||
import psycopg
|
||||
from psycopg.rows import dict_row
|
||||
PSYCOPG_AVAILABLE = True
|
||||
except ImportError as e:
|
||||
logging.warning(f"psycopg3 not available: {e}")
|
||||
PSYCOPG_AVAILABLE = False
|
||||
psycopg = None
|
||||
|
||||
try:
|
||||
from langgraph.checkpoint.postgres import PostgresSaver
|
||||
LANGGRAPH_POSTGRES_AVAILABLE = True
|
||||
except ImportError as e:
|
||||
logging.warning(f"LangGraph PostgreSQL checkpoint not available: {e}")
|
||||
LANGGRAPH_POSTGRES_AVAILABLE = False
|
||||
PostgresSaver = None
|
||||
|
||||
try:
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
LANGGRAPH_MEMORY_AVAILABLE = True
|
||||
except ImportError as e:
|
||||
logging.warning(f"LangGraph memory checkpoint not available: {e}")
|
||||
LANGGRAPH_MEMORY_AVAILABLE = False
|
||||
InMemorySaver = None
|
||||
|
||||
from ..config import get_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
POSTGRES_AVAILABLE = PSYCOPG_AVAILABLE and LANGGRAPH_POSTGRES_AVAILABLE
|
||||
|
||||
|
||||
class PostgreSQLCheckpointerWrapper:
|
||||
"""
|
||||
Wrapper for PostgresSaver that manages the context properly.
|
||||
"""
|
||||
|
||||
def __init__(self, conn_string: str):
|
||||
if not LANGGRAPH_POSTGRES_AVAILABLE or PostgresSaver is None:
|
||||
raise RuntimeError("PostgresSaver not available")
|
||||
self.conn_string = conn_string
|
||||
self._initialized = False
|
||||
|
||||
def _ensure_setup(self):
|
||||
"""Ensure the database schema is set up."""
|
||||
if not LANGGRAPH_POSTGRES_AVAILABLE or PostgresSaver is None:
|
||||
raise RuntimeError("PostgresSaver not available")
|
||||
if not self._initialized:
|
||||
with PostgresSaver.from_conn_string(self.conn_string) as saver:
|
||||
saver.setup()
|
||||
self._initialized = True
|
||||
logger.info("PostgreSQL schema initialized")
|
||||
|
||||
@contextmanager
|
||||
def get_saver(self):
|
||||
"""Get a PostgresSaver instance as context manager."""
|
||||
if not LANGGRAPH_POSTGRES_AVAILABLE or PostgresSaver is None:
|
||||
raise RuntimeError("PostgresSaver not available")
|
||||
self._ensure_setup()
|
||||
with PostgresSaver.from_conn_string(self.conn_string) as saver:
|
||||
yield saver
|
||||
|
||||
def list(self, config):
|
||||
"""List checkpoints."""
|
||||
with self.get_saver() as saver:
|
||||
return list(saver.list(config))
|
||||
|
||||
def get(self, config):
|
||||
"""Get a checkpoint."""
|
||||
with self.get_saver() as saver:
|
||||
return saver.get(config)
|
||||
|
||||
def get_tuple(self, config):
|
||||
"""Get a checkpoint tuple."""
|
||||
with self.get_saver() as saver:
|
||||
return saver.get_tuple(config)
|
||||
|
||||
def put(self, config, checkpoint, metadata, new_versions):
|
||||
"""Put a checkpoint."""
|
||||
with self.get_saver() as saver:
|
||||
return saver.put(config, checkpoint, metadata, new_versions)
|
||||
|
||||
def put_writes(self, config, writes, task_id):
|
||||
"""Put writes."""
|
||||
with self.get_saver() as saver:
|
||||
return saver.put_writes(config, writes, task_id)
|
||||
|
||||
def get_next_version(self, current, channel):
|
||||
"""Get next version."""
|
||||
with self.get_saver() as saver:
|
||||
return saver.get_next_version(current, channel)
|
||||
|
||||
def delete_thread(self, thread_id):
|
||||
"""Delete thread."""
|
||||
with self.get_saver() as saver:
|
||||
return saver.delete_thread(thread_id)
|
||||
|
||||
# Async methods
|
||||
async def alist(self, config):
|
||||
"""Async list checkpoints."""
|
||||
with self.get_saver() as saver:
|
||||
async for item in saver.alist(config):
|
||||
yield item
|
||||
|
||||
async def aget(self, config):
|
||||
"""Async get a checkpoint."""
|
||||
with self.get_saver() as saver:
|
||||
# PostgresSaver might not have async version, try sync first
|
||||
try:
|
||||
return await saver.aget(config)
|
||||
except NotImplementedError:
|
||||
# Fall back to sync version in a thread
|
||||
import asyncio
|
||||
return await asyncio.get_event_loop().run_in_executor(
|
||||
None, saver.get, config
|
||||
)
|
||||
|
||||
async def aget_tuple(self, config):
|
||||
"""Async get a checkpoint tuple."""
|
||||
with self.get_saver() as saver:
|
||||
# PostgresSaver might not have async version, try sync first
|
||||
try:
|
||||
return await saver.aget_tuple(config)
|
||||
except NotImplementedError:
|
||||
# Fall back to sync version in a thread
|
||||
import asyncio
|
||||
return await asyncio.get_event_loop().run_in_executor(
|
||||
None, saver.get_tuple, config
|
||||
)
|
||||
|
||||
async def aput(self, config, checkpoint, metadata, new_versions):
|
||||
"""Async put a checkpoint."""
|
||||
with self.get_saver() as saver:
|
||||
# PostgresSaver might not have async version, try sync first
|
||||
try:
|
||||
return await saver.aput(config, checkpoint, metadata, new_versions)
|
||||
except NotImplementedError:
|
||||
# Fall back to sync version in a thread
|
||||
import asyncio
|
||||
return await asyncio.get_event_loop().run_in_executor(
|
||||
None, saver.put, config, checkpoint, metadata, new_versions
|
||||
)
|
||||
|
||||
async def aput_writes(self, config, writes, task_id):
|
||||
"""Async put writes."""
|
||||
with self.get_saver() as saver:
|
||||
# PostgresSaver might not have async version, try sync first
|
||||
try:
|
||||
return await saver.aput_writes(config, writes, task_id)
|
||||
except NotImplementedError:
|
||||
# Fall back to sync version in a thread
|
||||
import asyncio
|
||||
return await asyncio.get_event_loop().run_in_executor(
|
||||
None, saver.put_writes, config, writes, task_id
|
||||
)
|
||||
|
||||
async def adelete_thread(self, thread_id):
|
||||
"""Async delete thread."""
|
||||
with self.get_saver() as saver:
|
||||
return await saver.adelete_thread(thread_id)
|
||||
|
||||
@property
|
||||
def config_specs(self):
|
||||
"""Get config specs."""
|
||||
with self.get_saver() as saver:
|
||||
return saver.config_specs
|
||||
|
||||
@property
|
||||
def serde(self):
|
||||
"""Get serde."""
|
||||
with self.get_saver() as saver:
|
||||
return saver.serde
|
||||
|
||||
|
||||
class PostgreSQLMemoryManager:
|
||||
"""
|
||||
PostgreSQL-based memory manager using LangGraph's built-in components.
|
||||
Falls back to in-memory storage if PostgreSQL is not available.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.config = get_config()
|
||||
self.pg_config = self.config.postgresql
|
||||
self._checkpointer: Optional[Any] = None
|
||||
self._postgres_available = POSTGRES_AVAILABLE
|
||||
|
||||
def _get_connection_string(self) -> str:
|
||||
"""Get PostgreSQL connection string."""
|
||||
if not self._postgres_available:
|
||||
return ""
|
||||
|
||||
# URL encode password to handle special characters
|
||||
encoded_password = quote_plus(self.pg_config.password)
|
||||
|
||||
return (
|
||||
f"postgresql://{self.pg_config.username}:{encoded_password}@"
|
||||
f"{self.pg_config.host}:{self.pg_config.port}/{self.pg_config.database}"
|
||||
)
|
||||
|
||||
def _test_connection(self) -> bool:
|
||||
"""Test PostgreSQL connection."""
|
||||
if not self._postgres_available:
|
||||
return False
|
||||
|
||||
if not PSYCOPG_AVAILABLE or psycopg is None:
|
||||
return False
|
||||
|
||||
try:
|
||||
conn_string = self._get_connection_string()
|
||||
with psycopg.connect(conn_string) as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute("SELECT 1")
|
||||
result = cur.fetchone()
|
||||
logger.info("PostgreSQL connection test successful")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"PostgreSQL connection test failed: {e}")
|
||||
return False
|
||||
|
||||
def _setup_ttl_cleanup(self):
|
||||
"""Setup TTL cleanup for old records."""
|
||||
if not self._postgres_available or not PSYCOPG_AVAILABLE or psycopg is None:
|
||||
return
|
||||
|
||||
try:
|
||||
conn_string = self._get_connection_string()
|
||||
with psycopg.connect(conn_string, autocommit=True) as conn:
|
||||
with conn.cursor() as cur:
|
||||
# Create a function to clean up old records for LangGraph tables
|
||||
# Note: LangGraph tables don't have created_at, so we'll use a different approach
|
||||
cleanup_sql = f"""
|
||||
CREATE OR REPLACE FUNCTION cleanup_old_checkpoints()
|
||||
RETURNS void AS $$
|
||||
BEGIN
|
||||
-- LangGraph tables don't have created_at columns
|
||||
-- We can clean based on checkpoint_id pattern or use a different strategy
|
||||
-- For now, just return successfully without actual cleanup
|
||||
-- You can implement custom logic based on your requirements
|
||||
RAISE NOTICE 'Cleanup function called - custom cleanup logic needed';
|
||||
END;
|
||||
$$ LANGUAGE plpgsql;
|
||||
"""
|
||||
cur.execute(cleanup_sql)
|
||||
|
||||
logger.info(f"TTL cleanup function created with {self.pg_config.ttl_days}-day retention")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to setup TTL cleanup (this is optional): {e}")
|
||||
|
||||
def cleanup_old_data(self):
|
||||
"""Manually trigger cleanup of old data."""
|
||||
if not self._postgres_available or not PSYCOPG_AVAILABLE or psycopg is None:
|
||||
return
|
||||
|
||||
try:
|
||||
conn_string = self._get_connection_string()
|
||||
with psycopg.connect(conn_string, autocommit=True) as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute("SELECT cleanup_old_checkpoints()")
|
||||
logger.info("Manual cleanup of old data completed")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to cleanup old data: {e}")
|
||||
|
||||
def get_checkpointer(self):
|
||||
"""Get checkpointer for conversation history (PostgreSQL if available, else in-memory)."""
|
||||
if self._checkpointer is None:
|
||||
if self._postgres_available:
|
||||
try:
|
||||
# Test connection first
|
||||
if not self._test_connection():
|
||||
raise Exception("PostgreSQL connection test failed")
|
||||
|
||||
# Setup TTL cleanup function
|
||||
self._setup_ttl_cleanup()
|
||||
|
||||
# Create checkpointer wrapper
|
||||
conn_string = self._get_connection_string()
|
||||
if LANGGRAPH_POSTGRES_AVAILABLE:
|
||||
self._checkpointer = PostgreSQLCheckpointerWrapper(conn_string)
|
||||
else:
|
||||
raise Exception("LangGraph PostgreSQL checkpoint not available")
|
||||
|
||||
logger.info(f"PostgreSQL checkpointer initialized with {self.pg_config.ttl_days}-day TTL")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize PostgreSQL checkpointer, falling back to in-memory: {e}")
|
||||
if LANGGRAPH_MEMORY_AVAILABLE and InMemorySaver is not None:
|
||||
self._checkpointer = InMemorySaver()
|
||||
else:
|
||||
logger.error("InMemorySaver not available - no checkpointer available")
|
||||
self._checkpointer = None
|
||||
else:
|
||||
logger.info("PostgreSQL not available, using in-memory checkpointer")
|
||||
if LANGGRAPH_MEMORY_AVAILABLE and InMemorySaver is not None:
|
||||
self._checkpointer = InMemorySaver()
|
||||
else:
|
||||
logger.error("InMemorySaver not available - no checkpointer available")
|
||||
self._checkpointer = None
|
||||
|
||||
return self._checkpointer
|
||||
|
||||
def test_connection(self) -> bool:
|
||||
"""Test PostgreSQL connection and return True if successful."""
|
||||
return self._test_connection()
|
||||
|
||||
|
||||
# Global memory manager instance
|
||||
_memory_manager: Optional[PostgreSQLMemoryManager] = None
|
||||
|
||||
|
||||
def get_memory_manager() -> PostgreSQLMemoryManager:
|
||||
"""Get global PostgreSQL memory manager instance."""
|
||||
global _memory_manager
|
||||
if _memory_manager is None:
|
||||
_memory_manager = PostgreSQLMemoryManager()
|
||||
return _memory_manager
|
||||
|
||||
|
||||
def get_checkpointer():
|
||||
"""Get checkpointer for conversation history."""
|
||||
return get_memory_manager().get_checkpointer()
|
||||
137
vw-agentic-rag/service/memory/redis_memory.py
Normal file
137
vw-agentic-rag/service/memory/redis_memory.py
Normal file
@@ -0,0 +1,137 @@
|
||||
"""
|
||||
Redis-based memory implementation using LangGraph built-in components.
|
||||
Provides session-level chat history with 7-day TTL.
|
||||
"""
|
||||
import logging
|
||||
import ssl
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
try:
|
||||
import redis
|
||||
from redis.exceptions import ConnectionError, TimeoutError
|
||||
from langgraph.checkpoint.redis import RedisSaver
|
||||
REDIS_AVAILABLE = True
|
||||
except ImportError as e:
|
||||
logging.warning(f"Redis packages not available: {e}")
|
||||
REDIS_AVAILABLE = False
|
||||
redis = None
|
||||
RedisSaver = None
|
||||
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
from ..config import get_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RedisMemoryManager:
|
||||
"""
|
||||
Redis-based memory manager using LangGraph's built-in components.
|
||||
Falls back to in-memory storage if Redis is not available.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.config = get_config()
|
||||
self.redis_config = self.config.redis
|
||||
self._checkpointer: Optional[Any] = None
|
||||
self._redis_available = REDIS_AVAILABLE
|
||||
|
||||
def _get_redis_client_kwargs(self) -> Dict[str, Any]:
|
||||
"""Get Redis client configuration for Azure Redis compatibility."""
|
||||
if not self._redis_available:
|
||||
return {}
|
||||
|
||||
kwargs = {
|
||||
"host": self.redis_config.host,
|
||||
"port": self.redis_config.port,
|
||||
"password": self.redis_config.password,
|
||||
"db": self.redis_config.db,
|
||||
"decode_responses": False, # Required for RedisSaver
|
||||
"socket_timeout": 30,
|
||||
"socket_connect_timeout": 10,
|
||||
"retry_on_timeout": True,
|
||||
"health_check_interval": 30,
|
||||
}
|
||||
|
||||
if self.redis_config.use_ssl:
|
||||
kwargs.update({
|
||||
"ssl": True,
|
||||
"ssl_cert_reqs": ssl.CERT_REQUIRED,
|
||||
"ssl_check_hostname": True,
|
||||
})
|
||||
|
||||
return kwargs
|
||||
|
||||
def _get_ttl_config(self) -> Dict[str, Any]:
|
||||
"""Get TTL configuration for automatic cleanup."""
|
||||
ttl_days = self.redis_config.ttl_days
|
||||
ttl_minutes = ttl_days * 24 * 60 # Convert days to minutes
|
||||
|
||||
return {
|
||||
"default_ttl": ttl_minutes,
|
||||
"refresh_on_read": True, # Refresh TTL when accessed
|
||||
}
|
||||
|
||||
def get_checkpointer(self):
|
||||
"""Get checkpointer for conversation history (Redis if available, else in-memory)."""
|
||||
if self._checkpointer is None:
|
||||
if self._redis_available:
|
||||
try:
|
||||
ttl_config = self._get_ttl_config()
|
||||
|
||||
# Create Redis client with proper configuration for Azure Redis
|
||||
redis_client = redis.Redis(**self._get_redis_client_kwargs())
|
||||
|
||||
# Test connection
|
||||
redis_client.ping()
|
||||
logger.info("Redis connection established successfully")
|
||||
|
||||
# Create checkpointer with TTL support
|
||||
self._checkpointer = RedisSaver(
|
||||
redis_client=redis_client,
|
||||
ttl=ttl_config
|
||||
)
|
||||
|
||||
# Initialize indices (required for first-time setup)
|
||||
self._checkpointer.setup()
|
||||
logger.info(f"Redis checkpointer initialized with {self.redis_config.ttl_days}-day TTL")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize Redis checkpointer, falling back to in-memory: {e}")
|
||||
self._checkpointer = InMemorySaver()
|
||||
else:
|
||||
logger.info("Redis not available, using in-memory checkpointer")
|
||||
self._checkpointer = InMemorySaver()
|
||||
|
||||
return self._checkpointer
|
||||
|
||||
def test_connection(self) -> bool:
|
||||
"""Test Redis connection and return True if successful."""
|
||||
if not self._redis_available:
|
||||
logger.warning("Redis packages not available")
|
||||
return False
|
||||
|
||||
try:
|
||||
redis_client = redis.Redis(**self._get_redis_client_kwargs())
|
||||
redis_client.ping()
|
||||
logger.info("Redis connection test successful")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Redis connection test failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
# Global memory manager instance
|
||||
_memory_manager: Optional[RedisMemoryManager] = None
|
||||
|
||||
|
||||
def get_memory_manager() -> RedisMemoryManager:
|
||||
"""Get global Redis memory manager instance."""
|
||||
global _memory_manager
|
||||
if _memory_manager is None:
|
||||
_memory_manager = RedisMemoryManager()
|
||||
return _memory_manager
|
||||
|
||||
|
||||
def get_checkpointer():
|
||||
"""Get checkpointer for conversation history."""
|
||||
return get_memory_manager().get_checkpointer()
|
||||
113
vw-agentic-rag/service/memory/store.py
Normal file
113
vw-agentic-rag/service/memory/store.py
Normal file
@@ -0,0 +1,113 @@
|
||||
from typing import Dict, Any, Optional
|
||||
from datetime import datetime, timedelta
|
||||
import logging
|
||||
|
||||
from .postgresql_memory import get_memory_manager, get_checkpointer
|
||||
from ..graph.state import TurnState, Message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class InMemoryStore:
|
||||
"""Simple in-memory store with TTL for conversation history"""
|
||||
|
||||
def __init__(self, ttl_days: float = 7.0):
|
||||
self.ttl_days = ttl_days
|
||||
self.store: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
def _is_expired(self, timestamp: datetime) -> bool:
|
||||
"""Check if a record has expired"""
|
||||
return datetime.now() - timestamp > timedelta(days=self.ttl_days)
|
||||
|
||||
def _cleanup_expired(self) -> None:
|
||||
"""Remove expired records"""
|
||||
expired_keys = []
|
||||
for session_id, data in self.store.items():
|
||||
if self._is_expired(data.get("last_updated", datetime.min)):
|
||||
expired_keys.append(session_id)
|
||||
|
||||
for key in expired_keys:
|
||||
del self.store[key]
|
||||
logger.info(f"Cleaned up expired session: {key}")
|
||||
|
||||
def get(self, session_id: str) -> Optional[TurnState]:
|
||||
"""Get conversation state for a session"""
|
||||
self._cleanup_expired()
|
||||
|
||||
if session_id not in self.store:
|
||||
return None
|
||||
|
||||
data = self.store[session_id]
|
||||
if self._is_expired(data.get("last_updated", datetime.min)):
|
||||
del self.store[session_id]
|
||||
return None
|
||||
|
||||
try:
|
||||
# Reconstruct TurnState from stored data
|
||||
state_data = data["state"]
|
||||
return TurnState(**state_data)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to deserialize state for session {session_id}: {e}")
|
||||
return None
|
||||
|
||||
def put(self, session_id: str, state: TurnState) -> None:
|
||||
"""Store conversation state for a session"""
|
||||
try:
|
||||
self.store[session_id] = {
|
||||
"state": state.model_dump(),
|
||||
"last_updated": datetime.now()
|
||||
}
|
||||
logger.debug(f"Stored state for session: {session_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to store state for session {session_id}: {e}")
|
||||
|
||||
def trim(self, session_id: str, max_messages: int = 20) -> None:
|
||||
"""Trim old messages to stay within token limits"""
|
||||
state = self.get(session_id)
|
||||
if not state:
|
||||
return
|
||||
|
||||
if len(state.messages) > max_messages:
|
||||
# Keep system message (if any) and recent user/assistant pairs
|
||||
trimmed_messages = state.messages[-max_messages:]
|
||||
|
||||
# Try to preserve complete conversation turns
|
||||
if len(trimmed_messages) > 1 and trimmed_messages[0].role == "assistant":
|
||||
trimmed_messages = trimmed_messages[1:]
|
||||
|
||||
state.messages = trimmed_messages
|
||||
self.put(session_id, state)
|
||||
logger.info(f"Trimmed messages for session {session_id} to {len(trimmed_messages)}")
|
||||
|
||||
def create_new_session(self, session_id: str) -> TurnState:
|
||||
"""Create a new conversation session"""
|
||||
state = TurnState(session_id=session_id)
|
||||
self.put(session_id, state)
|
||||
return state
|
||||
|
||||
def add_message(self, session_id: str, message: Message) -> None:
|
||||
"""Add a message to the conversation history"""
|
||||
state = self.get(session_id)
|
||||
if not state:
|
||||
state = self.create_new_session(session_id)
|
||||
|
||||
state.messages.append(message)
|
||||
self.put(session_id, state)
|
||||
|
||||
def get_conversation_history(self, session_id: str, max_turns: int = 10) -> str:
|
||||
"""Get formatted conversation history for prompts"""
|
||||
state = self.get(session_id)
|
||||
if not state or not state.messages:
|
||||
return ""
|
||||
|
||||
# Get recent messages, keeping complete turns
|
||||
recent_messages = state.messages[-(max_turns * 2):]
|
||||
|
||||
history_parts = []
|
||||
for msg in recent_messages:
|
||||
if msg.role == "user":
|
||||
history_parts.append(f"User: {msg.content}")
|
||||
elif msg.role == "assistant" and not msg.tool_call_id:
|
||||
history_parts.append(f"Assistant: {msg.content}")
|
||||
|
||||
return "\n".join(history_parts)
|
||||
1
vw-agentic-rag/service/retrieval/__init__.py
Normal file
1
vw-agentic-rag/service/retrieval/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Empty __init__.py files to make packages
|
||||
181
vw-agentic-rag/service/retrieval/clients.py
Normal file
181
vw-agentic-rag/service/retrieval/clients.py
Normal file
@@ -0,0 +1,181 @@
|
||||
"""
|
||||
Azure AI Search client utilities for retrieval operations.
|
||||
Contains shared functionality for interacting with Azure AI Search and embedding services.
|
||||
"""
|
||||
|
||||
import httpx
|
||||
import logging
|
||||
from typing import Dict, Any, List, Optional
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
|
||||
|
||||
from ..config import get_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RetrievalAPIError(Exception):
|
||||
"""Custom exception for retrieval API errors"""
|
||||
pass
|
||||
|
||||
|
||||
class AzureSearchClient:
|
||||
"""Shared Azure AI Search client for embedding and search operations"""
|
||||
|
||||
def __init__(self):
|
||||
self.config = get_config()
|
||||
self.search_endpoint = self.config.retrieval.endpoint
|
||||
self.api_key = self.config.retrieval.api_key
|
||||
self.api_version = self.config.retrieval.api_version
|
||||
self.semantic_configuration = self.config.retrieval.semantic_configuration
|
||||
self.embedding_client = httpx.AsyncClient(timeout=30.0)
|
||||
self.search_client = httpx.AsyncClient(timeout=30.0)
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
await self.embedding_client.aclose()
|
||||
await self.search_client.aclose()
|
||||
|
||||
async def get_embedding(self, text: str) -> List[float]:
|
||||
"""Get embedding vector for text using the configured embedding service"""
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.config.retrieval.embedding.api_key}"
|
||||
}
|
||||
|
||||
payload = {
|
||||
"input": text,
|
||||
"model": self.config.retrieval.embedding.model
|
||||
}
|
||||
|
||||
try:
|
||||
req_url = f"{self.config.retrieval.embedding.base_url}/embeddings"
|
||||
if self.config.retrieval.embedding.api_version:
|
||||
req_url += f"?api-version={self.config.retrieval.embedding.api_version}"
|
||||
|
||||
response = await self.embedding_client.post(req_url, json=payload, headers=headers)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
return result["data"][0]["embedding"]
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get embedding: {e}")
|
||||
raise RetrievalAPIError(f"Embedding generation failed: {str(e)}")
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
retry=retry_if_exception_type((httpx.HTTPStatusError, httpx.TimeoutException))
|
||||
)
|
||||
async def search_azure_ai(
|
||||
self,
|
||||
index_name: str,
|
||||
search_text: str,
|
||||
vector_fields: str,
|
||||
select_fields: str,
|
||||
search_fields: str,
|
||||
filter_query: Optional[str] = None,
|
||||
top_k: int = 10,
|
||||
score_threshold: float = 1.5
|
||||
) -> Dict[str, Any]:
|
||||
"""Make hybrid search request to Azure AI Search with semantic ranking"""
|
||||
|
||||
# Get embedding vector for the query
|
||||
query_vector = await self.get_embedding(search_text)
|
||||
|
||||
# Build vector queries based on the vector fields
|
||||
vector_queries = []
|
||||
for field in vector_fields.split(","):
|
||||
field = field.strip()
|
||||
vector_queries.append({
|
||||
"kind": "vector",
|
||||
"vector": query_vector,
|
||||
"fields": field,
|
||||
"k": top_k
|
||||
})
|
||||
|
||||
# Build the search request payload
|
||||
search_payload = {
|
||||
"search": search_text,
|
||||
"select": select_fields,
|
||||
"searchFields": search_fields,
|
||||
"top": top_k,
|
||||
"queryType": "semantic",
|
||||
"semanticConfiguration": self.semantic_configuration,
|
||||
"vectorQueries": vector_queries
|
||||
}
|
||||
|
||||
if filter_query:
|
||||
search_payload["filter"] = filter_query
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"api-key": self.api_key
|
||||
}
|
||||
|
||||
search_url = f"{self.search_endpoint}/indexes/{index_name}/docs/search"
|
||||
|
||||
try:
|
||||
response = await self.search_client.post(
|
||||
search_url,
|
||||
json=search_payload,
|
||||
headers=headers,
|
||||
params={"api-version": self.api_version}
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
|
||||
# Filter results by reranker score and add order numbers
|
||||
filtered_results = []
|
||||
for i, item in enumerate(result.get("value", [])):
|
||||
reranker_score = item.get("@search.rerankerScore", 0)
|
||||
if reranker_score >= score_threshold:
|
||||
# Add order number
|
||||
item["@order_num"] = i + 1
|
||||
# Normalize the result (removes unwanted fields and empty values)
|
||||
normalized_item = normalize_search_result(item)
|
||||
filtered_results.append(normalized_item)
|
||||
|
||||
return {"value": filtered_results}
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"Azure AI Search HTTP error {e.response.status_code}: {e.response.text}")
|
||||
raise RetrievalAPIError(f"Azure AI Search request failed: {e.response.status_code}")
|
||||
except httpx.TimeoutException:
|
||||
logger.error("Azure AI Search request timeout")
|
||||
raise RetrievalAPIError("Azure AI Search request timeout")
|
||||
except Exception as e:
|
||||
logger.error(f"Azure AI Search unexpected error: {e}")
|
||||
raise RetrievalAPIError(f"Azure AI Search unexpected error: {str(e)}")
|
||||
|
||||
|
||||
def normalize_search_result(raw_result: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Normalize raw Azure AI Search result to clean dynamic structure
|
||||
|
||||
Args:
|
||||
raw_result: Raw result from Azure AI Search
|
||||
|
||||
Returns:
|
||||
Cleaned and normalized result dictionary
|
||||
"""
|
||||
# Fields to remove if they exist (belt and suspenders approach)
|
||||
fields_to_remove = {
|
||||
"@search.score",
|
||||
"@search.rerankerScore",
|
||||
"@search.captions",
|
||||
"@subquery_id"
|
||||
}
|
||||
|
||||
# Create a copy and remove unwanted fields
|
||||
result = raw_result.copy()
|
||||
for field in fields_to_remove:
|
||||
result.pop(field, None)
|
||||
|
||||
# Remove empty fields (None, empty string, empty list, empty dict)
|
||||
result = {
|
||||
key: value for key, value in result.items()
|
||||
if value is not None and value != "" and value != [] and value != {}
|
||||
}
|
||||
|
||||
return result
|
||||
58
vw-agentic-rag/service/retrieval/generic_chunk_retrieval.py
Normal file
58
vw-agentic-rag/service/retrieval/generic_chunk_retrieval.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import logging
|
||||
|
||||
|
||||
import time
|
||||
|
||||
from ..config import get_config
|
||||
from service.retrieval.clients import AzureSearchClient
|
||||
from service.retrieval.model import RetrievalResponse
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GenericChunkRetrieval:
|
||||
def __init__(self)->None:
|
||||
self.config = get_config()
|
||||
self.search_client = AzureSearchClient()
|
||||
|
||||
async def retrieve_doc_chunk(
|
||||
self,
|
||||
query: str,
|
||||
conversation_history: str = "",
|
||||
**kwargs
|
||||
) -> RetrievalResponse:
|
||||
"""Search CATOnline system user manual document chunks"""
|
||||
start_time = time.time()
|
||||
|
||||
# Use the new Azure AI Search approach
|
||||
index_name = self.config.retrieval.index.chunk_user_manual_index
|
||||
vector_fields = "contentVector"
|
||||
select_fields = "content, title, full_headers"
|
||||
search_fields = "content, title, full_headers"
|
||||
|
||||
top_k = kwargs.get("top_k", 10)
|
||||
score_threshold = kwargs.get("score_threshold", 1.5)
|
||||
|
||||
try:
|
||||
response_data = await self.search_client.search_azure_ai(
|
||||
index_name=index_name,
|
||||
search_text=query,
|
||||
vector_fields=vector_fields,
|
||||
select_fields=select_fields,
|
||||
search_fields=search_fields,
|
||||
top_k=top_k,
|
||||
score_threshold=score_threshold
|
||||
)
|
||||
|
||||
results = response_data.get("value", [])
|
||||
|
||||
took_ms = int((time.time() - start_time) * 1000)
|
||||
return RetrievalResponse(
|
||||
results=results,
|
||||
took_ms=took_ms,
|
||||
total_count=len(results)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"retrieve_doc_chunk_user_manual failed: {e}")
|
||||
raise
|
||||
11
vw-agentic-rag/service/retrieval/model.py
Normal file
11
vw-agentic-rag/service/retrieval/model.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
|
||||
class RetrievalResponse(BaseModel):
|
||||
"""Simple response container for tool results"""
|
||||
results: list[dict[str, Any]]
|
||||
took_ms: Optional[int] = None
|
||||
total_count: Optional[int] = None
|
||||
158
vw-agentic-rag/service/retrieval/retrieval.py
Normal file
158
vw-agentic-rag/service/retrieval/retrieval.py
Normal file
@@ -0,0 +1,158 @@
|
||||
import httpx
|
||||
import time
|
||||
import json
|
||||
from typing import Dict, Any, List, Optional
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
|
||||
import logging
|
||||
|
||||
from .model import RetrievalResponse
|
||||
from ..config import get_config
|
||||
from .clients import AzureSearchClient, RetrievalAPIError
|
||||
|
||||
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgenticRetrieval:
|
||||
"""Azure AI Search client for retrieval tools"""
|
||||
|
||||
def __init__(self):
|
||||
self.config = get_config()
|
||||
self.search_client = AzureSearchClient()
|
||||
|
||||
async def __aenter__(self):
|
||||
await self.search_client.__aenter__()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
await self.search_client.__aexit__(exc_type, exc_val, exc_tb)
|
||||
|
||||
async def retrieve_standard_regulation(
|
||||
self,
|
||||
query: str,
|
||||
conversation_history: str = "",
|
||||
**kwargs
|
||||
) -> RetrievalResponse:
|
||||
"""Search standard/regulation attributes"""
|
||||
start_time = time.time()
|
||||
|
||||
# Use the new Azure AI Search approach
|
||||
index_name = self.config.retrieval.index.standard_regulation_index
|
||||
vector_fields = "full_metadata_vector"
|
||||
select_fields = "id, func_uuid, title, publisher, document_category, document_code, x_Standard_Regulation_Id, x_Attachment_Type, x_Standard_Title_CN, x_Standard_Title_EN, x_Standard_Published_State, x_Standard_Drafting_Status, x_Standard_Published_State_EN, x_Standard_Drafting_Status_EN, x_Standard_Range, x_Standard_Kind, x_Standard_No, x_Standard_Technical_Committee, x_Standard_Vehicle_Type, x_Standard_Power_Type, x_Standard_CCS, x_Standard_ICS, x_Standard_Published_Date, x_Standard_Effective_Date, x_Regulation_Status, x_Regulation_Status_EN, x_Regulation_Title_CN, x_Regulation_Title_EN, x_Regulation_Document_No, x_Regulation_Issued_Date, x_Classification, x_Work_Group, x_Reference_Standard, x_Replaced_by, x_Refer_To, update_time, status"
|
||||
search_fields = "title, publisher, document_category, document_code, x_Standard_Regulation_Id, x_Attachment_Type, x_Standard_Title_CN, x_Standard_Title_EN, x_Standard_Published_State, x_Standard_Drafting_Status, x_Standard_Published_State_EN, x_Standard_Drafting_Status_EN, x_Standard_Range, x_Standard_Kind, x_Standard_No, x_Standard_Technical_Committee, x_Standard_Vehicle_Type, x_Standard_Power_Type, x_Standard_CCS, x_Standard_ICS, x_Standard_Published_Date, x_Standard_Effective_Date, x_Regulation_Status, x_Regulation_Status_EN, x_Regulation_Title_CN, x_Regulation_Title_EN, x_Regulation_Document_No, x_Regulation_Issued_Date, x_Classification, x_Work_Group, x_Reference_Standard, x_Replaced_by, x_Refer_To, update_time, status"
|
||||
|
||||
top_k = kwargs.get("top_k", 10)
|
||||
score_threshold = kwargs.get("score_threshold", 1.5)
|
||||
|
||||
try:
|
||||
response_data = await self.search_client.search_azure_ai(
|
||||
index_name=index_name,
|
||||
search_text=query,
|
||||
vector_fields=vector_fields,
|
||||
select_fields=select_fields,
|
||||
search_fields=search_fields,
|
||||
top_k=top_k,
|
||||
score_threshold=score_threshold
|
||||
)
|
||||
|
||||
results = response_data.get("value", [])
|
||||
|
||||
took_ms = int((time.time() - start_time) * 1000)
|
||||
return RetrievalResponse(
|
||||
results=results,
|
||||
took_ms=took_ms,
|
||||
total_count=len(results)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"retrieve_standard_regulation failed: {e}")
|
||||
raise
|
||||
|
||||
async def retrieve_doc_chunk_standard_regulation(
|
||||
self,
|
||||
query: str,
|
||||
conversation_history: str = "",
|
||||
**kwargs
|
||||
) -> RetrievalResponse:
|
||||
"""Search standard/regulation document chunks"""
|
||||
start_time = time.time()
|
||||
|
||||
# Use the new Azure AI Search approach
|
||||
index_name = self.config.retrieval.index.chunk_index
|
||||
vector_fields = "contentVector, full_metadata_vector"
|
||||
select_fields = "content, title, full_headers, document_code, document_category, publisher, x_Regulation_Title_CN, x_Regulation_Title_EN, x_Standard_Title_CN, x_Standard_Title_EN, x_Standard_Kind, x_Standard_CCS, x_Standard_ICS, x_Standard_Vehicle_Type, x_Standard_Power_Type, id, metadata, func_uuid, filepath, x_Standard_Regulation_Id"
|
||||
search_fields = "content, title, full_headers, document_code, document_category, publisher, x_Regulation_Title_CN, x_Regulation_Title_EN, x_Standard_Title_CN, x_Standard_Title_EN, x_Standard_Kind, x_Standard_CCS, x_Standard_ICS, x_Standard_Vehicle_Type, x_Standard_Power_Type"
|
||||
filter_query = "(document_category eq 'Standard' or document_category eq 'Regulation') and (status eq '已发布') and (x_Standard_Published_State_EN eq 'Effective' or x_Standard_Published_State_EN eq 'Publication' or x_Standard_Published_State_EN eq 'Implementation' or x_Regulation_Status_EN eq 'Publication' or x_Regulation_Status_EN eq 'Implementation') and (x_Attachment_Type eq '标准附件(PUBLISHED_STANDARDS)' or x_Attachment_Type eq '已发布法规附件(ISSUED_REGULATION)')"
|
||||
|
||||
top_k = kwargs.get("top_k", 10)
|
||||
score_threshold = kwargs.get("score_threshold", 1.5)
|
||||
|
||||
try:
|
||||
response_data = await self.search_client.search_azure_ai(
|
||||
index_name=index_name,
|
||||
search_text=query,
|
||||
vector_fields=vector_fields,
|
||||
select_fields=select_fields,
|
||||
search_fields=search_fields,
|
||||
filter_query=filter_query,
|
||||
top_k=top_k,
|
||||
score_threshold=score_threshold
|
||||
)
|
||||
|
||||
results = response_data.get("value", [])
|
||||
|
||||
took_ms = int((time.time() - start_time) * 1000)
|
||||
return RetrievalResponse(
|
||||
results=results,
|
||||
took_ms=took_ms,
|
||||
total_count=len(results)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"retrieve_doc_chunk_standard_regulation failed: {e}")
|
||||
raise
|
||||
|
||||
|
||||
async def retrieve_doc_chunk_user_manual(
|
||||
self,
|
||||
query: str,
|
||||
conversation_history: str = "",
|
||||
**kwargs
|
||||
) -> RetrievalResponse:
|
||||
"""Search CATOnline system user manual document chunks"""
|
||||
start_time = time.time()
|
||||
|
||||
# Use the new Azure AI Search approach
|
||||
index_name = self.config.retrieval.index.chunk_user_manual_index
|
||||
vector_fields = "contentVector"
|
||||
select_fields = "content, title, full_headers"
|
||||
search_fields = "content, title, full_headers"
|
||||
|
||||
top_k = kwargs.get("top_k", 10)
|
||||
score_threshold = kwargs.get("score_threshold", 1.5)
|
||||
|
||||
try:
|
||||
response_data = await self.search_client.search_azure_ai(
|
||||
index_name=index_name,
|
||||
search_text=query,
|
||||
vector_fields=vector_fields,
|
||||
select_fields=select_fields,
|
||||
search_fields=search_fields,
|
||||
top_k=top_k,
|
||||
score_threshold=score_threshold
|
||||
)
|
||||
|
||||
results = response_data.get("value", [])
|
||||
|
||||
took_ms = int((time.time() - start_time) * 1000)
|
||||
return RetrievalResponse(
|
||||
results=results,
|
||||
took_ms=took_ms,
|
||||
total_count=len(results)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"retrieve_doc_chunk_user_manual failed: {e}")
|
||||
raise
|
||||
|
||||
|
||||
1
vw-agentic-rag/service/schemas/__init__.py
Normal file
1
vw-agentic-rag/service/schemas/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Empty __init__.py files to make packages
|
||||
34
vw-agentic-rag/service/schemas/messages.py
Normal file
34
vw-agentic-rag/service/schemas/messages.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from typing import Dict, Any, Optional
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class UserMessage(BaseModel):
|
||||
content: str
|
||||
timestamp: Optional[str] = None
|
||||
|
||||
|
||||
class AssistantMessage(BaseModel):
|
||||
content: str
|
||||
citations_mapping_csv: Optional[str] = None
|
||||
timestamp: Optional[str] = None
|
||||
|
||||
|
||||
class ToolMessage(BaseModel):
|
||||
tool_name: str
|
||||
tool_call_id: str
|
||||
content: str # Usually JSON string of results
|
||||
timestamp: Optional[str] = None
|
||||
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
session_id: str
|
||||
messages: list[Dict[str, Any]]
|
||||
client_hints: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class ChatResponse(BaseModel):
|
||||
"""Base response for non-streaming endpoints"""
|
||||
answer: str
|
||||
citations_mapping_csv: str
|
||||
tool_results: list[Dict[str, Any]]
|
||||
session_id: str
|
||||
72
vw-agentic-rag/service/sse.py
Normal file
72
vw-agentic-rag/service/sse.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import json
|
||||
from typing import AsyncGenerator, Dict, Any
|
||||
|
||||
|
||||
def format_sse_event(event: str, data: Dict[str, Any]) -> str:
|
||||
"""Format data as Server-Sent Events"""
|
||||
return f"event: {event}\ndata: {json.dumps(data)}\n\n"
|
||||
|
||||
|
||||
async def send_heartbeat() -> AsyncGenerator[str, None]:
|
||||
"""Send periodic heartbeat to keep connection alive"""
|
||||
while True:
|
||||
yield format_sse_event("heartbeat", {"timestamp": "now"})
|
||||
# In practice, you'd use asyncio.sleep but this is for demo
|
||||
break
|
||||
|
||||
|
||||
def create_token_event(delta: str, tool_call_id: str | None = None) -> str:
|
||||
"""Create a token streaming event"""
|
||||
return format_sse_event("tokens", {
|
||||
"delta": delta,
|
||||
"tool_call_id": tool_call_id
|
||||
})
|
||||
|
||||
|
||||
def create_tool_start_event(tool_id: str, name: str, args: Dict[str, Any]) -> str:
|
||||
"""Create a tool start event"""
|
||||
return format_sse_event("tool_start", {
|
||||
"id": tool_id,
|
||||
"name": name,
|
||||
"args": args
|
||||
})
|
||||
|
||||
|
||||
def create_tool_progress_event(tool_id: str, message: str) -> str:
|
||||
"""Create a tool progress event"""
|
||||
return format_sse_event("tool_progress", {
|
||||
"id": tool_id,
|
||||
"message": message
|
||||
})
|
||||
|
||||
|
||||
def create_tool_result_event(tool_id: str, name: str, results: list, took_ms: int) -> str:
|
||||
"""Create a tool result event"""
|
||||
return format_sse_event("tool_result", {
|
||||
"id": tool_id,
|
||||
"name": name,
|
||||
"results": results,
|
||||
"took_ms": took_ms
|
||||
})
|
||||
|
||||
|
||||
def create_tool_error_event(tool_id: str, name: str, error: str) -> str:
|
||||
"""Create a tool error event"""
|
||||
return format_sse_event("tool_error", {
|
||||
"id": tool_id,
|
||||
"name": name,
|
||||
"error": error
|
||||
})
|
||||
|
||||
|
||||
# def create_agent_done_event() -> str:
|
||||
# """Create agent completion event"""
|
||||
# return format_sse_event("agent_done", {"answer_done": True})
|
||||
|
||||
|
||||
def create_error_event(error: str, details: Dict[str, Any] | None = None) -> str:
|
||||
"""Create an error event"""
|
||||
event_data: Dict[str, Any] = {"error": error}
|
||||
if details:
|
||||
event_data["details"] = details
|
||||
return format_sse_event("error", event_data)
|
||||
1
vw-agentic-rag/service/utils/__init__.py
Normal file
1
vw-agentic-rag/service/utils/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Empty __init__.py to make this a package
|
||||
165
vw-agentic-rag/service/utils/error_handler.py
Normal file
165
vw-agentic-rag/service/utils/error_handler.py
Normal file
@@ -0,0 +1,165 @@
|
||||
"""
|
||||
DRY Error Handling and Logging Utilities
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import traceback
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Optional, Callable
|
||||
from functools import wraps
|
||||
|
||||
from ..sse import create_error_event, create_tool_error_event
|
||||
|
||||
|
||||
class ErrorCode(Enum):
|
||||
"""Error codes for different types of failures"""
|
||||
# Client errors (4xxx)
|
||||
INVALID_REQUEST = 4001
|
||||
MISSING_PARAMETERS = 4002
|
||||
INVALID_SESSION = 4003
|
||||
|
||||
# Server errors (5xxx)
|
||||
LLM_ERROR = 5001
|
||||
TOOL_ERROR = 5002
|
||||
DATABASE_ERROR = 5003
|
||||
MEMORY_ERROR = 5004
|
||||
EXTERNAL_API_ERROR = 5005
|
||||
INTERNAL_ERROR = 5000
|
||||
|
||||
|
||||
class ErrorCategory(Enum):
|
||||
"""Error categories for better organization"""
|
||||
VALIDATION = "validation"
|
||||
LLM = "llm"
|
||||
TOOL = "tool"
|
||||
DATABASE = "database"
|
||||
MEMORY = "memory"
|
||||
EXTERNAL_API = "external_api"
|
||||
INTERNAL = "internal"
|
||||
|
||||
|
||||
class StructuredLogger:
|
||||
"""DRY structured logging with automatic error handling"""
|
||||
|
||||
def __init__(self, name: str):
|
||||
self.logger = logging.getLogger(name)
|
||||
|
||||
def error(self, msg: str, error: Optional[Exception] = None, category: ErrorCategory = ErrorCategory.INTERNAL,
|
||||
error_code: ErrorCode = ErrorCode.INTERNAL_ERROR, extra: Optional[Dict[str, Any]] = None):
|
||||
"""Log structured error with stack trace"""
|
||||
data: Dict[str, Any] = {
|
||||
"message": msg,
|
||||
"category": category.value,
|
||||
"error_code": error_code.value,
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
|
||||
if error:
|
||||
data.update({
|
||||
"error_type": type(error).__name__,
|
||||
"error_message": str(error),
|
||||
"stack_trace": traceback.format_exc()
|
||||
})
|
||||
|
||||
if extra:
|
||||
data["extra"] = extra
|
||||
|
||||
self.logger.error(json.dumps(data))
|
||||
|
||||
def info(self, msg: str, extra: Optional[Dict[str, Any]] = None):
|
||||
"""Log structured info"""
|
||||
data: Dict[str, Any] = {"message": msg, "timestamp": datetime.now(timezone.utc).isoformat()}
|
||||
if extra:
|
||||
data["extra"] = extra
|
||||
self.logger.info(json.dumps(data))
|
||||
|
||||
def warning(self, msg: str, extra: Optional[Dict[str, Any]] = None):
|
||||
"""Log structured warning"""
|
||||
data: Dict[str, Any] = {"message": msg, "timestamp": datetime.now(timezone.utc).isoformat()}
|
||||
if extra:
|
||||
data["extra"] = extra
|
||||
self.logger.warning(json.dumps(data))
|
||||
|
||||
|
||||
def get_user_message(category: ErrorCategory) -> str:
|
||||
"""Get user-friendly error messages in English"""
|
||||
messages = {
|
||||
ErrorCategory.VALIDATION: "Invalid request parameters. Please check your input.",
|
||||
ErrorCategory.LLM: "AI service is temporarily unavailable. Please try again later.",
|
||||
ErrorCategory.TOOL: "Tool execution failed. Please retry your request.",
|
||||
ErrorCategory.DATABASE: "Database service is temporarily unavailable.",
|
||||
ErrorCategory.MEMORY: "Session storage issue occurred. Please refresh the page.",
|
||||
ErrorCategory.EXTERNAL_API: "External service connection failed.",
|
||||
ErrorCategory.INTERNAL: "Internal server error. We are working to resolve this."
|
||||
}
|
||||
return messages.get(category, "Unknown error occurred. Please contact technical support.")
|
||||
|
||||
|
||||
def handle_async_errors(category: ErrorCategory, error_code: ErrorCode,
|
||||
stream_callback: Optional[Callable] = None, tool_id: Optional[str] = None):
|
||||
"""DRY decorator for async error handling with streaming support"""
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
logger = StructuredLogger(func.__module__)
|
||||
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
user_msg = get_user_message(category)
|
||||
|
||||
logger.error(
|
||||
f"Error in {func.__name__}: {str(e)}",
|
||||
error=e,
|
||||
category=category,
|
||||
error_code=error_code,
|
||||
extra={"function": func.__name__, "args_count": len(args)}
|
||||
)
|
||||
|
||||
# Send error event if streaming
|
||||
if stream_callback:
|
||||
if tool_id:
|
||||
await stream_callback(create_tool_error_event(tool_id, func.__name__, user_msg))
|
||||
else:
|
||||
await stream_callback(create_error_event(user_msg))
|
||||
|
||||
# Re-raise with user-friendly message for API responses
|
||||
raise Exception(user_msg) from e
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
def handle_sync_errors(category: ErrorCategory, error_code: ErrorCode):
|
||||
"""DRY decorator for sync error handling"""
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
logger = StructuredLogger(func.__module__)
|
||||
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error in {func.__name__}: {str(e)}",
|
||||
error=e,
|
||||
category=category,
|
||||
error_code=error_code,
|
||||
extra={"function": func.__name__}
|
||||
)
|
||||
raise Exception(get_user_message(category)) from e
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
def create_error_response(category: ErrorCategory, error_code: ErrorCode,
|
||||
technical_msg: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Create consistent error response format"""
|
||||
return {
|
||||
"user_message": get_user_message(category),
|
||||
"error_code": error_code.value,
|
||||
"category": category.value,
|
||||
"technical_message": technical_msg,
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
94
vw-agentic-rag/service/utils/logging.py
Normal file
94
vw-agentic-rag/service/utils/logging.py
Normal file
@@ -0,0 +1,94 @@
|
||||
import logging
|
||||
import json
|
||||
import time
|
||||
from typing import Dict, Any, Optional
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
def setup_logging(level: str = "INFO", format_type: str = "json") -> None:
|
||||
"""Setup structured logging"""
|
||||
if format_type == "json":
|
||||
formatter = JsonFormatter()
|
||||
else:
|
||||
formatter = logging.Formatter(
|
||||
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
|
||||
handler = logging.StreamHandler()
|
||||
handler.setFormatter(formatter)
|
||||
|
||||
root_logger = logging.getLogger()
|
||||
root_logger.setLevel(getattr(logging, level.upper()))
|
||||
root_logger.addHandler(handler)
|
||||
|
||||
|
||||
class JsonFormatter(logging.Formatter):
|
||||
"""JSON log formatter"""
|
||||
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
log_data = {
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"level": record.levelname,
|
||||
"logger": record.name,
|
||||
"message": record.getMessage(),
|
||||
}
|
||||
|
||||
# Add extra fields
|
||||
if hasattr(record, "request_id"):
|
||||
log_data["request_id"] = getattr(record, "request_id")
|
||||
if hasattr(record, "session_id"):
|
||||
log_data["session_id"] = getattr(record, "session_id")
|
||||
if hasattr(record, "duration_ms"):
|
||||
log_data["duration_ms"] = getattr(record, "duration_ms")
|
||||
|
||||
return json.dumps(log_data)
|
||||
|
||||
|
||||
class Timer:
|
||||
"""Simple timer context manager"""
|
||||
|
||||
def __init__(self):
|
||||
self.start_time = None
|
||||
self.end_time = None
|
||||
|
||||
def __enter__(self):
|
||||
self.start_time = time.time()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.end_time = time.time()
|
||||
|
||||
@property
|
||||
def elapsed_ms(self) -> int:
|
||||
if self.start_time and self.end_time:
|
||||
return int((self.end_time - self.start_time) * 1000)
|
||||
return 0
|
||||
|
||||
|
||||
def redact_secrets(data: Dict[str, Any], secret_keys: list[str] | None = None) -> Dict[str, Any]:
|
||||
"""Redact sensitive information from logs"""
|
||||
if secret_keys is None:
|
||||
secret_keys = ["api_key", "password", "token", "secret", "key"]
|
||||
|
||||
redacted = {}
|
||||
for key, value in data.items():
|
||||
if any(secret in key.lower() for secret in secret_keys):
|
||||
redacted[key] = "***REDACTED***"
|
||||
elif isinstance(value, dict):
|
||||
redacted[key] = redact_secrets(value, secret_keys)
|
||||
else:
|
||||
redacted[key] = value
|
||||
|
||||
return redacted
|
||||
|
||||
|
||||
def generate_request_id() -> str:
|
||||
"""Generate unique request ID"""
|
||||
return f"req_{int(time.time() * 1000)}_{hash(time.time()) % 10000:04d}"
|
||||
|
||||
|
||||
def truncate_text(text: str, max_length: int = 1000, suffix: str = "...") -> str:
|
||||
"""Truncate text to maximum length"""
|
||||
if len(text) <= max_length:
|
||||
return text
|
||||
return text[:max_length - len(suffix)] + suffix
|
||||
51
vw-agentic-rag/service/utils/middleware.py
Normal file
51
vw-agentic-rag/service/utils/middleware.py
Normal file
@@ -0,0 +1,51 @@
|
||||
"""
|
||||
Lightweight Error Handling Middleware
|
||||
"""
|
||||
|
||||
from fastapi import Request, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
from .error_handler import StructuredLogger, ErrorCategory, ErrorCode, create_error_response
|
||||
|
||||
|
||||
class ErrorMiddleware(BaseHTTPMiddleware):
|
||||
"""Concise error handling middleware following DRY principles"""
|
||||
|
||||
def __init__(self, app):
|
||||
super().__init__(app)
|
||||
self.logger = StructuredLogger(__name__)
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
try:
|
||||
return await call_next(request)
|
||||
except HTTPException as e:
|
||||
# HTTP exceptions - map to appropriate categories
|
||||
category = ErrorCategory.VALIDATION if e.status_code < 500 else ErrorCategory.INTERNAL
|
||||
error_code = ErrorCode.INVALID_REQUEST if e.status_code < 500 else ErrorCode.INTERNAL_ERROR
|
||||
|
||||
self.logger.error(
|
||||
f"HTTP {e.status_code}: {e.detail}",
|
||||
category=category,
|
||||
error_code=error_code,
|
||||
extra={"path": str(request.url), "method": request.method}
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=e.status_code,
|
||||
content=create_error_response(category, error_code, e.detail)
|
||||
)
|
||||
except Exception as e:
|
||||
# Unexpected errors
|
||||
self.logger.error(
|
||||
f"Unhandled error: {str(e)}",
|
||||
error=e,
|
||||
category=ErrorCategory.INTERNAL,
|
||||
error_code=ErrorCode.INTERNAL_ERROR,
|
||||
extra={"path": str(request.url), "method": request.method}
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response(ErrorCategory.INTERNAL, ErrorCode.INTERNAL_ERROR)
|
||||
)
|
||||
103
vw-agentic-rag/service/utils/templates.py
Normal file
103
vw-agentic-rag/service/utils/templates.py
Normal file
@@ -0,0 +1,103 @@
|
||||
"""
|
||||
Template utilities for Jinja2 template rendering with LangChain
|
||||
"""
|
||||
import logging
|
||||
from typing import Dict, Any
|
||||
from jinja2 import Environment, BaseLoader, TemplateError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TemplateRenderer:
|
||||
"""Jinja2 template renderer for LLM prompts"""
|
||||
|
||||
def __init__(self):
|
||||
self.env = Environment(
|
||||
loader=BaseLoader(),
|
||||
# Enable safe variable substitution
|
||||
autoescape=False,
|
||||
# Custom delimiters to avoid conflicts with common markdown syntax
|
||||
variable_start_string='{{',
|
||||
variable_end_string='}}',
|
||||
block_start_string='{%',
|
||||
block_end_string='%}',
|
||||
comment_start_string='{#',
|
||||
comment_end_string='#}',
|
||||
# Keep linebreaks
|
||||
keep_trailing_newline=True,
|
||||
# Remove unnecessary whitespace
|
||||
trim_blocks=True,
|
||||
lstrip_blocks=True
|
||||
)
|
||||
|
||||
def render_template(self, template_string: str, variables: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Render a Jinja2 template string with provided variables
|
||||
|
||||
Args:
|
||||
template_string: The template string with Jinja2 syntax
|
||||
variables: Dictionary of variables to substitute
|
||||
|
||||
Returns:
|
||||
Rendered template string
|
||||
|
||||
Raises:
|
||||
TemplateError: If template rendering fails
|
||||
"""
|
||||
try:
|
||||
template = self.env.from_string(template_string)
|
||||
rendered = template.render(**variables)
|
||||
logger.debug(f"Template rendered successfully with variables: {list(variables.keys())}")
|
||||
return rendered
|
||||
except TemplateError as e:
|
||||
logger.error(f"Template rendering failed: {e}")
|
||||
logger.error(f"Template: {template_string[:200]}...")
|
||||
logger.error(f"Variables: {variables}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error during template rendering: {e}")
|
||||
raise TemplateError(f"Template rendering failed: {e}")
|
||||
|
||||
def render_system_prompt(self, template_string: str, variables: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Render system prompt template
|
||||
|
||||
Args:
|
||||
template_string: System prompt template
|
||||
variables: Variables for substitution
|
||||
|
||||
Returns:
|
||||
Rendered system prompt
|
||||
"""
|
||||
return self.render_template(template_string, variables)
|
||||
|
||||
def render_user_prompt(self, template_string: str, variables: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Render user prompt template
|
||||
|
||||
Args:
|
||||
template_string: User prompt template
|
||||
variables: Variables for substitution
|
||||
|
||||
Returns:
|
||||
Rendered user prompt
|
||||
"""
|
||||
return self.render_template(template_string, variables)
|
||||
|
||||
|
||||
# Global template renderer instance
|
||||
template_renderer = TemplateRenderer()
|
||||
|
||||
|
||||
def render_prompt_template(template_string: str, variables: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Convenience function to render prompt templates
|
||||
|
||||
Args:
|
||||
template_string: Template string with Jinja2 syntax
|
||||
variables: Dictionary of variables to substitute
|
||||
|
||||
Returns:
|
||||
Rendered template string
|
||||
"""
|
||||
return template_renderer.render_template(template_string, variables)
|
||||
Reference in New Issue
Block a user