This commit is contained in:
2025-09-26 17:15:54 +08:00
commit db0e5965ec
211 changed files with 40437 additions and 0 deletions

View File

@@ -0,0 +1 @@
# Empty __init__.py files to make packages

View File

@@ -0,0 +1,146 @@
"""
AI SDK Data Stream Protocol adapter
Converts our internal SSE events to AI SDK compatible format
Following the official Data Stream Protocol: TYPE_ID:CONTENT_JSON\n
"""
import json
import uuid
from typing import Dict, Any, AsyncGenerator
def format_data_stream_part(type_id: str, content: Any) -> str:
"""Format data as AI SDK Data Stream Protocol part: TYPE_ID:JSON\n"""
content_json = json.dumps(content, ensure_ascii=False)
return f"{type_id}:{content_json}\n"
def create_text_part(text: str) -> str:
"""Create text part (type 0)"""
return format_data_stream_part("0", text)
def create_data_part(data: list) -> str:
"""Create data part (type 2) for additional data"""
return format_data_stream_part("2", data)
def create_error_part(error: str) -> str:
"""Create error part (type 3)"""
return format_data_stream_part("3", error)
def create_tool_call_part(tool_call_id: str, tool_name: str, args: dict) -> str:
"""Create tool call part (type 9)"""
return format_data_stream_part("9", {
"toolCallId": tool_call_id,
"toolName": tool_name,
"args": args
})
def create_tool_result_part(tool_call_id: str, result: Any) -> str:
"""Create tool result part (type a)"""
return format_data_stream_part("a", {
"toolCallId": tool_call_id,
"result": result
})
def create_finish_step_part(finish_reason: str = "stop", usage: Dict[str, int] | None = None, is_continued: bool = False) -> str:
"""Create finish step part (type e)"""
usage = usage or {"promptTokens": 0, "completionTokens": 0}
return format_data_stream_part("e", {
"finishReason": finish_reason,
"usage": usage,
"isContinued": is_continued
})
def create_finish_message_part(finish_reason: str = "stop", usage: Dict[str, int] | None = None) -> str:
"""Create finish message part (type d)"""
usage = usage or {"promptTokens": 0, "completionTokens": 0}
return format_data_stream_part("d", {
"finishReason": finish_reason,
"usage": usage
})
class AISDKEventAdapter:
"""Adapter to convert our internal events to AI SDK Data Stream Protocol format"""
def __init__(self):
self.tool_calls = {} # Track tool calls
self.current_message_id = str(uuid.uuid4())
def convert_event(self, event_line: str) -> str | None:
"""Convert our SSE event to AI SDK Data Stream Protocol format"""
if not event_line.strip():
return None
try:
# Handle multi-line SSE format
lines = event_line.strip().split('\n')
event_type = None
data = None
for line in lines:
if line.startswith("event: "):
event_type = line.replace("event: ", "")
elif line.startswith("data: "):
data_str = line[6:] # Remove "data: "
if data_str:
data = json.loads(data_str)
if event_type and data:
return self._convert_by_type(event_type, data)
except (json.JSONDecodeError, IndexError, KeyError) as e:
# Skip malformed events
return None
return None
def _convert_by_type(self, event_type: str, data: Dict[str, Any]) -> str | None:
"""Convert event by type to Data Stream Protocol format"""
if event_type == "tokens":
# Token streaming -> text part (type 0)
delta = data.get("delta", "")
if delta:
return create_text_part(delta)
elif event_type == "tool_start":
# Tool start -> tool call part (type 9)
tool_id = data.get("id", str(uuid.uuid4()))
tool_name = data.get("name", "unknown")
args = data.get("args", {})
self.tool_calls[tool_id] = {"name": tool_name, "args": args}
return create_tool_call_part(tool_id, tool_name, args)
elif event_type == "tool_result":
# Tool result -> tool result part (type a)
tool_id = data.get("id", "")
results = data.get("results", [])
return create_tool_result_part(tool_id, results)
elif event_type == "tool_error":
# Tool error -> error part (type 3)
error = data.get("error", "Tool execution failed")
return create_error_part(error)
elif event_type == "error":
# Error -> error part (type 3)
error = data.get("error", "Unknown error")
return create_error_part(error)
return None
async def stream_ai_sdk_compatible(internal_stream: AsyncGenerator[str, None]) -> AsyncGenerator[str, None]:
"""Convert our internal SSE stream to AI SDK Data Stream Protocol compatible format"""
adapter = AISDKEventAdapter()
async for event in internal_stream:
converted = adapter.convert_event(event)
if converted:
yield converted

View File

@@ -0,0 +1,121 @@
"""
AI SDK compatible chat endpoint
"""
import asyncio
import logging
from typing import AsyncGenerator
from fastapi import Request
from fastapi.responses import StreamingResponse
from langchain_core.messages import HumanMessage
from .config import get_config
from .graph.state import TurnState, Message
from .schemas.messages import ChatRequest
from .ai_sdk_adapter import stream_ai_sdk_compatible
from .sse import create_error_event
logger = logging.getLogger(__name__)
async def handle_ai_sdk_chat(request: ChatRequest, app_state) -> StreamingResponse:
"""Handle chat request with AI SDK Data Stream Protocol"""
async def ai_sdk_stream() -> AsyncGenerator[str, None]:
try:
app_config = get_config()
memory_manager = app_state.memory_manager
graph = app_state.graph
# Prepare the new user message for LangGraph (session memory handled automatically)
graph_config = {
"configurable": {
"thread_id": request.session_id
}
}
# Get the latest user message from AI SDK format
new_user_message = None
if request.messages:
last_message = request.messages[-1]
if last_message.get("role") == "user":
new_user_message = HumanMessage(content=last_message.get("content", ""))
if not new_user_message:
logger.error("No user message found in request")
yield create_error_event("No user message provided")
return
# Create event queue for internal streaming
event_queue = asyncio.Queue()
async def stream_callback(event_str: str):
await event_queue.put(event_str)
async def run_workflow():
try:
# Set stream callback in context for the workflow
from .graph.graph import stream_callback_context
stream_callback_context.set(stream_callback)
# Create TurnState with the new user message
# AgenticWorkflow will handle LangGraph interaction and session history
from .graph.state import TurnState, Message
turn_state = TurnState(
messages=[Message(
role="user",
content=str(new_user_message.content),
timestamp=None
)],
session_id=request.session_id,
tool_results=[],
final_answer=""
)
# Use AgenticWorkflow.astream with stream_callback parameter
async for final_state in graph.astream(turn_state, stream_callback=stream_callback):
# The workflow handles all streaming internally via stream_callback
pass # final_state contains the complete result
await event_queue.put(None) # Signal completion
except Exception as e:
logger.error(f"Workflow execution error: {e}", exc_info=True)
await event_queue.put(create_error_event(f"Processing error: {str(e)}"))
await event_queue.put(None)
# Start workflow task
workflow_task = asyncio.create_task(run_workflow())
# Convert internal events to AI SDK format
async def internal_stream():
try:
while True:
event = await event_queue.get()
if event is None:
break
yield event
finally:
if not workflow_task.done():
workflow_task.cancel()
# Stream converted events
async for ai_sdk_event in stream_ai_sdk_compatible(internal_stream()):
yield ai_sdk_event
except Exception as e:
logger.error(f"AI SDK chat error: {e}")
# Send error in AI SDK format
from .ai_sdk_adapter import create_error_part
yield create_error_part(f"Server error: {str(e)}")
return StreamingResponse(
ai_sdk_stream(),
media_type="text/plain",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Headers": "*",
"x-vercel-ai-data-stream": "v1", # AI SDK Data Stream Protocol header
}
)

View File

@@ -0,0 +1,297 @@
import yaml
import os
from typing import Dict, Any, Optional
from pydantic import BaseModel, Field
from pydantic_settings import BaseSettings
class OpenAIConfig(BaseModel):
base_url: str = "https://api.openai.com/v1"
api_key: str
model: str = "gpt-4o"
class AzureConfig(BaseModel):
base_url: str
api_key: str
deployment: str
api_version: str = "2024-02-01"
class EmbeddingConfig(BaseModel):
base_url: str
api_key: str
model: str
dimension: int
api_version: Optional[str]
class IndexConfig(BaseModel):
standard_regulation_index: str
chunk_index: str
chunk_user_manual_index: str
class RetrievalConfig(BaseModel):
endpoint: str
api_key: str
api_version: str
semantic_configuration: str
embedding: EmbeddingConfig
index: IndexConfig
class PostgreSQLConfig(BaseModel):
host: str
port: int = 5432
database: str
username: str
password: str
ttl_days: int = 7
class RedisConfig(BaseModel):
host: str
port: int = 6379
password: str
use_ssl: bool = True
db: int = 0
ttl_days: int = 7
class AppLoggingConfig(BaseModel):
level: str = "INFO"
class AppConfig(BaseModel):
name: str = "agentic-rag"
memory_ttl_days: int = 7
max_tool_rounds: int = 3 # Maximum allowed tool calling rounds
max_tool_rounds_user_manual: int = 3 # Maximum allowed tool calling rounds for user manual agent
cors_origins: list[str] = Field(default_factory=lambda: ["*"])
logging: AppLoggingConfig = Field(default_factory=AppLoggingConfig)
# Service configuration
host: str = "0.0.0.0"
port: int = 8000
class SearchConfig(BaseModel):
"""Search index configuration"""
standard_regulation_index: str = ""
chunk_index: str = ""
chunk_user_manual_index: str = ""
class CitationConfig(BaseModel):
"""Citation link configuration"""
base_url: str = "" # Default empty string
class LLMParametersConfig(BaseModel):
"""LLM parameters configuration"""
temperature: Optional[float] = None
max_context_length: int = 96000 # Maximum context length for conversation history (in tokens)
max_output_tokens: Optional[int] = None # Optional limit for LLM output tokens (None = no limit)
class LLMPromptsConfig(BaseModel):
"""LLM prompts configuration"""
agent_system_prompt: str
synthesis_system_prompt: Optional[str] = None
synthesis_user_prompt: Optional[str] = None
intent_recognition_prompt: Optional[str] = None
user_manual_prompt: Optional[str] = None
class LLMPromptConfig(BaseModel):
"""LLM prompt configuration from llm_prompt.yaml"""
parameters: LLMParametersConfig = Field(default_factory=LLMParametersConfig)
prompts: LLMPromptsConfig
class LLMRagConfig(BaseModel):
"""Legacy LLM RAG configuration for backward compatibility"""
temperature: Optional[float] = None
max_context_length: int = 96000 # Maximum context length for conversation history (in tokens)
max_output_tokens: Optional[int] = None # Optional limit for LLM output tokens (None = no limit)
# Legacy prompts for backward compatibility
system_prompt: Optional[str] = None
user_prompt: Optional[str] = None
# New autonomous agent prompts
agent_system_prompt: Optional[str] = None
synthesis_system_prompt: Optional[str] = None
synthesis_user_prompt: Optional[str] = None
class LLMConfig(BaseModel):
rag: LLMRagConfig
class LoggingConfig(BaseModel):
level: str = "INFO"
format: str = "json"
class Config(BaseSettings):
provider: str = "openai"
openai: Optional[OpenAIConfig] = None
azure: Optional[AzureConfig] = None
retrieval: RetrievalConfig
postgresql: PostgreSQLConfig
redis: Optional[RedisConfig] = None
app: AppConfig = Field(default_factory=AppConfig)
search: SearchConfig = Field(default_factory=SearchConfig)
citation: CitationConfig = Field(default_factory=CitationConfig)
llm: Optional[LLMConfig] = None
logging: LoggingConfig = Field(default_factory=LoggingConfig)
# New LLM prompt configuration
llm_prompt: Optional[LLMPromptConfig] = None
@classmethod
def from_yaml(cls, config_path: str = "config.yaml", llm_prompt_path: str = "llm_prompt.yaml") -> "Config":
"""Load configuration from YAML files with environment variable substitution"""
# Load main config
with open(config_path, 'r', encoding='utf-8') as f:
yaml_data = yaml.safe_load(f)
# Substitute environment variables
yaml_data = cls._substitute_env_vars(yaml_data)
# Load LLM prompt config if exists
llm_prompt_data = None
if os.path.exists(llm_prompt_path):
with open(llm_prompt_path, 'r', encoding='utf-8') as f:
llm_prompt_data = yaml.safe_load(f)
llm_prompt_data = cls._substitute_env_vars(llm_prompt_data)
yaml_data['llm_prompt'] = llm_prompt_data
return cls(**yaml_data)
@classmethod
def _substitute_env_vars(cls, data: Any) -> Any:
"""Recursively substitute ${VAR} and ${VAR:-default} patterns with environment variables"""
if isinstance(data, dict):
return {k: cls._substitute_env_vars(v) for k, v in data.items()}
elif isinstance(data, list):
return [cls._substitute_env_vars(item) for item in data]
elif isinstance(data, str):
# Handle ${VAR:-default} pattern
if data.startswith("${") and data.endswith("}"):
env_spec = data[2:-1]
if ":-" in env_spec:
var_name, default_value = env_spec.split(":-", 1)
return os.getenv(var_name, default_value)
else:
return os.getenv(env_spec, data) # Return original if env var not found
return data
else:
return data
def get_llm_config(self) -> Dict[str, Any]:
"""Get LLM configuration based on provider"""
base_config = {}
# Get temperature and max_output_tokens from llm_prompt config first, fallback to legacy llm.rag config
if self.llm_prompt and self.llm_prompt.parameters:
# Only add temperature if explicitly set (not None)
if self.llm_prompt.parameters.temperature is not None:
base_config["temperature"] = self.llm_prompt.parameters.temperature
# Only add max_output_tokens if explicitly set (not None)
if self.llm_prompt.parameters.max_output_tokens is not None:
base_config["max_tokens"] = self.llm_prompt.parameters.max_output_tokens
elif self.llm and self.llm.rag:
# Only add temperature if explicitly set (not None)
if hasattr(self.llm.rag, 'temperature') and self.llm.rag.temperature is not None:
base_config["temperature"] = self.llm.rag.temperature
# Only add max_output_tokens if explicitly set (not None)
if self.llm.rag.max_output_tokens is not None:
base_config["max_tokens"] = self.llm.rag.max_output_tokens
if self.provider == "openai" and self.openai:
return {
**base_config,
"provider": "openai",
"base_url": self.openai.base_url,
"api_key": self.openai.api_key,
"model": self.openai.model,
}
elif self.provider == "azure" and self.azure:
return {
**base_config,
"provider": "azure",
"base_url": self.azure.base_url,
"api_key": self.azure.api_key,
"deployment": self.azure.deployment,
"api_version": self.azure.api_version,
}
else:
raise ValueError(f"Invalid provider '{self.provider}' or missing configuration")
def get_rag_prompts(self) -> Dict[str, str]:
"""Get RAG prompts configuration - prioritize llm_prompt.yaml over legacy config"""
# Use new llm_prompt config if available
if self.llm_prompt and self.llm_prompt.prompts:
return {
"system_prompt": self.llm_prompt.prompts.agent_system_prompt,
"user_prompt": "{{user_query}}", # Default template
"agent_system_prompt": self.llm_prompt.prompts.agent_system_prompt,
"synthesis_system_prompt": self.llm_prompt.prompts.synthesis_system_prompt or "You are a helpful assistant.",
"synthesis_user_prompt": self.llm_prompt.prompts.synthesis_user_prompt or "{{user_query}}",
"intent_recognition_prompt": self.llm_prompt.prompts.intent_recognition_prompt or "",
"user_manual_prompt": self.llm_prompt.prompts.user_manual_prompt or "",
}
# Fallback to legacy llm.rag config
if self.llm and self.llm.rag:
return {
"system_prompt": self.llm.rag.system_prompt or "You are a helpful assistant.",
"user_prompt": self.llm.rag.user_prompt or "{{user_query}}",
"agent_system_prompt": self.llm.rag.agent_system_prompt or "You are a helpful assistant.",
"synthesis_system_prompt": self.llm.rag.synthesis_system_prompt or "You are a helpful assistant.",
"synthesis_user_prompt": self.llm.rag.synthesis_user_prompt or "{{user_query}}",
"intent_recognition_prompt": "",
"user_manual_prompt": "",
}
# Default fallback
return {
"system_prompt": "You are a helpful assistant.",
"user_prompt": "{{user_query}}",
"agent_system_prompt": "You are a helpful assistant.",
"synthesis_system_prompt": "You are a helpful assistant.",
"synthesis_user_prompt": "{{user_query}}",
"intent_recognition_prompt": "",
"user_manual_prompt": "",
}
def get_max_context_length(self) -> int:
"""Get maximum context length for conversation history"""
# Use new llm_prompt config if available
if self.llm_prompt and self.llm_prompt.parameters:
return self.llm_prompt.parameters.max_context_length
# Fallback to legacy llm.rag config
if self.llm and self.llm.rag:
return self.llm.rag.max_context_length
# Default fallback
return 96000
# Global config instance
config: Optional[Config] = None
def load_config(config_path: str = "config.yaml", llm_prompt_path: str = "llm_prompt.yaml") -> Config:
"""Load and return the global configuration"""
global config
config = Config.from_yaml(config_path, llm_prompt_path)
return config
def get_config() -> Config:
"""Get the current configuration instance"""
if config is None:
raise RuntimeError("Configuration not loaded. Call load_config() first.")
return config

View File

@@ -0,0 +1 @@
# Empty __init__.py files to make packages

View File

@@ -0,0 +1,746 @@
import json
import logging
import re
import asyncio
from typing import Dict, Any, List, Callable, Annotated, Literal, TypedDict, Optional, Union, cast
from datetime import datetime
from urllib.parse import quote
from contextvars import ContextVar
from pydantic import BaseModel
from langgraph.graph import StateGraph, END, add_messages, MessagesState
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, ToolMessage, BaseMessage
from langchain_core.runnables import RunnableConfig
from .state import TurnState, Message, ToolResult, AgentState
from .message_trimmer import create_conversation_trimmer
from .tools import get_tool_schemas, get_tools_by_name
from .user_manual_tools import get_user_manual_tools_by_name
from .intent_recognition import intent_recognition_node, intent_router
from .user_manual_rag import user_manual_rag_node
from ..llm_client import LLMClient
from ..config import get_config
from ..utils.templates import render_prompt_template
from ..memory.postgresql_memory import get_checkpointer
from ..utils.error_handler import (
StructuredLogger, ErrorCategory, ErrorCode,
handle_async_errors, get_user_message
)
from ..sse import (
create_tool_start_event,
create_tool_result_event,
create_tool_error_event,
create_token_event,
create_error_event
)
logger = StructuredLogger(__name__)
# Cache configuration at module level to avoid repeated get_config() calls
_cached_config = None
def get_cached_config():
"""Get cached configuration, loading it if not already cached"""
global _cached_config
if _cached_config is None:
_cached_config = get_config()
return _cached_config
# Context variable for streaming callback (thread-safe)
stream_callback_context: ContextVar[Optional[Callable]] = ContextVar('stream_callback', default=None)
# Agent node (autonomous function calling agent)
async def call_model(state: AgentState, config: Optional[RunnableConfig] = None) -> Dict[str, Any]:
"""
Agent node that autonomously uses tools and generates final answer.
Implements "detect-first-then-stream" strategy for optimal multi-round behavior:
1. Always start with non-streaming detection to check for tool needs
2. If tool_calls exist → return immediately for routing to tools
3. If no tool_calls → temporarily disable tools and perform streaming final synthesis
"""
app_config = get_cached_config()
llm_client = LLMClient()
# Get stream callback from context variable
stream_callback = stream_callback_context.get()
# Get tool schemas and bind tools for planning phase
tool_schemas = get_tool_schemas()
llm_client.bind_tools(tool_schemas, force_tool_choice=True)
# Create conversation trimmer for managing context length
trimmer = create_conversation_trimmer()
# Prepare messages with system prompt
messages = state["messages"].copy()
if not messages or not isinstance(messages[0], SystemMessage):
rag_prompts = app_config.get_rag_prompts()
system_prompt = rag_prompts.get("agent_system_prompt", "")
if not system_prompt:
raise ValueError("system_prompt is null")
messages = [SystemMessage(content=system_prompt)] + messages
# Track tool rounds
current_round = state.get("tool_rounds", 0)
# Get max_tool_rounds from state, fallback to config if not set
max_rounds = state.get("max_tool_rounds", None)
if max_rounds is None:
max_rounds = app_config.app.max_tool_rounds
# Only apply trimming at the start of a new conversation turn (when tool_rounds = 0)
# This prevents trimming current turn's tool results during multi-round tool calling
if current_round == 0:
# Trim conversation history to manage context length (only for previous conversation turns)
if trimmer.should_trim(messages):
messages = trimmer.trim_conversation_history(messages)
logger.info("Applied conversation history trimming for context management (new conversation turn)")
else:
logger.info(f"Skipping trimming during tool round {current_round} to preserve current turn's context")
logger.info(f"Agent node: tool_rounds={current_round}, max_tool_rounds={max_rounds}")
# Check if this should be final synthesis (max rounds reached)
has_tool_messages = any(isinstance(msg, ToolMessage) for msg in messages)
is_final_synthesis = has_tool_messages and current_round >= max_rounds
if is_final_synthesis:
logger.info("Starting final synthesis phase - no more tool calls allowed")
# ✅ STEP 1: Final synthesis with tools disabled from the start
# Disable tools to prevent any tool calling during synthesis
try:
original_tools = llm_client.bind_tools([], force_tool_choice=False) # Disable tools
if not stream_callback:
# No streaming callback, generate final response without tools
draft = await llm_client.ainvoke(list(messages))
return {"messages": [draft]}
# ✅ STEP 2: Streaming final synthesis with improved HTML comment filtering
response_content = ""
accumulated_content = ""
async for token in llm_client.astream(list(messages)):
accumulated_content += token
response_content += token
# Check for complete HTML comments in accumulated content
while "<!--" in accumulated_content and "-->" in accumulated_content:
comment_start = accumulated_content.find("<!--")
comment_end = accumulated_content.find("-->", comment_start)
if comment_start >= 0 and comment_end >= 0:
# Send content before comment
before_comment = accumulated_content[:comment_start]
if stream_callback and before_comment:
await stream_callback(create_token_event(before_comment))
# Skip the comment and continue with content after
accumulated_content = accumulated_content[comment_end + 3:]
else:
break
# Send accumulated content if no pending comment
if "<!--" not in accumulated_content:
if stream_callback and accumulated_content:
await stream_callback(create_token_event(accumulated_content))
accumulated_content = ""
# Send any remaining content (if not in middle of comment)
if accumulated_content and "<!--" not in accumulated_content:
if stream_callback:
await stream_callback(create_token_event(accumulated_content))
return {"messages": [AIMessage(content=response_content)]}
finally:
# ✅ STEP 3: Restore tool binding for next interaction
llm_client.bind_tools(tool_schemas, force_tool_choice=True)
else:
logger.info(f"Tool calling round {current_round + 1}/{max_rounds}")
# ✅ STEP 1: Non-streaming detection to check for tool needs
draft = await llm_client.ainvoke_with_tools(list(messages))
# ✅ STEP 2: If draft has tool_calls, return immediately (let routing handle it)
if isinstance(draft, AIMessage) and hasattr(draft, 'tool_calls') and draft.tool_calls:
# Increment tool round counter for next iteration
new_tool_rounds = current_round + 1
logger.info(f"Incremented tool_rounds to {new_tool_rounds}")
return {"messages": [draft], "tool_rounds": new_tool_rounds}
# ✅ STEP 3: No tool_calls needed → Enter final synthesis with streaming
# Temporarily disable tools to prevent accidental tool calling during synthesis
try:
llm_client.bind_tools([], force_tool_choice=False) # Disable tools
if not stream_callback:
# No streaming callback, use the draft we already have
return {"messages": [draft]}
# ✅ STEP 4: Streaming final synthesis with improved HTML comment filtering
response_content = ""
accumulated_content = ""
async for token in llm_client.astream(list(messages)):
accumulated_content += token
response_content += token
# Check for complete HTML comments in accumulated content
while "<!--" in accumulated_content and "-->" in accumulated_content:
comment_start = accumulated_content.find("<!--")
comment_end = accumulated_content.find("-->", comment_start)
if comment_start >= 0 and comment_end >= 0:
# Send content before comment
before_comment = accumulated_content[:comment_start]
if stream_callback and before_comment:
await stream_callback(create_token_event(before_comment))
# Skip the comment and continue with content after
accumulated_content = accumulated_content[comment_end + 3:]
else:
break
# Send accumulated content if no pending comment
if "<!--" not in accumulated_content:
if stream_callback and accumulated_content:
await stream_callback(create_token_event(accumulated_content))
accumulated_content = ""
# Send any remaining content (if not in middle of comment)
if accumulated_content and "<!--" not in accumulated_content:
if stream_callback:
await stream_callback(create_token_event(accumulated_content))
return {"messages": [AIMessage(content=response_content)]}
finally:
# ✅ STEP 5: Restore tool binding for next interaction
llm_client.bind_tools(tool_schemas, force_tool_choice=True)
# Tools routing condition (simplified for "detect-first-then-stream" strategy)
def should_continue(state: AgentState) -> Literal["tools", "agent", "post_process"]:
"""
Simplified routing logic for "detect-first-then-stream" strategy:
- has tool_calls → route to tools
- no tool_calls → route to post_process (final synthesis already completed)
"""
messages = state["messages"]
if not messages:
logger.info("should_continue: No messages, routing to post_process")
return "post_process"
last_message = messages[-1]
current_round = state.get("tool_rounds", 0)
# Get max_tool_rounds from state, fallback to config if not set
max_rounds = state.get("max_tool_rounds", None)
if max_rounds is None:
app_config = get_cached_config()
max_rounds = app_config.app.max_tool_rounds
logger.info(f"should_continue: Last message type: {type(last_message)}, tool_rounds: {current_round}/{max_rounds}")
# If last message is AI message with tool calls, route to tools
if isinstance(last_message, AIMessage):
has_tool_calls = hasattr(last_message, 'tool_calls') and last_message.tool_calls
logger.info(f"should_continue: AI message has tool_calls: {has_tool_calls}")
if has_tool_calls:
logger.info("should_continue: Routing to tools")
return "tools"
else:
# No tool calls = final synthesis already completed in call_model
logger.info("should_continue: No tool calls, routing to post_process")
return "post_process"
# If last message is tool message(s), continue with agent for next round or final synthesis
if isinstance(last_message, ToolMessage):
logger.info("should_continue: Tool message completed, continuing to agent")
return "agent"
logger.info("should_continue: Routing to post_process")
return "post_process"
# Custom tool node with streaming support and parallel execution
async def run_tools_with_streaming(state: AgentState, config: Optional[RunnableConfig] = None) -> Dict[str, Any]:
"""Execute tools with streaming events - supports parallel execution"""
messages = state["messages"]
last_message = messages[-1]
# Get stream callback from context variable
stream_callback = stream_callback_context.get()
if not isinstance(last_message, AIMessage) or not hasattr(last_message, 'tool_calls'):
return {"messages": []}
tool_calls = last_message.tool_calls or []
tool_results = []
new_messages = []
# Tools mapping
tools_map = get_tools_by_name()
async def execute_single_tool(tool_call):
"""Execute a single tool call with enhanced error handling"""
# Get stream callback from context
stream_callback = stream_callback_context.get()
# Apply error handling decorator
@handle_async_errors(
ErrorCategory.TOOL,
ErrorCode.TOOL_ERROR,
stream_callback,
tool_call.get("id", "unknown") if isinstance(tool_call, dict) else "unknown"
)
async def _execute():
# Validate tool_call format
if not isinstance(tool_call, dict):
raise ValueError(f"Tool call must be dict, got {type(tool_call)}")
tool_name = tool_call.get("name")
tool_args = tool_call.get("args", {})
tool_id = tool_call.get("id", "unknown")
if not tool_name or tool_name not in tools_map:
raise ValueError(f"Tool '{tool_name}' not found")
logger.info(f"Executing tool: {tool_name}", extra={
"tool_id": tool_id, "tool_name": tool_name
})
# Send start event
if stream_callback:
await stream_callback(create_tool_start_event(tool_id, tool_name, tool_args))
# Execute tool
import time
start_time = time.time()
result = await tools_map[tool_name].ainvoke(tool_args)
execution_time = int((time.time() - start_time) * 1000)
# Process result
if isinstance(result, dict):
result["tool_call_id"] = tool_id
if "results" in result and isinstance(result["results"], list):
for i, search_result in enumerate(result["results"]):
if isinstance(search_result, dict):
search_result["@tool_call_id"] = tool_id
search_result["@order_num"] = i
# Send result event
if stream_callback:
await stream_callback(create_tool_result_event(
tool_id, tool_name, result.get("results", []), execution_time
))
# Create tool message
tool_message = ToolMessage(
content=json.dumps(result, ensure_ascii=False),
tool_call_id=tool_id,
name=tool_name
)
return {
"message": tool_message,
"results": result.get("results", []) if isinstance(result, dict) else [],
"success": True
}
try:
return await _execute()
except Exception as e:
# Handle any errors not caught by decorator
tool_id = tool_call.get("id", "unknown") if isinstance(tool_call, dict) else "unknown"
tool_name = tool_call.get("name", "unknown") if isinstance(tool_call, dict) else "unknown"
error_message = ToolMessage(
content=f"Error: {get_user_message(ErrorCategory.TOOL)}",
tool_call_id=tool_id,
name=tool_name
)
return {
"message": error_message,
"results": [],
"success": False
}
# Execute all tool calls in parallel using asyncio.gather
if tool_calls:
logger.info(f"Executing {len(tool_calls)} tool calls in parallel")
tool_execution_results = await asyncio.gather(
*[execute_single_tool(tool_call) for tool_call in tool_calls],
return_exceptions=True
)
# Process results
for execution_result in tool_execution_results:
if execution_result is None:
continue
if isinstance(execution_result, Exception):
logger.error(f"Tool execution exception: {execution_result}")
continue
if not isinstance(execution_result, dict):
logger.error(f"Unexpected execution result type: {type(execution_result)}")
continue
new_messages.append(execution_result["message"])
if execution_result["success"] and execution_result["results"]:
tool_results.extend(execution_result["results"])
logger.info(f"Parallel tool execution completed. {len(new_messages)} tools executed, {len(tool_results)} results collected")
return {
"messages": new_messages,
"tool_results": tool_results
}
# Helper functions for citation processing
def _extract_citations_mapping(agent_response: str) -> Dict[int, Dict[str, Any]]:
"""Extract citations mapping CSV from agent response HTML comment"""
try:
# Look for citations_map comment
pattern = r'<!-- citations_map\s*(.*?)\s*-->'
match = re.search(pattern, agent_response, re.DOTALL | re.IGNORECASE)
if not match:
logger.warning("No citations_map comment found in agent response")
return {}
csv_content = match.group(1).strip()
citations_mapping = {}
for line in csv_content.split('\n'):
line = line.strip()
if not line:
continue
parts = line.split(',')
if len(parts) >= 3:
try:
citation_num = int(parts[0])
tool_call_id = parts[1].strip()
order_num = int(parts[2])
citations_mapping[citation_num] = {
'tool_call_id': tool_call_id,
'order_num': order_num
}
except (ValueError, IndexError) as e:
logger.warning(f"Failed to parse citation line: {line}, error: {e}")
continue
return citations_mapping
except Exception as e:
logger.error(f"Error extracting citations mapping: {e}")
return {}
def _build_citation_markdown(citations_mapping: Dict[int, Dict[str, Any]], tool_results: List[Dict[str, Any]]) -> str:
"""Build citation markdown based on mapping and tool results, following build_citations.py logic"""
if not citations_mapping:
return ""
# Get configuration for citation base URL
config = get_cached_config()
cat_base_url = config.citation.base_url
# Collect citation lines first; only emit header if we have at least one valid citation
entries: List[str] = []
for citation_num in sorted(citations_mapping.keys()):
mapping = citations_mapping[citation_num]
tool_call_id = mapping['tool_call_id']
order_num = mapping['order_num']
# Find the corresponding tool result
result = _find_tool_result(tool_results, tool_call_id, order_num)
if not result:
logger.warning(f"No tool result found for citation [{citation_num}]")
continue
# Extract citation information following build_citations.py logic
full_headers = result.get('full_headers', '')
lowest_header = full_headers.split("||", 1)[0] if full_headers else ""
header_display = f": {lowest_header}" if lowest_header else ""
document_code = result.get('document_code', '')
document_category = result.get('document_category', '')
# Determine standard/regulation title (assuming English language)
standard_regulation_title = ''
if document_category == 'Standard':
standard_regulation_title = result.get('x_Standard_Title_EN', '') or result.get('x_Standard_Title_CN', '')
elif document_category == 'Regulation':
standard_regulation_title = result.get('x_Regulation_Title_EN', '') or result.get('x_Regulation_Title_CN', '')
# Build link
func_uuid = result.get('func_uuid', '')
uuid = result.get('x_Standard_Regulation_Id', '')
document_code_encoded = quote(document_code, safe='') if document_code else ''
standard_regulation_title_encoded = quote(standard_regulation_title, safe='') if standard_regulation_title else ''
link_name = f"{document_code_encoded}({standard_regulation_title_encoded})" if (document_code_encoded or standard_regulation_title_encoded) else ''
link = f'{cat_base_url}?funcUuid={func_uuid}&uuid={uuid}&name={link_name}'
# Format citation line
title = result.get('title', '')
entries.append(f"[{citation_num}] {title}{header_display} | [{standard_regulation_title} | {document_code}]({link})")
# If no valid citations were found, do not include the header
if not entries:
return ""
# Build citations section with entries separated by a blank line (matching previous formatting)
md = "\n\n### 📘 Citations:\n" + "\n\n".join(entries) + "\n\n"
return md
def _find_tool_result(tool_results: List[Dict[str, Any]], tool_call_id: str, order_num: int) -> Optional[Dict[str, Any]]:
"""Find tool result by tool_call_id and order_num"""
matching_results = []
for result in tool_results:
if result.get('@tool_call_id') == tool_call_id:
matching_results.append(result)
# Sort by order and return the one at the specified position
if matching_results and 0 <= order_num < len(matching_results):
# If results have @order_num, use it; otherwise use position in list
if '@order_num' in matching_results[0]:
for result in matching_results:
if result.get('@order_num') == order_num:
return result
else:
return matching_results[order_num]
return None
def _remove_citations_comment(agent_response: str) -> str:
"""Remove citations mapping HTML comment from agent response"""
pattern = r'<!-- citations_map\s*.*?\s*-->'
return re.sub(pattern, '', agent_response, flags=re.DOTALL | re.IGNORECASE).strip()
# Post-processing node with citation list and link building
async def post_process_node(state: AgentState, config: Optional[RunnableConfig] = None) -> Dict[str, Any]:
"""
Post-processing node that builds citation list and links based on agent's citations mapping
and tool call results, following the logic from build_citations.py
"""
try:
logger.info("🔧 POST_PROCESS_NODE: Starting citation processing")
# Get stream callback from context variable
stream_callback = stream_callback_context.get()
# Get the last AI message (agent's response with citations mapping)
agent_response = ""
citations_mapping = {}
for message in reversed(state["messages"]):
if isinstance(message, AIMessage) and message.content:
# Ensure content is a string
if isinstance(message.content, str):
agent_response = message.content
break
if not agent_response:
logger.warning("POST_PROCESS_NODE: No agent response found")
return {"messages": [], "final_answer": ""}
# Extract citations mapping from agent response
citations_mapping = _extract_citations_mapping(agent_response)
logger.info(f"POST_PROCESS_NODE: Extracted {len(citations_mapping)} citations")
# Build citation markdown
citation_markdown = _build_citation_markdown(citations_mapping, state["tool_results"])
# Combine agent response (without HTML comment) with citations
clean_response = _remove_citations_comment(agent_response)
final_content = clean_response + citation_markdown
logger.info("POST_PROCESS_NODE: Built complete response with citations")
# Send citation markdown as a single block instead of streaming
stream_callback = stream_callback_context.get()
if stream_callback and citation_markdown:
logger.info("POST_PROCESS_NODE: Sending citation markdown as single block to client")
await stream_callback(create_token_event(citation_markdown))
# Create AI message with complete content
final_ai_message = AIMessage(content=final_content)
return {
"messages": [final_ai_message],
"final_answer": final_content
}
except Exception as e:
logger.error(f"Post-processing error: {e}")
error_message = "\n\n❌ **Error generating citations**\n\nPlease check the search results above."
# Send error message as single block
stream_callback = stream_callback_context.get()
if stream_callback:
await stream_callback(create_token_event(error_message))
error_content = agent_response + error_message if agent_response else error_message
error_ai_message = AIMessage(content=error_content)
return {
"messages": [error_ai_message],
"final_answer": error_ai_message.content
}
# Main workflow class
class AgenticWorkflow:
"""LangGraph-based autonomous agent workflow following v0.6.0+ best practices"""
def __init__(self):
# Build StateGraph with TypedDict state
workflow = StateGraph(AgentState)
# Add nodes following best practices
workflow.add_node("intent_recognition", intent_recognition_node)
workflow.add_node("agent", call_model)
workflow.add_node("user_manual_rag", user_manual_rag_node)
workflow.add_node("tools", run_tools_with_streaming)
workflow.add_node("post_process", post_process_node)
# Set entry point to intent recognition
workflow.set_entry_point("intent_recognition")
# Intent recognition routes to either Standard_Regulation_RAG or User_Manual_RAG
workflow.add_conditional_edges(
"intent_recognition",
intent_router,
{
"Standard_Regulation_RAG": "agent",
"User_Manual_RAG": "user_manual_rag"
}
)
# Standard RAG workflow (existing pattern)
workflow.add_conditional_edges(
"agent",
should_continue,
{
"tools": "tools",
"agent": "agent", # Allow agent to continue for multi-round
"post_process": "post_process"
}
)
# Tools route back to should_continue for multi-round decision
workflow.add_conditional_edges(
"tools",
should_continue,
{
"agent": "agent", # Continue to agent for next round
"post_process": "post_process" # Or finish if max rounds reached
}
)
# User Manual RAG directly goes to END (single turn)
workflow.add_edge("user_manual_rag", END)
# Post-process is terminal
workflow.add_edge("post_process", END)
# Compile graph with PostgreSQL checkpointer for session memory
try:
checkpointer = get_checkpointer()
self.graph = workflow.compile(checkpointer=checkpointer)
logger.info("Graph compiled with PostgreSQL checkpointer for session memory")
except Exception as e:
logger.warning(f"Failed to initialize PostgreSQL checkpointer, using memory-only graph: {e}")
self.graph = workflow.compile()
async def astream(self, state: TurnState, stream_callback: Callable | None = None):
"""Stream agent execution using LangGraph with PostgreSQL session memory"""
try:
# Get configuration
config = get_cached_config()
# Prepare initial messages for the graph
messages = []
for msg in state.messages:
if msg.role == "user":
messages.append(HumanMessage(content=msg.content))
elif msg.role == "assistant":
messages.append(AIMessage(content=msg.content))
# Create initial agent state (without stream_callback to avoid serialization issues)
initial_state: AgentState = {
"messages": messages,
"session_id": state.session_id,
"intent": None, # Will be determined by intent recognition node
"tool_results": [],
"final_answer": "",
"tool_rounds": 0,
"max_tool_rounds": config.app.max_tool_rounds, # Use configuration value
"max_tool_rounds_user_manual": config.app.max_tool_rounds_user_manual # Use configuration value for user manual agent
}
# Set stream callback in context variable (thread-safe)
stream_callback_context.set(stream_callback)
# Create proper RunnableConfig
runnable_config = RunnableConfig(configurable={"thread_id": state.session_id})
# Stream graph execution with session memory
async for step in self.graph.astream(initial_state, config=runnable_config):
if "post_process" in step:
final_state = step["post_process"]
# Extract the tool summary message and update state
state.final_answer = final_state.get("final_answer", "")
# Add the summary as a regular assistant message
if state.final_answer:
state.messages.append(Message(
role="assistant",
content=state.final_answer,
timestamp=datetime.now()
))
yield {"final": state}
break
elif "user_manual_rag" in step:
# Handle user manual RAG completion
final_state = step["user_manual_rag"]
# Extract the response from user manual RAG
state.final_answer = final_state.get("final_answer", "")
# Add the response as a regular assistant message
if state.final_answer:
state.messages.append(Message(
role="assistant",
content=state.final_answer,
timestamp=datetime.now()
))
yield {"final": state}
break
else:
# Process regular steps (intent_recognition, agent, tools)
yield step
except Exception as e:
logger.error(f"AgentWorkflow error: {e}")
state.final_answer = "I apologize, but I encountered an error while processing your request."
yield {"final": state}
def build_graph() -> AgenticWorkflow:
"""Build and return the autonomous agent workflow"""
return AgenticWorkflow()

View File

@@ -0,0 +1,136 @@
"""
Intent recognition functionality for the Agentic RAG system.
This module contains the intent classification logic.
"""
import logging
from typing import Dict, Any, Optional, Literal
from langchain_core.messages import SystemMessage
from langchain_core.runnables import RunnableConfig
from pydantic import BaseModel
from .state import AgentState
from ..llm_client import LLMClient
from ..config import get_config
from ..utils.error_handler import StructuredLogger
logger = StructuredLogger(__name__)
# Intent Recognition Models
class Intent(BaseModel):
"""Intent classification model for routing user queries"""
label: Literal["Standard_Regulation_RAG", "User_Manual_RAG"]
confidence: Optional[float] = None
def get_last_user_message(messages) -> str:
"""Extract the last user message from conversation history"""
for message in reversed(messages):
if hasattr(message, 'content'):
content = message.content
# Handle both string and list content
if isinstance(content, str):
return content
elif isinstance(content, list):
# Extract string content from list
return " ".join([str(item) for item in content if isinstance(item, str)])
return ""
def render_conversation_history(messages, max_messages: int = 10) -> str:
"""Render conversation history for context"""
recent_messages = messages[-max_messages:] if len(messages) > max_messages else messages
lines = []
for msg in recent_messages:
if hasattr(msg, 'content'):
content = msg.content
if isinstance(content, str):
# Determine message type by class name or other attributes
if 'Human' in str(type(msg)):
lines.append(f"<user>{content}</user>")
elif 'AI' in str(type(msg)):
lines.append(f"<ai>{content}</ai>")
elif isinstance(content, list):
content_str = " ".join([str(item) for item in content if isinstance(item, str)])
if 'Human' in str(type(msg)):
lines.append(f"<user>{content_str}</user>")
elif 'AI' in str(type(msg)):
lines.append(f"<ai>{content_str}</ai>")
return "\n".join(lines)
async def intent_recognition_node(state: AgentState, config: Optional[RunnableConfig] = None) -> Dict[str, Any]:
"""
Intent recognition node that uses LLM to classify user queries into specific domains
"""
try:
logger.info("🎯 INTENT_RECOGNITION_NODE: Starting intent classification")
app_config = get_config()
llm_client = LLMClient()
# Get current user query and conversation history
current_query = get_last_user_message(state["messages"])
conversation_context = render_conversation_history(state["messages"])
# Get intent classification prompt from configuration
rag_prompts = app_config.get_rag_prompts()
intent_prompt_template = rag_prompts.get("intent_recognition_prompt")
if not intent_prompt_template:
logger.error("Intent recognition prompt not found in configuration")
return {"intent": "Standard_Regulation_RAG"}
# Format the prompt with instruction to return only the label
system_prompt = intent_prompt_template.format(
current_query=current_query,
conversation_context=conversation_context
) + "\n\nIMPORTANT: You must respond with ONLY one of these two exact labels: 'Standard_Regulation_RAG' or 'User_Manual_RAG'. Do not include any other text or explanation."
# Classify intent using regular LLM call
intent_result = await llm_client.llm.ainvoke([
SystemMessage(content=system_prompt)
])
# Parse the response to extract the intent label
response_text = ""
if hasattr(intent_result, 'content') and intent_result.content:
if isinstance(intent_result.content, str):
response_text = intent_result.content.strip()
elif isinstance(intent_result.content, list):
# Handle list content by joining string elements
response_text = " ".join([str(item) for item in intent_result.content if isinstance(item, str)]).strip()
# Extract intent label from response
if "User_Manual_RAG" in response_text:
intent_label = "User_Manual_RAG"
elif "Standard_Regulation_RAG" in response_text:
intent_label = "Standard_Regulation_RAG"
else:
# Default fallback
logger.warning(f"Could not parse intent from response: {response_text}, defaulting to Standard_Regulation_RAG")
intent_label = "Standard_Regulation_RAG"
logger.info(f"🎯 INTENT_RECOGNITION_NODE: Classified intent as '{intent_label}'")
return {"intent": intent_label}
except Exception as e:
logger.error(f"Intent recognition error: {e}")
# Default to Standard_Regulation_RAG if classification fails
logger.info("🎯 INTENT_RECOGNITION_NODE: Defaulting to Standard_Regulation_RAG due to error")
return {"intent": "Standard_Regulation_RAG"}
def intent_router(state: AgentState) -> Literal["Standard_Regulation_RAG", "User_Manual_RAG"]:
"""
Route based on intent classification result
"""
intent = state.get("intent")
if intent is None:
logger.warning("🎯 INTENT_ROUTER: No intent found, defaulting to Standard_Regulation_RAG")
return "Standard_Regulation_RAG"
logger.info(f"🎯 INTENT_ROUTER: Routing to {intent}")
return intent

View File

@@ -0,0 +1,270 @@
"""
Conversation history trimming utilities for managing context length.
"""
import logging
from typing import List, Optional, Sequence, Tuple
from langchain_core.messages import BaseMessage, SystemMessage, HumanMessage, ToolMessage, AIMessage, AnyMessage
from langchain_core.messages.utils import trim_messages, count_tokens_approximately
logger = logging.getLogger(__name__)
class ConversationTrimmer:
"""
Manages conversation history to prevent exceeding LLM context limits.
"""
def __init__(self, max_context_length: int = 96000, preserve_system: bool = True):
"""
Initialize the conversation trimmer.
Args:
max_context_length: Maximum context length for conversation history (in tokens)
preserve_system: Whether to always preserve system messages
"""
self.max_context_length = max_context_length
self.preserve_system = preserve_system
# Reserve tokens for response generation (use 85% for history, 15% for response)
self.history_token_limit = int(max_context_length * 0.85)
def trim_conversation_history(self, messages: Sequence[AnyMessage]) -> List[BaseMessage]:
"""
Trim conversation history to fit within token limits.
Args:
messages: List of conversation messages
Returns:
Trimmed list of messages
"""
if not messages:
return list(messages)
try:
# Convert to list for processing
message_list = list(messages)
# First, try multi-round tool call optimization
optimized_messages = self._optimize_multi_round_tool_calls(message_list)
# Check if optimization is sufficient
try:
token_count = count_tokens_approximately(optimized_messages)
if token_count <= self.history_token_limit:
original_count = len(message_list)
optimized_count = len(optimized_messages)
if optimized_count < original_count:
logger.info(f"Multi-round tool optimization: {original_count} -> {optimized_count} messages")
return optimized_messages
except Exception:
# If token counting fails, continue with LangChain trimming
pass
# If still too long, use LangChain's trim_messages utility
trimmed_messages = trim_messages(
optimized_messages,
strategy="last", # Keep most recent messages
token_counter=count_tokens_approximately,
max_tokens=self.history_token_limit,
start_on="human", # Ensure valid conversation start
end_on=("human", "tool", "ai"), # Allow ending on human, tool, or AI messages
include_system=self.preserve_system, # Preserve system messages
allow_partial=False # Don't split individual messages
)
original_count = len(messages)
trimmed_count = len(trimmed_messages)
if trimmed_count < original_count:
logger.info(f"Trimmed conversation history: {original_count} -> {trimmed_count} messages")
return trimmed_messages
except Exception as e:
logger.error(f"Error trimming conversation history: {e}")
# Fallback: keep last N messages
return self._fallback_trim(list(messages))
def _optimize_multi_round_tool_calls(self, messages: List[AnyMessage]) -> List[BaseMessage]:
"""
Optimize conversation history by removing older tool call results in multi-round scenarios.
This reduces token usage while preserving conversation context.
Strategy:
1. Always preserve system messages
2. Always preserve the original user query
3. Keep the most recent AI-Tool message pairs (for context continuity)
4. Remove older ToolMessage content which typically contains large JSON responses
Args:
messages: List of conversation messages
Returns:
Optimized list of messages
"""
if len(messages) <= 4: # Too short to optimize
return [msg for msg in messages]
# Identify message patterns
tool_rounds = self._identify_tool_rounds(messages)
if len(tool_rounds) <= 1: # Single or no tool round, no optimization needed
return [msg for msg in messages]
logger.info(f"Multi-round tool optimization: Found {len(tool_rounds)} tool rounds")
# Build optimized message list
optimized = []
# Always preserve system messages
for msg in messages:
if isinstance(msg, SystemMessage):
optimized.append(msg)
# Preserve initial user query (first human message after system)
first_human_added = False
for msg in messages:
if isinstance(msg, HumanMessage) and not first_human_added:
optimized.append(msg)
first_human_added = True
break
# Keep only the most recent tool round (preserve context for next round)
if tool_rounds:
latest_round_start, latest_round_end = tool_rounds[-1]
# Add messages from the latest tool round
for i in range(latest_round_start, min(latest_round_end + 1, len(messages))):
msg = messages[i]
if not isinstance(msg, SystemMessage) and not (isinstance(msg, HumanMessage) and not first_human_added):
optimized.append(msg)
logger.info(f"Multi-round optimization: {len(messages)} -> {len(optimized)} messages (removed {len(tool_rounds)-1} older tool rounds)")
return optimized
def _identify_tool_rounds(self, messages: List[AnyMessage]) -> List[Tuple[int, int]]:
"""
Identify tool calling rounds in the message sequence.
A tool round typically consists of:
- AI message with tool_calls
- One or more ToolMessage responses
Returns:
List of (start_index, end_index) tuples for each tool round
"""
rounds = []
i = 0
while i < len(messages):
msg = messages[i]
# Look for AI message with tool calls
if isinstance(msg, AIMessage) and hasattr(msg, 'tool_calls') and msg.tool_calls:
round_start = i
round_end = i
# Find the end of this tool round (look for consecutive ToolMessages)
j = i + 1
while j < len(messages) and isinstance(messages[j], ToolMessage):
round_end = j
j += 1
# Only consider it a tool round if we found at least one ToolMessage
if round_end > round_start:
rounds.append((round_start, round_end))
i = round_end + 1
else:
i += 1
else:
i += 1
return rounds
def _fallback_trim(self, messages: List[AnyMessage], max_messages: int = 20) -> List[BaseMessage]:
"""
Fallback trimming based on message count.
Args:
messages: List of conversation messages
max_messages: Maximum number of messages to keep
Returns:
Trimmed list of messages
"""
if len(messages) <= max_messages:
return [msg for msg in messages] # Convert to BaseMessage
# Preserve system message if it exists
system_messages = [msg for msg in messages if isinstance(msg, SystemMessage)]
other_messages = [msg for msg in messages if not isinstance(msg, SystemMessage)]
# Keep the most recent messages
recent_messages = other_messages[-(max_messages - len(system_messages)):]
result = system_messages + recent_messages
logger.info(f"Fallback trimming: {len(messages)} -> {len(result)} messages")
return [msg for msg in result] # Ensure BaseMessage type
def should_trim(self, messages: Sequence[AnyMessage]) -> bool:
"""
Check if conversation history should be trimmed.
Strategy:
1. Always trim if there are multiple tool rounds from previous conversation turns
2. Also trim if approaching token limit
Args:
messages: List of conversation messages
Returns:
True if trimming is needed
"""
try:
# Convert to list for processing
message_list = list(messages)
# Check for multiple tool rounds - if found, always trim to remove old tool results
tool_rounds = self._identify_tool_rounds(message_list)
if len(tool_rounds) > 1:
logger.info(f"Found {len(tool_rounds)} tool rounds - trimming to remove old tool results")
return True
# Also check token count for traditional trimming
token_count = count_tokens_approximately(message_list)
return token_count > self.history_token_limit
except Exception:
# Fallback to message count
return len(messages) > 30
def create_conversation_trimmer(max_context_length: Optional[int] = None) -> ConversationTrimmer:
"""
Create a conversation trimmer with config-based settings.
Args:
max_context_length: Override for maximum context length
Returns:
ConversationTrimmer instance
"""
# If max_context_length is provided, use it directly
if max_context_length is not None:
return ConversationTrimmer(
max_context_length=max_context_length,
preserve_system=True
)
# Try to get from config, fallback to default if config not available
try:
from ..config import get_config
config = get_config()
effective_max_context_length = config.get_max_context_length()
except (RuntimeError, AttributeError):
effective_max_context_length = 96000
return ConversationTrimmer(
max_context_length=effective_max_context_length,
preserve_system=True
)

View File

@@ -0,0 +1,66 @@
from pydantic import BaseModel, Field
from typing import List, Dict, Any, Optional, Literal
from datetime import datetime
from typing_extensions import Annotated
from langgraph.graph.message import add_messages
from langchain_core.messages import BaseMessage
class Message(BaseModel):
"""Base message class for conversation history"""
role: str # "user", "assistant", "tool"
content: str
timestamp: Optional[datetime] = None
tool_call_id: Optional[str] = None
tool_name: Optional[str] = None
class Citation(BaseModel):
"""Citation mapping between numbers and result IDs"""
number: int
result_id: str
url: Optional[str] = None
class ToolResult(BaseModel):
"""Normalized tool result schema"""
id: str
title: str
url: Optional[str] = None
score: Optional[float] = None
metadata: Dict[str, Any] = Field(default_factory=dict)
content: Optional[str] = None # For chunk results
# Standard/regulation specific fields
publisher: Optional[str] = None
publish_date: Optional[str] = None
document_code: Optional[str] = None
document_category: Optional[str] = None
class TurnState(BaseModel):
"""State container for LangGraph workflow"""
session_id: str
messages: List[Message] = Field(default_factory=list)
tool_results: List[ToolResult] = Field(default_factory=list)
citations: List[Citation] = Field(default_factory=list)
meta: Dict[str, Any] = Field(default_factory=dict)
# Additional fields for tracking
current_step: int = 0
max_steps: int = 5
final_answer: Optional[str] = None
# TypedDict for LangGraph AgentState (LangGraph native format)
from typing import TypedDict
from langgraph.graph import MessagesState
class AgentState(MessagesState):
"""LangGraph state with intent recognition support"""
session_id: str
intent: Optional[Literal["Standard_Regulation_RAG", "User_Manual_RAG"]]
tool_results: Annotated[List[Dict[str, Any]], lambda x, y: (x or []) + (y or [])]
final_answer: str
tool_rounds: int
max_tool_rounds: int
max_tool_rounds_user_manual: int

View File

@@ -0,0 +1,98 @@
"""
Tool definitions and schemas for the Agentic RAG system.
This module contains all tool implementations and their corresponding schemas.
"""
import logging
from typing import Dict, Any, List
from langchain_core.tools import tool
from ..retrieval.retrieval import AgenticRetrieval
logger = logging.getLogger(__name__)
# Tool Definitions using @tool decorator (following LangGraph best practices)
@tool
async def retrieve_standard_regulation(query: str) -> Dict[str, Any]:
"""Search for attributes/metadata of China standards and regulations in automobile/manufacturing industry"""
async with AgenticRetrieval() as retrieval:
try:
result = await retrieval.retrieve_standard_regulation(
query=query
)
return {
"tool_name": "retrieve_standard_regulation",
"results_count": len(result.results),
"results": result.results, # Already dict objects, no need for model_dump()
"took_ms": result.took_ms
}
except Exception as e:
logger.error(f"Retrieval error: {e}")
return {"error": str(e), "results_count": 0, "results": []}
@tool
async def retrieve_doc_chunk_standard_regulation(query: str) -> Dict[str, Any]:
"""Search for detailed document content chunks of China standards and regulations in automobile/manufacturing industry"""
async with AgenticRetrieval() as retrieval:
try:
result = await retrieval.retrieve_doc_chunk_standard_regulation(
query=query
)
return {
"tool_name": "retrieve_doc_chunk_standard_regulation",
"results_count": len(result.results),
"results": result.results, # Already dict objects, no need for model_dump()
"took_ms": result.took_ms
}
except Exception as e:
logger.error(f"Doc chunk retrieval error: {e}")
return {"error": str(e), "results_count": 0, "results": []}
# Available tools list
tools = [retrieve_standard_regulation, retrieve_doc_chunk_standard_regulation]
def get_tool_schemas() -> List[Dict[str, Any]]:
"""
Generate tool schemas for LLM function calling.
Returns:
List of tool schemas in OpenAI function calling format
"""
tools.append();
tool_schemas = []
for tool in tools:
schema = {
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Search query for retrieving relevant information"
}
},
"required": ["query"]
}
}
}
tool_schemas.append(schema)
return tool_schemas
def get_tools_by_name() -> Dict[str, Any]:
"""
Create a mapping of tool names to tool functions.
Returns:
Dictionary mapping tool names to tool functions
"""
return {tool.name: tool for tool in tools}

View File

@@ -0,0 +1,464 @@
"""
User Manual Agent node for the Agentic RAG system.
This module contains the autonomous user manual agent that can use tools and generate responses.
"""
import logging
from typing import Dict, Any, List, Optional, Callable, Literal
from contextvars import ContextVar
from langchain_core.messages import AIMessage, SystemMessage, BaseMessage, ToolMessage, HumanMessage
from langchain_core.runnables import RunnableConfig
from .state import AgentState
from .user_manual_tools import get_user_manual_tool_schemas, get_user_manual_tools_by_name
from .message_trimmer import create_conversation_trimmer
from ..llm_client import LLMClient
from ..config import get_config
from ..sse import (
create_tool_start_event,
create_tool_result_event,
create_tool_error_event,
create_token_event,
create_error_event
)
from ..utils.error_handler import (
StructuredLogger, ErrorCategory, ErrorCode,
handle_async_errors, get_user_message
)
logger = StructuredLogger(__name__)
# Cache configuration at module level to avoid repeated get_config() calls
_cached_config = None
def get_cached_config():
"""Get cached configuration, loading it if not already cached"""
global _cached_config
if _cached_config is None:
_cached_config = get_config()
return _cached_config
# User Manual Agent node (autonomous function calling agent)
async def user_manual_agent_node(state: AgentState, config: Optional[RunnableConfig] = None) -> Dict[str, Any]:
"""
User Manual Agent node that autonomously uses user manual tools and generates final answer.
Implements "detect-first-then-stream" strategy for optimal multi-round behavior:
1. Always start with non-streaming detection to check for tool needs
2. If tool_calls exist → return immediately for routing to tools
3. If no tool_calls → temporarily disable tools and perform streaming final synthesis
"""
app_config = get_cached_config()
llm_client = LLMClient()
# Get stream callback from context variable
from .graph import stream_callback_context
stream_callback = stream_callback_context.get()
# Get user manual tool schemas and bind tools for planning phase
tool_schemas = get_user_manual_tool_schemas()
llm_client.bind_tools(tool_schemas, force_tool_choice=True)
# Create conversation trimmer for managing context length
trimmer = create_conversation_trimmer()
# Prepare messages with user manual system prompt
messages = state["messages"].copy()
if not messages or not isinstance(messages[0], SystemMessage):
rag_prompts = app_config.get_rag_prompts()
user_manual_prompt = rag_prompts.get("user_manual_prompt", "")
if not user_manual_prompt:
raise ValueError("user_manual_prompt is null")
# For user manual agent, we need to format the prompt with placeholders
# Extract current query and conversation history
current_query = ""
for message in reversed(messages):
if isinstance(message, HumanMessage):
current_query = message.content
break
conversation_history = ""
if len(messages) > 1:
conversation_history = render_conversation_history(messages[:-1]) # Exclude current query
# Format system prompt (initially with empty context, tools will provide it)
formatted_system_prompt = user_manual_prompt.format(
conversation_history=conversation_history,
context_content="", # Will be filled by tools
current_query=current_query
)
messages = [SystemMessage(content=formatted_system_prompt)] + messages
# Track tool rounds
current_round = state.get("tool_rounds", 0)
# Get max_tool_rounds_user_manual from state, fallback to config if not set
max_rounds = state.get("max_tool_rounds_user_manual", None)
if max_rounds is None:
max_rounds = app_config.app.max_tool_rounds_user_manual
# Only apply trimming at the start of a new conversation turn (when tool_rounds = 0)
# This prevents trimming current turn's tool results during multi-round tool calling
if current_round == 0:
# Trim conversation history to manage context length (only for previous conversation turns)
if trimmer.should_trim(messages):
messages = trimmer.trim_conversation_history(messages)
logger.info("Applied conversation history trimming for context management (new conversation turn)")
else:
logger.info(f"Skipping trimming during tool round {current_round} to preserve current turn's context")
logger.info(f"User Manual Agent node: tool_rounds={current_round}, max_tool_rounds={max_rounds}")
# Check if this should be final synthesis (max rounds reached)
has_tool_messages = any(isinstance(msg, ToolMessage) for msg in messages)
is_final_synthesis = has_tool_messages and current_round >= max_rounds
if is_final_synthesis:
logger.info("Starting final synthesis phase - no more tool calls allowed")
# ✅ STEP 1: Final synthesis with tools disabled from the start
# Disable tools to prevent any tool calling during synthesis
try:
original_tools = llm_client.bind_tools([], force_tool_choice=False) # Disable tools
if not stream_callback:
# No streaming callback, generate final response without tools
draft = await llm_client.ainvoke(list(messages))
return {"messages": [draft]}
# ✅ STEP 2: Streaming final synthesis with improved HTML comment filtering
response_content = ""
accumulated_content = ""
async for token in llm_client.astream(list(messages)):
accumulated_content += token
response_content += token
# Check for complete HTML comments in accumulated content
while "<!--" in accumulated_content and "-->" in accumulated_content:
comment_start = accumulated_content.find("<!--")
comment_end = accumulated_content.find("-->", comment_start)
if comment_start >= 0 and comment_end >= 0:
# Send content before comment
before_comment = accumulated_content[:comment_start]
if stream_callback and before_comment:
await stream_callback(create_token_event(before_comment))
# Skip the comment and continue with content after
accumulated_content = accumulated_content[comment_end + 3:]
else:
break
# Send accumulated content if no pending comment
if "<!--" not in accumulated_content:
if stream_callback and accumulated_content:
await stream_callback(create_token_event(accumulated_content))
accumulated_content = ""
# Send any remaining content (if not in middle of comment)
if accumulated_content and "<!--" not in accumulated_content:
if stream_callback:
await stream_callback(create_token_event(accumulated_content))
return {"messages": [AIMessage(content=response_content)]}
finally:
# ✅ STEP 3: Restore tool binding for next interaction
llm_client.bind_tools(tool_schemas, force_tool_choice=True)
else:
logger.info(f"User Manual tool calling round {current_round + 1}/{max_rounds}")
# ✅ STEP 1: Non-streaming detection to check for tool needs
draft = await llm_client.ainvoke_with_tools(list(messages))
# ✅ STEP 2: If draft has tool_calls, execute them within this node
if isinstance(draft, AIMessage) and hasattr(draft, 'tool_calls') and draft.tool_calls:
logger.info(f"Detected {len(draft.tool_calls)} tool calls, executing within user manual agent")
# Create a new state with the tool call message added
tool_call_state = state.copy()
updated_messages = state["messages"].copy()
updated_messages.append(draft)
tool_call_state["messages"] = updated_messages
# Execute the tools using the existing streaming tool execution function
tool_results = await run_user_manual_tools_with_streaming(tool_call_state)
tool_messages = tool_results.get("messages", [])
# Increment tool round counter for next iteration
new_tool_rounds = current_round + 1
logger.info(f"Incremented user manual tool_rounds to {new_tool_rounds}")
# Continue with another round if under max rounds
if new_tool_rounds < max_rounds:
# Recursive call for next round with all messages
final_messages = updated_messages + tool_messages
recursive_state = state.copy()
recursive_state["messages"] = final_messages
recursive_state["tool_rounds"] = new_tool_rounds
return await user_manual_agent_node(recursive_state)
else:
# Max rounds reached, force final synthesis
logger.info("Max tool rounds reached, forcing final synthesis")
# Update messages for final synthesis
messages = updated_messages + tool_messages
# Continue to final synthesis below
# ✅ STEP 3: No tool_calls needed or max rounds reached → Enter final synthesis with streaming
# Temporarily disable tools to prevent accidental tool calling during synthesis
try:
llm_client.bind_tools([], force_tool_choice=False) # Disable tools
if not stream_callback:
# No streaming callback, use the draft we already have
return {"messages": [draft]}
# ✅ STEP 4: Streaming final synthesis with improved HTML comment filtering
response_content = ""
accumulated_content = ""
async for token in llm_client.astream(list(messages)):
accumulated_content += token
response_content += token
# Check for complete HTML comments in accumulated content
while "<!--" in accumulated_content and "-->" in accumulated_content:
comment_start = accumulated_content.find("<!--")
comment_end = accumulated_content.find("-->", comment_start)
if comment_start >= 0 and comment_end >= 0:
# Send content before comment
before_comment = accumulated_content[:comment_start]
if stream_callback and before_comment:
await stream_callback(create_token_event(before_comment))
# Skip the comment and continue with content after
accumulated_content = accumulated_content[comment_end + 3:]
else:
break
# Send accumulated content if no pending comment
if "<!--" not in accumulated_content:
if stream_callback and accumulated_content:
await stream_callback(create_token_event(accumulated_content))
accumulated_content = ""
# Send any remaining content (if not in middle of comment)
if accumulated_content and "<!--" not in accumulated_content:
if stream_callback:
await stream_callback(create_token_event(accumulated_content))
return {"messages": [AIMessage(content=response_content)]}
finally:
# ✅ STEP 5: Restore tool binding for next interaction
llm_client.bind_tools(tool_schemas, force_tool_choice=True)
def render_conversation_history(messages, max_messages: int = 10) -> str:
"""Render conversation history for context"""
recent_messages = messages[-max_messages:] if len(messages) > max_messages else messages
lines = []
for msg in recent_messages:
if hasattr(msg, 'content'):
content = msg.content
if isinstance(content, str):
# Determine message type by class name or other attributes
if 'Human' in str(type(msg)):
lines.append(f"<user>{content}</user>")
elif 'AI' in str(type(msg)):
lines.append(f"<ai>{content}</ai>")
elif isinstance(content, list):
content_str = " ".join([str(item) for item in content if isinstance(item, str)])
if 'Human' in str(type(msg)):
lines.append(f"<user>{content_str}</user>")
elif 'AI' in str(type(msg)):
lines.append(f"<ai>{content_str}</ai>")
return "\n".join(lines)
# User Manual Tools routing condition
def user_manual_should_continue(state: AgentState) -> Literal["user_manual_tools", "user_manual_agent", "post_process"]:
"""
Routing logic for user manual agent:
- has tool_calls → route to user_manual_tools
- no tool_calls → route to post_process (final synthesis already completed)
"""
messages = state["messages"]
if not messages:
logger.info("user_manual_should_continue: No messages, routing to post_process")
return "post_process"
last_message = messages[-1]
current_round = state.get("tool_rounds", 0)
# Get max_tool_rounds_user_manual from state, fallback to config if not set
max_rounds = state.get("max_tool_rounds_user_manual", None)
if max_rounds is None:
app_config = get_cached_config()
max_rounds = app_config.app.max_tool_rounds_user_manual
logger.info(f"user_manual_should_continue: Last message type: {type(last_message)}, tool_rounds: {current_round}/{max_rounds}")
# If last message is AI message with tool calls, route to tools
if isinstance(last_message, AIMessage):
has_tool_calls = hasattr(last_message, 'tool_calls') and last_message.tool_calls
logger.info(f"user_manual_should_continue: AI message has tool_calls: {has_tool_calls}")
if has_tool_calls:
logger.info("user_manual_should_continue: Routing to user_manual_tools")
return "user_manual_tools"
else:
# No tool calls = final synthesis already completed in user_manual_agent_node
logger.info("user_manual_should_continue: No tool calls, routing to post_process")
return "post_process"
# If last message is tool message(s), continue with agent for next round or final synthesis
if isinstance(last_message, ToolMessage):
logger.info("user_manual_should_continue: Tool message completed, continuing to user_manual_agent")
return "user_manual_agent"
logger.info("user_manual_should_continue: Routing to post_process")
return "post_process"
# User Manual Tools node with streaming support
async def run_user_manual_tools_with_streaming(state: AgentState, config: Optional[RunnableConfig] = None) -> Dict[str, Any]:
"""Execute user manual tools with streaming events - supports parallel execution"""
messages = state["messages"]
last_message = messages[-1]
# Get stream callback from context variable
from .graph import stream_callback_context
stream_callback = stream_callback_context.get()
if not isinstance(last_message, AIMessage) or not hasattr(last_message, 'tool_calls'):
return {"messages": []}
tool_calls = last_message.tool_calls or []
tool_results = []
new_messages = []
# User manual tools mapping
tools_map = get_user_manual_tools_by_name()
async def execute_single_tool(tool_call):
"""Execute a single user manual tool call with enhanced error handling"""
# Get stream callback from context
from .graph import stream_callback_context
stream_callback = stream_callback_context.get()
# Apply error handling decorator
@handle_async_errors(
ErrorCategory.TOOL,
ErrorCode.TOOL_ERROR,
stream_callback,
tool_call.get("id", "unknown") if isinstance(tool_call, dict) else "unknown"
)
async def _execute():
# Validate tool_call format
if not isinstance(tool_call, dict):
raise ValueError(f"Tool call must be dict, got {type(tool_call)}")
tool_name = tool_call.get("name")
tool_args = tool_call.get("args", {})
tool_id = tool_call.get("id", "unknown")
if not tool_name:
raise ValueError("Tool call missing 'name' field")
if tool_name not in tools_map:
available_tools = list(tools_map.keys())
raise ValueError(f"Tool '{tool_name}' not found. Available user manual tools: {available_tools}")
tool_func = tools_map[tool_name]
# Stream tool start event
if stream_callback:
await stream_callback(create_tool_start_event(tool_id, tool_name, tool_args))
import time
start_time = time.time()
try:
# Execute the user manual tool
result = await tool_func.ainvoke(tool_args)
# Calculate execution time
took_ms = int((time.time() - start_time) * 1000)
# Stream tool result event
if stream_callback:
await stream_callback(create_tool_result_event(tool_id, tool_name, result, took_ms))
# Create tool message
tool_message = ToolMessage(
content=str(result),
tool_call_id=tool_id,
name=tool_name
)
return tool_message, {"name": tool_name, "result": result, "took_ms": took_ms}
except Exception as e:
took_ms = int((time.time() - start_time) * 1000)
error_msg = get_user_message(ErrorCategory.TOOL)
# Stream tool error event
if stream_callback:
await stream_callback(create_tool_error_event(tool_id, tool_name, error_msg))
# Create error tool message
tool_message = ToolMessage(
content=f"Error executing {tool_name}: {error_msg}",
tool_call_id=tool_id,
name=tool_name
)
return tool_message, {"name": tool_name, "error": error_msg, "took_ms": took_ms}
return await _execute()
# Execute user manual tools (typically just one for user manual retrieval)
import asyncio
tasks = [execute_single_tool(tool_call) for tool_call in tool_calls]
results = await asyncio.gather(*tasks, return_exceptions=True)
for i, result in enumerate(results):
if isinstance(result, Exception):
# Handle execution exception
tool_call = tool_calls[i]
tool_id = tool_call.get("id", f"error_{i}") or f"error_{i}"
tool_name = tool_call.get("name", "unknown")
error_msg = get_user_message(ErrorCategory.TOOL)
if stream_callback:
await stream_callback(create_tool_error_event(tool_id, tool_name, error_msg))
error_message = ToolMessage(
content=f"Error executing {tool_name}: {error_msg}",
tool_call_id=tool_id,
name=tool_name
)
new_messages.append(error_message)
elif isinstance(result, tuple) and len(result) == 2:
# result is a tuple: (tool_message, tool_result)
tool_message, tool_result = result
new_messages.append(tool_message)
tool_results.append(tool_result)
else:
# Unexpected result format
logger.error(f"Unexpected tool execution result format: {type(result)}")
continue
return {"messages": new_messages, "tool_results": tool_results}
# Legacy function for backward compatibility
async def user_manual_rag_node(state: AgentState, config: Optional[RunnableConfig] = None) -> Dict[str, Any]:
"""
Legacy user manual RAG node - redirects to new agent-based implementation
"""
logger.info("📚 USER_MANUAL_RAG_NODE: Redirecting to user_manual_agent_node")
return await user_manual_agent_node(state, config)

View File

@@ -0,0 +1,77 @@
"""
User manual specific tools for the Agentic RAG system.
This module contains tools specifically for user manual retrieval and processing.
"""
import logging
from typing import Dict, Any, List
from langchain_core.tools import tool
from ..retrieval.retrieval import AgenticRetrieval
logger = logging.getLogger(__name__)
# User Manual Tools
@tool
async def retrieve_system_usermanual(query: str) -> Dict[str, Any]:
"""Search for document content chunks of user manual of this system(CATOnline)"""
async with AgenticRetrieval() as retrieval:
try:
result = await retrieval.retrieve_doc_chunk_user_manual(
query=query
)
return {
"tool_name": "retrieve_system_usermanual",
"results_count": len(result.results),
"results": result.results, # Already dict objects, no need for model_dump()
"took_ms": result.took_ms
}
except Exception as e:
logger.error(f"User manual retrieval error: {e}")
return {"error": str(e), "results_count": 0, "results": []}
# User manual tools list
user_manual_tools = [retrieve_system_usermanual]
def get_user_manual_tool_schemas() -> List[Dict[str, Any]]:
"""
Generate tool schemas for user manual tools.
Returns:
List of tool schemas in OpenAI function calling format
"""
tool_schemas = []
for tool in user_manual_tools:
schema = {
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Search query for retrieving relevant information"
}
},
"required": ["query"]
}
}
}
tool_schemas.append(schema)
return tool_schemas
def get_user_manual_tools_by_name() -> Dict[str, Any]:
"""
Create a mapping of user manual tool names to tool functions.
Returns:
Dictionary mapping tool names to tool functions
"""
return {tool.name: tool for tool in user_manual_tools}

View File

@@ -0,0 +1,103 @@
from typing import AsyncIterator, Dict, Any, List, Optional
from langchain_openai import ChatOpenAI, AzureChatOpenAI
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, SystemMessage, ToolMessage
from langchain_core.tools import BaseTool
import logging
from .config import get_config
logger = logging.getLogger(__name__)
class LLMClient:
"""Wrapper for OpenAI/Azure OpenAI clients with streaming and function calling support"""
def __init__(self):
self.config = get_config()
self.llm = self._create_llm()
self.llm_with_tools = None
def _create_llm(self) -> ChatOpenAI | AzureChatOpenAI:
"""Create LLM client based on configuration"""
llm_config = self.config.get_llm_config()
if llm_config["provider"] == "openai":
# Create base parameters
params = {
"base_url": llm_config["base_url"],
"api_key": llm_config["api_key"],
"model": llm_config["model"],
"streaming": True,
}
# Only add temperature if explicitly set
if "temperature" in llm_config:
params["temperature"] = llm_config["temperature"]
return ChatOpenAI(**params)
elif llm_config["provider"] == "azure":
# Create base parameters
params = {
"azure_endpoint": llm_config["base_url"],
"api_key": llm_config["api_key"],
"azure_deployment": llm_config["deployment"],
"api_version": llm_config["api_version"],
"streaming": True,
}
# Only add temperature if explicitly set
if "temperature" in llm_config:
params["temperature"] = llm_config["temperature"]
return AzureChatOpenAI(**params)
else:
raise ValueError(f"Unsupported provider: {llm_config['provider']}")
def bind_tools(self, tools: List[Dict[str, Any]], force_tool_choice: bool = False):
"""Bind tools to LLM for function calling"""
if force_tool_choice:
# Use tool_choice="required" to force tool calling for DeepSeek
self.llm_with_tools = self.llm.bind_tools(tools, tool_choice="required")
else:
self.llm_with_tools = self.llm.bind_tools(tools)
async def astream(self, messages: list[BaseMessage]) -> AsyncIterator[str]:
"""Stream LLM response tokens"""
try:
async for chunk in self.llm.astream(messages):
if chunk.content and isinstance(chunk.content, str):
yield chunk.content
except Exception as e:
logger.error(f"LLM streaming error: {e}")
raise
async def ainvoke(self, messages: list[BaseMessage]) -> AIMessage:
"""Get complete LLM response"""
try:
response = await self.llm.ainvoke(messages)
if isinstance(response, AIMessage):
return response
else:
# Convert to AIMessage if needed
return AIMessage(content=str(response.content) if response.content else "")
except Exception as e:
logger.error(f"LLM invoke error: {e}")
raise
async def ainvoke_with_tools(self, messages: list[BaseMessage]) -> AIMessage:
"""Get LLM response with tool calling capability"""
try:
if not self.llm_with_tools:
raise ValueError("Tools not bound to LLM. Call bind_tools() first.")
response = await self.llm_with_tools.ainvoke(messages)
if isinstance(response, AIMessage):
return response
else:
return AIMessage(content=str(response.content) if response.content else "")
except Exception as e:
logger.error(f"LLM with tools invoke error: {e}")
raise
def create_messages(self, system_prompt: str, user_prompt: str) -> list[BaseMessage]:
"""Create message list for LLM"""
messages = []
if system_prompt:
messages.append(SystemMessage(content=system_prompt))
messages.append(HumanMessage(content=user_prompt))
return messages

View File

@@ -0,0 +1,187 @@
import asyncio
import logging
from typing import AsyncGenerator
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
from .config import load_config, get_config
from .schemas.messages import ChatRequest
from .memory.postgresql_memory import get_memory_manager
from .graph.state import TurnState, Message
from .graph.graph import build_graph
from .sse import create_error_event
from .utils.error_handler import StructuredLogger, ErrorCategory, ErrorCode, handle_async_errors
from .utils.middleware import ErrorMiddleware
# Setup logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = StructuredLogger(__name__)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Application lifespan manager"""
# Startup
try:
logger.info("Starting application initialization...")
# Initialize PostgreSQL memory manager
memory_manager = get_memory_manager()
connection_ok = memory_manager.test_connection()
logger.info(f"PostgreSQL memory manager initialized (connected: {connection_ok})")
# Initialize global components
app.state.memory_manager = memory_manager
app.state.graph = build_graph()
logger.info("Application startup complete")
yield
except Exception as e:
logger.error(f"Failed to start application: {e}")
raise
finally:
# Shutdown
logger.info("Application shutdown")
def create_app() -> FastAPI:
"""Application factory"""
# Load configuration first
config = load_config()
logger.info(f"Loaded configuration for provider: {config.provider}")
app = FastAPI(
title="Agentic RAG API",
description="Agentic RAG application for manufacturing standards and regulations",
version="0.1.0",
lifespan=lifespan
)
# Add error handling middleware
app.add_middleware(ErrorMiddleware)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=config.app.cors_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Define routes
@app.post("/api/chat")
async def chat_endpoint(request: ChatRequest):
"""Main chat endpoint with SSE streaming"""
try:
return StreamingResponse(
stream_chat_response(request),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Headers": "*",
}
)
except Exception as e:
logger.error(f"Chat endpoint error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/ai-sdk/chat")
async def ai_sdk_chat_endpoint(request: ChatRequest):
"""AI SDK compatible chat endpoint"""
try:
# Import here to avoid circular imports
from .ai_sdk_chat import handle_ai_sdk_chat
return await handle_ai_sdk_chat(request, app.state)
except Exception as e:
logger.error(f"AI SDK chat endpoint error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
async def health_check():
"""Health check endpoint"""
return {"status": "healthy", "service": "agentic-rag"}
@app.get("/")
async def root():
"""Root endpoint"""
return {"message": "Agentic RAG API for Manufacturing Standards & Regulations"}
return app
# Create the global app instance for uvicorn
app = create_app()
@handle_async_errors(ErrorCategory.LLM, ErrorCode.LLM_ERROR)
async def stream_chat_response(request: ChatRequest) -> AsyncGenerator[str, None]:
"""Stream chat response with enhanced error handling"""
config = get_config()
memory_manager = app.state.memory_manager
graph = app.state.graph
# Create conversation state
state = TurnState(session_id=request.session_id)
# Add user message
if request.messages:
last_message = request.messages[-1]
if last_message.get("role") == "user":
user_message = Message(
role="user",
content=last_message.get("content", "")
)
state.messages.append(user_message)
# Create event queue for streaming
event_queue = asyncio.Queue()
async def stream_callback(event_str: str):
await event_queue.put(event_str)
# Execute workflow in background task
async def run_workflow():
try:
async for _ in graph.astream(state, stream_callback):
pass
await event_queue.put(None) # Signal completion
except Exception as e:
logger.error("Workflow execution failed", error=e,
category=ErrorCategory.LLM, error_code=ErrorCode.LLM_ERROR)
await event_queue.put(create_error_event("Processing error: AI service is temporarily unavailable"))
await event_queue.put(None)
# Start workflow task
workflow_task = asyncio.create_task(run_workflow())
# Stream events as they come
try:
while True:
event = await event_queue.get()
if event is None: # Completion signal
break
yield event
finally:
if not workflow_task.done():
workflow_task.cancel()
if __name__ == "__main__":
config = load_config() # Load configuration first
uvicorn.run(
"service.main:app",
host=config.app.host,
port=config.app.port,
reload=True,
log_level="info"
)

View File

@@ -0,0 +1 @@
# Empty __init__.py files to make packages

View File

@@ -0,0 +1,332 @@
"""
PostgreSQL-based memory implementation using LangGraph built-in components.
Provides session-level chat history with 7-day TTL.
Uses psycopg3 for better compatibility without requiring libpq-dev.
"""
import logging
from typing import Dict, Any, Optional
from urllib.parse import quote_plus
from contextlib import contextmanager
try:
import psycopg
from psycopg.rows import dict_row
PSYCOPG_AVAILABLE = True
except ImportError as e:
logging.warning(f"psycopg3 not available: {e}")
PSYCOPG_AVAILABLE = False
psycopg = None
try:
from langgraph.checkpoint.postgres import PostgresSaver
LANGGRAPH_POSTGRES_AVAILABLE = True
except ImportError as e:
logging.warning(f"LangGraph PostgreSQL checkpoint not available: {e}")
LANGGRAPH_POSTGRES_AVAILABLE = False
PostgresSaver = None
try:
from langgraph.checkpoint.memory import InMemorySaver
LANGGRAPH_MEMORY_AVAILABLE = True
except ImportError as e:
logging.warning(f"LangGraph memory checkpoint not available: {e}")
LANGGRAPH_MEMORY_AVAILABLE = False
InMemorySaver = None
from ..config import get_config
logger = logging.getLogger(__name__)
POSTGRES_AVAILABLE = PSYCOPG_AVAILABLE and LANGGRAPH_POSTGRES_AVAILABLE
class PostgreSQLCheckpointerWrapper:
"""
Wrapper for PostgresSaver that manages the context properly.
"""
def __init__(self, conn_string: str):
if not LANGGRAPH_POSTGRES_AVAILABLE or PostgresSaver is None:
raise RuntimeError("PostgresSaver not available")
self.conn_string = conn_string
self._initialized = False
def _ensure_setup(self):
"""Ensure the database schema is set up."""
if not LANGGRAPH_POSTGRES_AVAILABLE or PostgresSaver is None:
raise RuntimeError("PostgresSaver not available")
if not self._initialized:
with PostgresSaver.from_conn_string(self.conn_string) as saver:
saver.setup()
self._initialized = True
logger.info("PostgreSQL schema initialized")
@contextmanager
def get_saver(self):
"""Get a PostgresSaver instance as context manager."""
if not LANGGRAPH_POSTGRES_AVAILABLE or PostgresSaver is None:
raise RuntimeError("PostgresSaver not available")
self._ensure_setup()
with PostgresSaver.from_conn_string(self.conn_string) as saver:
yield saver
def list(self, config):
"""List checkpoints."""
with self.get_saver() as saver:
return list(saver.list(config))
def get(self, config):
"""Get a checkpoint."""
with self.get_saver() as saver:
return saver.get(config)
def get_tuple(self, config):
"""Get a checkpoint tuple."""
with self.get_saver() as saver:
return saver.get_tuple(config)
def put(self, config, checkpoint, metadata, new_versions):
"""Put a checkpoint."""
with self.get_saver() as saver:
return saver.put(config, checkpoint, metadata, new_versions)
def put_writes(self, config, writes, task_id):
"""Put writes."""
with self.get_saver() as saver:
return saver.put_writes(config, writes, task_id)
def get_next_version(self, current, channel):
"""Get next version."""
with self.get_saver() as saver:
return saver.get_next_version(current, channel)
def delete_thread(self, thread_id):
"""Delete thread."""
with self.get_saver() as saver:
return saver.delete_thread(thread_id)
# Async methods
async def alist(self, config):
"""Async list checkpoints."""
with self.get_saver() as saver:
async for item in saver.alist(config):
yield item
async def aget(self, config):
"""Async get a checkpoint."""
with self.get_saver() as saver:
# PostgresSaver might not have async version, try sync first
try:
return await saver.aget(config)
except NotImplementedError:
# Fall back to sync version in a thread
import asyncio
return await asyncio.get_event_loop().run_in_executor(
None, saver.get, config
)
async def aget_tuple(self, config):
"""Async get a checkpoint tuple."""
with self.get_saver() as saver:
# PostgresSaver might not have async version, try sync first
try:
return await saver.aget_tuple(config)
except NotImplementedError:
# Fall back to sync version in a thread
import asyncio
return await asyncio.get_event_loop().run_in_executor(
None, saver.get_tuple, config
)
async def aput(self, config, checkpoint, metadata, new_versions):
"""Async put a checkpoint."""
with self.get_saver() as saver:
# PostgresSaver might not have async version, try sync first
try:
return await saver.aput(config, checkpoint, metadata, new_versions)
except NotImplementedError:
# Fall back to sync version in a thread
import asyncio
return await asyncio.get_event_loop().run_in_executor(
None, saver.put, config, checkpoint, metadata, new_versions
)
async def aput_writes(self, config, writes, task_id):
"""Async put writes."""
with self.get_saver() as saver:
# PostgresSaver might not have async version, try sync first
try:
return await saver.aput_writes(config, writes, task_id)
except NotImplementedError:
# Fall back to sync version in a thread
import asyncio
return await asyncio.get_event_loop().run_in_executor(
None, saver.put_writes, config, writes, task_id
)
async def adelete_thread(self, thread_id):
"""Async delete thread."""
with self.get_saver() as saver:
return await saver.adelete_thread(thread_id)
@property
def config_specs(self):
"""Get config specs."""
with self.get_saver() as saver:
return saver.config_specs
@property
def serde(self):
"""Get serde."""
with self.get_saver() as saver:
return saver.serde
class PostgreSQLMemoryManager:
"""
PostgreSQL-based memory manager using LangGraph's built-in components.
Falls back to in-memory storage if PostgreSQL is not available.
"""
def __init__(self):
self.config = get_config()
self.pg_config = self.config.postgresql
self._checkpointer: Optional[Any] = None
self._postgres_available = POSTGRES_AVAILABLE
def _get_connection_string(self) -> str:
"""Get PostgreSQL connection string."""
if not self._postgres_available:
return ""
# URL encode password to handle special characters
encoded_password = quote_plus(self.pg_config.password)
return (
f"postgresql://{self.pg_config.username}:{encoded_password}@"
f"{self.pg_config.host}:{self.pg_config.port}/{self.pg_config.database}"
)
def _test_connection(self) -> bool:
"""Test PostgreSQL connection."""
if not self._postgres_available:
return False
if not PSYCOPG_AVAILABLE or psycopg is None:
return False
try:
conn_string = self._get_connection_string()
with psycopg.connect(conn_string) as conn:
with conn.cursor() as cur:
cur.execute("SELECT 1")
result = cur.fetchone()
logger.info("PostgreSQL connection test successful")
return True
except Exception as e:
logger.error(f"PostgreSQL connection test failed: {e}")
return False
def _setup_ttl_cleanup(self):
"""Setup TTL cleanup for old records."""
if not self._postgres_available or not PSYCOPG_AVAILABLE or psycopg is None:
return
try:
conn_string = self._get_connection_string()
with psycopg.connect(conn_string, autocommit=True) as conn:
with conn.cursor() as cur:
# Create a function to clean up old records for LangGraph tables
# Note: LangGraph tables don't have created_at, so we'll use a different approach
cleanup_sql = f"""
CREATE OR REPLACE FUNCTION cleanup_old_checkpoints()
RETURNS void AS $$
BEGIN
-- LangGraph tables don't have created_at columns
-- We can clean based on checkpoint_id pattern or use a different strategy
-- For now, just return successfully without actual cleanup
-- You can implement custom logic based on your requirements
RAISE NOTICE 'Cleanup function called - custom cleanup logic needed';
END;
$$ LANGUAGE plpgsql;
"""
cur.execute(cleanup_sql)
logger.info(f"TTL cleanup function created with {self.pg_config.ttl_days}-day retention")
except Exception as e:
logger.warning(f"Failed to setup TTL cleanup (this is optional): {e}")
def cleanup_old_data(self):
"""Manually trigger cleanup of old data."""
if not self._postgres_available or not PSYCOPG_AVAILABLE or psycopg is None:
return
try:
conn_string = self._get_connection_string()
with psycopg.connect(conn_string, autocommit=True) as conn:
with conn.cursor() as cur:
cur.execute("SELECT cleanup_old_checkpoints()")
logger.info("Manual cleanup of old data completed")
except Exception as e:
logger.error(f"Failed to cleanup old data: {e}")
def get_checkpointer(self):
"""Get checkpointer for conversation history (PostgreSQL if available, else in-memory)."""
if self._checkpointer is None:
if self._postgres_available:
try:
# Test connection first
if not self._test_connection():
raise Exception("PostgreSQL connection test failed")
# Setup TTL cleanup function
self._setup_ttl_cleanup()
# Create checkpointer wrapper
conn_string = self._get_connection_string()
if LANGGRAPH_POSTGRES_AVAILABLE:
self._checkpointer = PostgreSQLCheckpointerWrapper(conn_string)
else:
raise Exception("LangGraph PostgreSQL checkpoint not available")
logger.info(f"PostgreSQL checkpointer initialized with {self.pg_config.ttl_days}-day TTL")
except Exception as e:
logger.error(f"Failed to initialize PostgreSQL checkpointer, falling back to in-memory: {e}")
if LANGGRAPH_MEMORY_AVAILABLE and InMemorySaver is not None:
self._checkpointer = InMemorySaver()
else:
logger.error("InMemorySaver not available - no checkpointer available")
self._checkpointer = None
else:
logger.info("PostgreSQL not available, using in-memory checkpointer")
if LANGGRAPH_MEMORY_AVAILABLE and InMemorySaver is not None:
self._checkpointer = InMemorySaver()
else:
logger.error("InMemorySaver not available - no checkpointer available")
self._checkpointer = None
return self._checkpointer
def test_connection(self) -> bool:
"""Test PostgreSQL connection and return True if successful."""
return self._test_connection()
# Global memory manager instance
_memory_manager: Optional[PostgreSQLMemoryManager] = None
def get_memory_manager() -> PostgreSQLMemoryManager:
"""Get global PostgreSQL memory manager instance."""
global _memory_manager
if _memory_manager is None:
_memory_manager = PostgreSQLMemoryManager()
return _memory_manager
def get_checkpointer():
"""Get checkpointer for conversation history."""
return get_memory_manager().get_checkpointer()

View File

@@ -0,0 +1,137 @@
"""
Redis-based memory implementation using LangGraph built-in components.
Provides session-level chat history with 7-day TTL.
"""
import logging
import ssl
from typing import Dict, Any, Optional
try:
import redis
from redis.exceptions import ConnectionError, TimeoutError
from langgraph.checkpoint.redis import RedisSaver
REDIS_AVAILABLE = True
except ImportError as e:
logging.warning(f"Redis packages not available: {e}")
REDIS_AVAILABLE = False
redis = None
RedisSaver = None
from langgraph.checkpoint.memory import InMemorySaver
from ..config import get_config
logger = logging.getLogger(__name__)
class RedisMemoryManager:
"""
Redis-based memory manager using LangGraph's built-in components.
Falls back to in-memory storage if Redis is not available.
"""
def __init__(self):
self.config = get_config()
self.redis_config = self.config.redis
self._checkpointer: Optional[Any] = None
self._redis_available = REDIS_AVAILABLE
def _get_redis_client_kwargs(self) -> Dict[str, Any]:
"""Get Redis client configuration for Azure Redis compatibility."""
if not self._redis_available:
return {}
kwargs = {
"host": self.redis_config.host,
"port": self.redis_config.port,
"password": self.redis_config.password,
"db": self.redis_config.db,
"decode_responses": False, # Required for RedisSaver
"socket_timeout": 30,
"socket_connect_timeout": 10,
"retry_on_timeout": True,
"health_check_interval": 30,
}
if self.redis_config.use_ssl:
kwargs.update({
"ssl": True,
"ssl_cert_reqs": ssl.CERT_REQUIRED,
"ssl_check_hostname": True,
})
return kwargs
def _get_ttl_config(self) -> Dict[str, Any]:
"""Get TTL configuration for automatic cleanup."""
ttl_days = self.redis_config.ttl_days
ttl_minutes = ttl_days * 24 * 60 # Convert days to minutes
return {
"default_ttl": ttl_minutes,
"refresh_on_read": True, # Refresh TTL when accessed
}
def get_checkpointer(self):
"""Get checkpointer for conversation history (Redis if available, else in-memory)."""
if self._checkpointer is None:
if self._redis_available:
try:
ttl_config = self._get_ttl_config()
# Create Redis client with proper configuration for Azure Redis
redis_client = redis.Redis(**self._get_redis_client_kwargs())
# Test connection
redis_client.ping()
logger.info("Redis connection established successfully")
# Create checkpointer with TTL support
self._checkpointer = RedisSaver(
redis_client=redis_client,
ttl=ttl_config
)
# Initialize indices (required for first-time setup)
self._checkpointer.setup()
logger.info(f"Redis checkpointer initialized with {self.redis_config.ttl_days}-day TTL")
except Exception as e:
logger.error(f"Failed to initialize Redis checkpointer, falling back to in-memory: {e}")
self._checkpointer = InMemorySaver()
else:
logger.info("Redis not available, using in-memory checkpointer")
self._checkpointer = InMemorySaver()
return self._checkpointer
def test_connection(self) -> bool:
"""Test Redis connection and return True if successful."""
if not self._redis_available:
logger.warning("Redis packages not available")
return False
try:
redis_client = redis.Redis(**self._get_redis_client_kwargs())
redis_client.ping()
logger.info("Redis connection test successful")
return True
except Exception as e:
logger.error(f"Redis connection test failed: {e}")
return False
# Global memory manager instance
_memory_manager: Optional[RedisMemoryManager] = None
def get_memory_manager() -> RedisMemoryManager:
"""Get global Redis memory manager instance."""
global _memory_manager
if _memory_manager is None:
_memory_manager = RedisMemoryManager()
return _memory_manager
def get_checkpointer():
"""Get checkpointer for conversation history."""
return get_memory_manager().get_checkpointer()

View File

@@ -0,0 +1,113 @@
from typing import Dict, Any, Optional
from datetime import datetime, timedelta
import logging
from .postgresql_memory import get_memory_manager, get_checkpointer
from ..graph.state import TurnState, Message
logger = logging.getLogger(__name__)
class InMemoryStore:
"""Simple in-memory store with TTL for conversation history"""
def __init__(self, ttl_days: float = 7.0):
self.ttl_days = ttl_days
self.store: Dict[str, Dict[str, Any]] = {}
def _is_expired(self, timestamp: datetime) -> bool:
"""Check if a record has expired"""
return datetime.now() - timestamp > timedelta(days=self.ttl_days)
def _cleanup_expired(self) -> None:
"""Remove expired records"""
expired_keys = []
for session_id, data in self.store.items():
if self._is_expired(data.get("last_updated", datetime.min)):
expired_keys.append(session_id)
for key in expired_keys:
del self.store[key]
logger.info(f"Cleaned up expired session: {key}")
def get(self, session_id: str) -> Optional[TurnState]:
"""Get conversation state for a session"""
self._cleanup_expired()
if session_id not in self.store:
return None
data = self.store[session_id]
if self._is_expired(data.get("last_updated", datetime.min)):
del self.store[session_id]
return None
try:
# Reconstruct TurnState from stored data
state_data = data["state"]
return TurnState(**state_data)
except Exception as e:
logger.error(f"Failed to deserialize state for session {session_id}: {e}")
return None
def put(self, session_id: str, state: TurnState) -> None:
"""Store conversation state for a session"""
try:
self.store[session_id] = {
"state": state.model_dump(),
"last_updated": datetime.now()
}
logger.debug(f"Stored state for session: {session_id}")
except Exception as e:
logger.error(f"Failed to store state for session {session_id}: {e}")
def trim(self, session_id: str, max_messages: int = 20) -> None:
"""Trim old messages to stay within token limits"""
state = self.get(session_id)
if not state:
return
if len(state.messages) > max_messages:
# Keep system message (if any) and recent user/assistant pairs
trimmed_messages = state.messages[-max_messages:]
# Try to preserve complete conversation turns
if len(trimmed_messages) > 1 and trimmed_messages[0].role == "assistant":
trimmed_messages = trimmed_messages[1:]
state.messages = trimmed_messages
self.put(session_id, state)
logger.info(f"Trimmed messages for session {session_id} to {len(trimmed_messages)}")
def create_new_session(self, session_id: str) -> TurnState:
"""Create a new conversation session"""
state = TurnState(session_id=session_id)
self.put(session_id, state)
return state
def add_message(self, session_id: str, message: Message) -> None:
"""Add a message to the conversation history"""
state = self.get(session_id)
if not state:
state = self.create_new_session(session_id)
state.messages.append(message)
self.put(session_id, state)
def get_conversation_history(self, session_id: str, max_turns: int = 10) -> str:
"""Get formatted conversation history for prompts"""
state = self.get(session_id)
if not state or not state.messages:
return ""
# Get recent messages, keeping complete turns
recent_messages = state.messages[-(max_turns * 2):]
history_parts = []
for msg in recent_messages:
if msg.role == "user":
history_parts.append(f"User: {msg.content}")
elif msg.role == "assistant" and not msg.tool_call_id:
history_parts.append(f"Assistant: {msg.content}")
return "\n".join(history_parts)

View File

@@ -0,0 +1 @@
# Empty __init__.py files to make packages

View File

@@ -0,0 +1,181 @@
"""
Azure AI Search client utilities for retrieval operations.
Contains shared functionality for interacting with Azure AI Search and embedding services.
"""
import httpx
import logging
from typing import Dict, Any, List, Optional
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
from ..config import get_config
logger = logging.getLogger(__name__)
class RetrievalAPIError(Exception):
"""Custom exception for retrieval API errors"""
pass
class AzureSearchClient:
"""Shared Azure AI Search client for embedding and search operations"""
def __init__(self):
self.config = get_config()
self.search_endpoint = self.config.retrieval.endpoint
self.api_key = self.config.retrieval.api_key
self.api_version = self.config.retrieval.api_version
self.semantic_configuration = self.config.retrieval.semantic_configuration
self.embedding_client = httpx.AsyncClient(timeout=30.0)
self.search_client = httpx.AsyncClient(timeout=30.0)
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.embedding_client.aclose()
await self.search_client.aclose()
async def get_embedding(self, text: str) -> List[float]:
"""Get embedding vector for text using the configured embedding service"""
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.config.retrieval.embedding.api_key}"
}
payload = {
"input": text,
"model": self.config.retrieval.embedding.model
}
try:
req_url = f"{self.config.retrieval.embedding.base_url}/embeddings"
if self.config.retrieval.embedding.api_version:
req_url += f"?api-version={self.config.retrieval.embedding.api_version}"
response = await self.embedding_client.post(req_url, json=payload, headers=headers)
response.raise_for_status()
result = response.json()
return result["data"][0]["embedding"]
except Exception as e:
logger.error(f"Failed to get embedding: {e}")
raise RetrievalAPIError(f"Embedding generation failed: {str(e)}")
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((httpx.HTTPStatusError, httpx.TimeoutException))
)
async def search_azure_ai(
self,
index_name: str,
search_text: str,
vector_fields: str,
select_fields: str,
search_fields: str,
filter_query: Optional[str] = None,
top_k: int = 10,
score_threshold: float = 1.5
) -> Dict[str, Any]:
"""Make hybrid search request to Azure AI Search with semantic ranking"""
# Get embedding vector for the query
query_vector = await self.get_embedding(search_text)
# Build vector queries based on the vector fields
vector_queries = []
for field in vector_fields.split(","):
field = field.strip()
vector_queries.append({
"kind": "vector",
"vector": query_vector,
"fields": field,
"k": top_k
})
# Build the search request payload
search_payload = {
"search": search_text,
"select": select_fields,
"searchFields": search_fields,
"top": top_k,
"queryType": "semantic",
"semanticConfiguration": self.semantic_configuration,
"vectorQueries": vector_queries
}
if filter_query:
search_payload["filter"] = filter_query
headers = {
"Content-Type": "application/json",
"api-key": self.api_key
}
search_url = f"{self.search_endpoint}/indexes/{index_name}/docs/search"
try:
response = await self.search_client.post(
search_url,
json=search_payload,
headers=headers,
params={"api-version": self.api_version}
)
response.raise_for_status()
result = response.json()
# Filter results by reranker score and add order numbers
filtered_results = []
for i, item in enumerate(result.get("value", [])):
reranker_score = item.get("@search.rerankerScore", 0)
if reranker_score >= score_threshold:
# Add order number
item["@order_num"] = i + 1
# Normalize the result (removes unwanted fields and empty values)
normalized_item = normalize_search_result(item)
filtered_results.append(normalized_item)
return {"value": filtered_results}
except httpx.HTTPStatusError as e:
logger.error(f"Azure AI Search HTTP error {e.response.status_code}: {e.response.text}")
raise RetrievalAPIError(f"Azure AI Search request failed: {e.response.status_code}")
except httpx.TimeoutException:
logger.error("Azure AI Search request timeout")
raise RetrievalAPIError("Azure AI Search request timeout")
except Exception as e:
logger.error(f"Azure AI Search unexpected error: {e}")
raise RetrievalAPIError(f"Azure AI Search unexpected error: {str(e)}")
def normalize_search_result(raw_result: Dict[str, Any]) -> Dict[str, Any]:
"""
Normalize raw Azure AI Search result to clean dynamic structure
Args:
raw_result: Raw result from Azure AI Search
Returns:
Cleaned and normalized result dictionary
"""
# Fields to remove if they exist (belt and suspenders approach)
fields_to_remove = {
"@search.score",
"@search.rerankerScore",
"@search.captions",
"@subquery_id"
}
# Create a copy and remove unwanted fields
result = raw_result.copy()
for field in fields_to_remove:
result.pop(field, None)
# Remove empty fields (None, empty string, empty list, empty dict)
result = {
key: value for key, value in result.items()
if value is not None and value != "" and value != [] and value != {}
}
return result

View File

@@ -0,0 +1,58 @@
import logging
import time
from ..config import get_config
from service.retrieval.clients import AzureSearchClient
from service.retrieval.model import RetrievalResponse
logger = logging.getLogger(__name__)
class GenericChunkRetrieval:
def __init__(self)->None:
self.config = get_config()
self.search_client = AzureSearchClient()
async def retrieve_doc_chunk(
self,
query: str,
conversation_history: str = "",
**kwargs
) -> RetrievalResponse:
"""Search CATOnline system user manual document chunks"""
start_time = time.time()
# Use the new Azure AI Search approach
index_name = self.config.retrieval.index.chunk_user_manual_index
vector_fields = "contentVector"
select_fields = "content, title, full_headers"
search_fields = "content, title, full_headers"
top_k = kwargs.get("top_k", 10)
score_threshold = kwargs.get("score_threshold", 1.5)
try:
response_data = await self.search_client.search_azure_ai(
index_name=index_name,
search_text=query,
vector_fields=vector_fields,
select_fields=select_fields,
search_fields=search_fields,
top_k=top_k,
score_threshold=score_threshold
)
results = response_data.get("value", [])
took_ms = int((time.time() - start_time) * 1000)
return RetrievalResponse(
results=results,
took_ms=took_ms,
total_count=len(results)
)
except Exception as e:
logger.error(f"retrieve_doc_chunk_user_manual failed: {e}")
raise

View File

@@ -0,0 +1,11 @@
from typing import Any, Optional
from pydantic import BaseModel
class RetrievalResponse(BaseModel):
"""Simple response container for tool results"""
results: list[dict[str, Any]]
took_ms: Optional[int] = None
total_count: Optional[int] = None

View File

@@ -0,0 +1,158 @@
import httpx
import time
import json
from typing import Dict, Any, List, Optional
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
import logging
from .model import RetrievalResponse
from ..config import get_config
from .clients import AzureSearchClient, RetrievalAPIError
logger = logging.getLogger(__name__)
class AgenticRetrieval:
"""Azure AI Search client for retrieval tools"""
def __init__(self):
self.config = get_config()
self.search_client = AzureSearchClient()
async def __aenter__(self):
await self.search_client.__aenter__()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.search_client.__aexit__(exc_type, exc_val, exc_tb)
async def retrieve_standard_regulation(
self,
query: str,
conversation_history: str = "",
**kwargs
) -> RetrievalResponse:
"""Search standard/regulation attributes"""
start_time = time.time()
# Use the new Azure AI Search approach
index_name = self.config.retrieval.index.standard_regulation_index
vector_fields = "full_metadata_vector"
select_fields = "id, func_uuid, title, publisher, document_category, document_code, x_Standard_Regulation_Id, x_Attachment_Type, x_Standard_Title_CN, x_Standard_Title_EN, x_Standard_Published_State, x_Standard_Drafting_Status, x_Standard_Published_State_EN, x_Standard_Drafting_Status_EN, x_Standard_Range, x_Standard_Kind, x_Standard_No, x_Standard_Technical_Committee, x_Standard_Vehicle_Type, x_Standard_Power_Type, x_Standard_CCS, x_Standard_ICS, x_Standard_Published_Date, x_Standard_Effective_Date, x_Regulation_Status, x_Regulation_Status_EN, x_Regulation_Title_CN, x_Regulation_Title_EN, x_Regulation_Document_No, x_Regulation_Issued_Date, x_Classification, x_Work_Group, x_Reference_Standard, x_Replaced_by, x_Refer_To, update_time, status"
search_fields = "title, publisher, document_category, document_code, x_Standard_Regulation_Id, x_Attachment_Type, x_Standard_Title_CN, x_Standard_Title_EN, x_Standard_Published_State, x_Standard_Drafting_Status, x_Standard_Published_State_EN, x_Standard_Drafting_Status_EN, x_Standard_Range, x_Standard_Kind, x_Standard_No, x_Standard_Technical_Committee, x_Standard_Vehicle_Type, x_Standard_Power_Type, x_Standard_CCS, x_Standard_ICS, x_Standard_Published_Date, x_Standard_Effective_Date, x_Regulation_Status, x_Regulation_Status_EN, x_Regulation_Title_CN, x_Regulation_Title_EN, x_Regulation_Document_No, x_Regulation_Issued_Date, x_Classification, x_Work_Group, x_Reference_Standard, x_Replaced_by, x_Refer_To, update_time, status"
top_k = kwargs.get("top_k", 10)
score_threshold = kwargs.get("score_threshold", 1.5)
try:
response_data = await self.search_client.search_azure_ai(
index_name=index_name,
search_text=query,
vector_fields=vector_fields,
select_fields=select_fields,
search_fields=search_fields,
top_k=top_k,
score_threshold=score_threshold
)
results = response_data.get("value", [])
took_ms = int((time.time() - start_time) * 1000)
return RetrievalResponse(
results=results,
took_ms=took_ms,
total_count=len(results)
)
except Exception as e:
logger.error(f"retrieve_standard_regulation failed: {e}")
raise
async def retrieve_doc_chunk_standard_regulation(
self,
query: str,
conversation_history: str = "",
**kwargs
) -> RetrievalResponse:
"""Search standard/regulation document chunks"""
start_time = time.time()
# Use the new Azure AI Search approach
index_name = self.config.retrieval.index.chunk_index
vector_fields = "contentVector, full_metadata_vector"
select_fields = "content, title, full_headers, document_code, document_category, publisher, x_Regulation_Title_CN, x_Regulation_Title_EN, x_Standard_Title_CN, x_Standard_Title_EN, x_Standard_Kind, x_Standard_CCS, x_Standard_ICS, x_Standard_Vehicle_Type, x_Standard_Power_Type, id, metadata, func_uuid, filepath, x_Standard_Regulation_Id"
search_fields = "content, title, full_headers, document_code, document_category, publisher, x_Regulation_Title_CN, x_Regulation_Title_EN, x_Standard_Title_CN, x_Standard_Title_EN, x_Standard_Kind, x_Standard_CCS, x_Standard_ICS, x_Standard_Vehicle_Type, x_Standard_Power_Type"
filter_query = "(document_category eq 'Standard' or document_category eq 'Regulation') and (status eq '已发布') and (x_Standard_Published_State_EN eq 'Effective' or x_Standard_Published_State_EN eq 'Publication' or x_Standard_Published_State_EN eq 'Implementation' or x_Regulation_Status_EN eq 'Publication' or x_Regulation_Status_EN eq 'Implementation') and (x_Attachment_Type eq '标准附件(PUBLISHED_STANDARDS)' or x_Attachment_Type eq '已发布法规附件(ISSUED_REGULATION)')"
top_k = kwargs.get("top_k", 10)
score_threshold = kwargs.get("score_threshold", 1.5)
try:
response_data = await self.search_client.search_azure_ai(
index_name=index_name,
search_text=query,
vector_fields=vector_fields,
select_fields=select_fields,
search_fields=search_fields,
filter_query=filter_query,
top_k=top_k,
score_threshold=score_threshold
)
results = response_data.get("value", [])
took_ms = int((time.time() - start_time) * 1000)
return RetrievalResponse(
results=results,
took_ms=took_ms,
total_count=len(results)
)
except Exception as e:
logger.error(f"retrieve_doc_chunk_standard_regulation failed: {e}")
raise
async def retrieve_doc_chunk_user_manual(
self,
query: str,
conversation_history: str = "",
**kwargs
) -> RetrievalResponse:
"""Search CATOnline system user manual document chunks"""
start_time = time.time()
# Use the new Azure AI Search approach
index_name = self.config.retrieval.index.chunk_user_manual_index
vector_fields = "contentVector"
select_fields = "content, title, full_headers"
search_fields = "content, title, full_headers"
top_k = kwargs.get("top_k", 10)
score_threshold = kwargs.get("score_threshold", 1.5)
try:
response_data = await self.search_client.search_azure_ai(
index_name=index_name,
search_text=query,
vector_fields=vector_fields,
select_fields=select_fields,
search_fields=search_fields,
top_k=top_k,
score_threshold=score_threshold
)
results = response_data.get("value", [])
took_ms = int((time.time() - start_time) * 1000)
return RetrievalResponse(
results=results,
took_ms=took_ms,
total_count=len(results)
)
except Exception as e:
logger.error(f"retrieve_doc_chunk_user_manual failed: {e}")
raise

View File

@@ -0,0 +1 @@
# Empty __init__.py files to make packages

View File

@@ -0,0 +1,34 @@
from typing import Dict, Any, Optional
from pydantic import BaseModel
class UserMessage(BaseModel):
content: str
timestamp: Optional[str] = None
class AssistantMessage(BaseModel):
content: str
citations_mapping_csv: Optional[str] = None
timestamp: Optional[str] = None
class ToolMessage(BaseModel):
tool_name: str
tool_call_id: str
content: str # Usually JSON string of results
timestamp: Optional[str] = None
class ChatRequest(BaseModel):
session_id: str
messages: list[Dict[str, Any]]
client_hints: Optional[Dict[str, Any]] = None
class ChatResponse(BaseModel):
"""Base response for non-streaming endpoints"""
answer: str
citations_mapping_csv: str
tool_results: list[Dict[str, Any]]
session_id: str

View File

@@ -0,0 +1,72 @@
import json
from typing import AsyncGenerator, Dict, Any
def format_sse_event(event: str, data: Dict[str, Any]) -> str:
"""Format data as Server-Sent Events"""
return f"event: {event}\ndata: {json.dumps(data)}\n\n"
async def send_heartbeat() -> AsyncGenerator[str, None]:
"""Send periodic heartbeat to keep connection alive"""
while True:
yield format_sse_event("heartbeat", {"timestamp": "now"})
# In practice, you'd use asyncio.sleep but this is for demo
break
def create_token_event(delta: str, tool_call_id: str | None = None) -> str:
"""Create a token streaming event"""
return format_sse_event("tokens", {
"delta": delta,
"tool_call_id": tool_call_id
})
def create_tool_start_event(tool_id: str, name: str, args: Dict[str, Any]) -> str:
"""Create a tool start event"""
return format_sse_event("tool_start", {
"id": tool_id,
"name": name,
"args": args
})
def create_tool_progress_event(tool_id: str, message: str) -> str:
"""Create a tool progress event"""
return format_sse_event("tool_progress", {
"id": tool_id,
"message": message
})
def create_tool_result_event(tool_id: str, name: str, results: list, took_ms: int) -> str:
"""Create a tool result event"""
return format_sse_event("tool_result", {
"id": tool_id,
"name": name,
"results": results,
"took_ms": took_ms
})
def create_tool_error_event(tool_id: str, name: str, error: str) -> str:
"""Create a tool error event"""
return format_sse_event("tool_error", {
"id": tool_id,
"name": name,
"error": error
})
# def create_agent_done_event() -> str:
# """Create agent completion event"""
# return format_sse_event("agent_done", {"answer_done": True})
def create_error_event(error: str, details: Dict[str, Any] | None = None) -> str:
"""Create an error event"""
event_data: Dict[str, Any] = {"error": error}
if details:
event_data["details"] = details
return format_sse_event("error", event_data)

View File

@@ -0,0 +1 @@
# Empty __init__.py to make this a package

View File

@@ -0,0 +1,165 @@
"""
DRY Error Handling and Logging Utilities
"""
import json
import logging
import traceback
from datetime import datetime, timezone
from enum import Enum
from typing import Any, Dict, Optional, Callable
from functools import wraps
from ..sse import create_error_event, create_tool_error_event
class ErrorCode(Enum):
"""Error codes for different types of failures"""
# Client errors (4xxx)
INVALID_REQUEST = 4001
MISSING_PARAMETERS = 4002
INVALID_SESSION = 4003
# Server errors (5xxx)
LLM_ERROR = 5001
TOOL_ERROR = 5002
DATABASE_ERROR = 5003
MEMORY_ERROR = 5004
EXTERNAL_API_ERROR = 5005
INTERNAL_ERROR = 5000
class ErrorCategory(Enum):
"""Error categories for better organization"""
VALIDATION = "validation"
LLM = "llm"
TOOL = "tool"
DATABASE = "database"
MEMORY = "memory"
EXTERNAL_API = "external_api"
INTERNAL = "internal"
class StructuredLogger:
"""DRY structured logging with automatic error handling"""
def __init__(self, name: str):
self.logger = logging.getLogger(name)
def error(self, msg: str, error: Optional[Exception] = None, category: ErrorCategory = ErrorCategory.INTERNAL,
error_code: ErrorCode = ErrorCode.INTERNAL_ERROR, extra: Optional[Dict[str, Any]] = None):
"""Log structured error with stack trace"""
data: Dict[str, Any] = {
"message": msg,
"category": category.value,
"error_code": error_code.value,
"timestamp": datetime.now(timezone.utc).isoformat()
}
if error:
data.update({
"error_type": type(error).__name__,
"error_message": str(error),
"stack_trace": traceback.format_exc()
})
if extra:
data["extra"] = extra
self.logger.error(json.dumps(data))
def info(self, msg: str, extra: Optional[Dict[str, Any]] = None):
"""Log structured info"""
data: Dict[str, Any] = {"message": msg, "timestamp": datetime.now(timezone.utc).isoformat()}
if extra:
data["extra"] = extra
self.logger.info(json.dumps(data))
def warning(self, msg: str, extra: Optional[Dict[str, Any]] = None):
"""Log structured warning"""
data: Dict[str, Any] = {"message": msg, "timestamp": datetime.now(timezone.utc).isoformat()}
if extra:
data["extra"] = extra
self.logger.warning(json.dumps(data))
def get_user_message(category: ErrorCategory) -> str:
"""Get user-friendly error messages in English"""
messages = {
ErrorCategory.VALIDATION: "Invalid request parameters. Please check your input.",
ErrorCategory.LLM: "AI service is temporarily unavailable. Please try again later.",
ErrorCategory.TOOL: "Tool execution failed. Please retry your request.",
ErrorCategory.DATABASE: "Database service is temporarily unavailable.",
ErrorCategory.MEMORY: "Session storage issue occurred. Please refresh the page.",
ErrorCategory.EXTERNAL_API: "External service connection failed.",
ErrorCategory.INTERNAL: "Internal server error. We are working to resolve this."
}
return messages.get(category, "Unknown error occurred. Please contact technical support.")
def handle_async_errors(category: ErrorCategory, error_code: ErrorCode,
stream_callback: Optional[Callable] = None, tool_id: Optional[str] = None):
"""DRY decorator for async error handling with streaming support"""
def decorator(func):
@wraps(func)
async def wrapper(*args, **kwargs):
logger = StructuredLogger(func.__module__)
try:
return await func(*args, **kwargs)
except Exception as e:
user_msg = get_user_message(category)
logger.error(
f"Error in {func.__name__}: {str(e)}",
error=e,
category=category,
error_code=error_code,
extra={"function": func.__name__, "args_count": len(args)}
)
# Send error event if streaming
if stream_callback:
if tool_id:
await stream_callback(create_tool_error_event(tool_id, func.__name__, user_msg))
else:
await stream_callback(create_error_event(user_msg))
# Re-raise with user-friendly message for API responses
raise Exception(user_msg) from e
return wrapper
return decorator
def handle_sync_errors(category: ErrorCategory, error_code: ErrorCode):
"""DRY decorator for sync error handling"""
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
logger = StructuredLogger(func.__module__)
try:
return func(*args, **kwargs)
except Exception as e:
logger.error(
f"Error in {func.__name__}: {str(e)}",
error=e,
category=category,
error_code=error_code,
extra={"function": func.__name__}
)
raise Exception(get_user_message(category)) from e
return wrapper
return decorator
def create_error_response(category: ErrorCategory, error_code: ErrorCode,
technical_msg: Optional[str] = None) -> Dict[str, Any]:
"""Create consistent error response format"""
return {
"user_message": get_user_message(category),
"error_code": error_code.value,
"category": category.value,
"technical_message": technical_msg,
"timestamp": datetime.now(timezone.utc).isoformat()
}

View File

@@ -0,0 +1,94 @@
import logging
import json
import time
from typing import Dict, Any, Optional
from datetime import datetime
def setup_logging(level: str = "INFO", format_type: str = "json") -> None:
"""Setup structured logging"""
if format_type == "json":
formatter = JsonFormatter()
else:
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
handler = logging.StreamHandler()
handler.setFormatter(formatter)
root_logger = logging.getLogger()
root_logger.setLevel(getattr(logging, level.upper()))
root_logger.addHandler(handler)
class JsonFormatter(logging.Formatter):
"""JSON log formatter"""
def format(self, record: logging.LogRecord) -> str:
log_data = {
"timestamp": datetime.utcnow().isoformat(),
"level": record.levelname,
"logger": record.name,
"message": record.getMessage(),
}
# Add extra fields
if hasattr(record, "request_id"):
log_data["request_id"] = getattr(record, "request_id")
if hasattr(record, "session_id"):
log_data["session_id"] = getattr(record, "session_id")
if hasattr(record, "duration_ms"):
log_data["duration_ms"] = getattr(record, "duration_ms")
return json.dumps(log_data)
class Timer:
"""Simple timer context manager"""
def __init__(self):
self.start_time = None
self.end_time = None
def __enter__(self):
self.start_time = time.time()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.end_time = time.time()
@property
def elapsed_ms(self) -> int:
if self.start_time and self.end_time:
return int((self.end_time - self.start_time) * 1000)
return 0
def redact_secrets(data: Dict[str, Any], secret_keys: list[str] | None = None) -> Dict[str, Any]:
"""Redact sensitive information from logs"""
if secret_keys is None:
secret_keys = ["api_key", "password", "token", "secret", "key"]
redacted = {}
for key, value in data.items():
if any(secret in key.lower() for secret in secret_keys):
redacted[key] = "***REDACTED***"
elif isinstance(value, dict):
redacted[key] = redact_secrets(value, secret_keys)
else:
redacted[key] = value
return redacted
def generate_request_id() -> str:
"""Generate unique request ID"""
return f"req_{int(time.time() * 1000)}_{hash(time.time()) % 10000:04d}"
def truncate_text(text: str, max_length: int = 1000, suffix: str = "...") -> str:
"""Truncate text to maximum length"""
if len(text) <= max_length:
return text
return text[:max_length - len(suffix)] + suffix

View File

@@ -0,0 +1,51 @@
"""
Lightweight Error Handling Middleware
"""
from fastapi import Request, HTTPException
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
from .error_handler import StructuredLogger, ErrorCategory, ErrorCode, create_error_response
class ErrorMiddleware(BaseHTTPMiddleware):
"""Concise error handling middleware following DRY principles"""
def __init__(self, app):
super().__init__(app)
self.logger = StructuredLogger(__name__)
async def dispatch(self, request: Request, call_next):
try:
return await call_next(request)
except HTTPException as e:
# HTTP exceptions - map to appropriate categories
category = ErrorCategory.VALIDATION if e.status_code < 500 else ErrorCategory.INTERNAL
error_code = ErrorCode.INVALID_REQUEST if e.status_code < 500 else ErrorCode.INTERNAL_ERROR
self.logger.error(
f"HTTP {e.status_code}: {e.detail}",
category=category,
error_code=error_code,
extra={"path": str(request.url), "method": request.method}
)
return JSONResponse(
status_code=e.status_code,
content=create_error_response(category, error_code, e.detail)
)
except Exception as e:
# Unexpected errors
self.logger.error(
f"Unhandled error: {str(e)}",
error=e,
category=ErrorCategory.INTERNAL,
error_code=ErrorCode.INTERNAL_ERROR,
extra={"path": str(request.url), "method": request.method}
)
return JSONResponse(
status_code=500,
content=create_error_response(ErrorCategory.INTERNAL, ErrorCode.INTERNAL_ERROR)
)

View File

@@ -0,0 +1,103 @@
"""
Template utilities for Jinja2 template rendering with LangChain
"""
import logging
from typing import Dict, Any
from jinja2 import Environment, BaseLoader, TemplateError
logger = logging.getLogger(__name__)
class TemplateRenderer:
"""Jinja2 template renderer for LLM prompts"""
def __init__(self):
self.env = Environment(
loader=BaseLoader(),
# Enable safe variable substitution
autoescape=False,
# Custom delimiters to avoid conflicts with common markdown syntax
variable_start_string='{{',
variable_end_string='}}',
block_start_string='{%',
block_end_string='%}',
comment_start_string='{#',
comment_end_string='#}',
# Keep linebreaks
keep_trailing_newline=True,
# Remove unnecessary whitespace
trim_blocks=True,
lstrip_blocks=True
)
def render_template(self, template_string: str, variables: Dict[str, Any]) -> str:
"""
Render a Jinja2 template string with provided variables
Args:
template_string: The template string with Jinja2 syntax
variables: Dictionary of variables to substitute
Returns:
Rendered template string
Raises:
TemplateError: If template rendering fails
"""
try:
template = self.env.from_string(template_string)
rendered = template.render(**variables)
logger.debug(f"Template rendered successfully with variables: {list(variables.keys())}")
return rendered
except TemplateError as e:
logger.error(f"Template rendering failed: {e}")
logger.error(f"Template: {template_string[:200]}...")
logger.error(f"Variables: {variables}")
raise
except Exception as e:
logger.error(f"Unexpected error during template rendering: {e}")
raise TemplateError(f"Template rendering failed: {e}")
def render_system_prompt(self, template_string: str, variables: Dict[str, Any]) -> str:
"""
Render system prompt template
Args:
template_string: System prompt template
variables: Variables for substitution
Returns:
Rendered system prompt
"""
return self.render_template(template_string, variables)
def render_user_prompt(self, template_string: str, variables: Dict[str, Any]) -> str:
"""
Render user prompt template
Args:
template_string: User prompt template
variables: Variables for substitution
Returns:
Rendered user prompt
"""
return self.render_template(template_string, variables)
# Global template renderer instance
template_renderer = TemplateRenderer()
def render_prompt_template(template_string: str, variables: Dict[str, Any]) -> str:
"""
Convenience function to render prompt templates
Args:
template_string: Template string with Jinja2 syntax
variables: Dictionary of variables to substitute
Returns:
Rendered template string
"""
return template_renderer.render_template(template_string, variables)