This commit is contained in:
2025-09-26 17:15:54 +08:00
commit db0e5965ec
211 changed files with 40437 additions and 0 deletions

View File

@@ -0,0 +1 @@
# Empty __init__.py files to make packages

View File

@@ -0,0 +1,746 @@
import json
import logging
import re
import asyncio
from typing import Dict, Any, List, Callable, Annotated, Literal, TypedDict, Optional, Union, cast
from datetime import datetime
from urllib.parse import quote
from contextvars import ContextVar
from pydantic import BaseModel
from langgraph.graph import StateGraph, END, add_messages, MessagesState
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, ToolMessage, BaseMessage
from langchain_core.runnables import RunnableConfig
from .state import TurnState, Message, ToolResult, AgentState
from .message_trimmer import create_conversation_trimmer
from .tools import get_tool_schemas, get_tools_by_name
from .user_manual_tools import get_user_manual_tools_by_name
from .intent_recognition import intent_recognition_node, intent_router
from .user_manual_rag import user_manual_rag_node
from ..llm_client import LLMClient
from ..config import get_config
from ..utils.templates import render_prompt_template
from ..memory.postgresql_memory import get_checkpointer
from ..utils.error_handler import (
StructuredLogger, ErrorCategory, ErrorCode,
handle_async_errors, get_user_message
)
from ..sse import (
create_tool_start_event,
create_tool_result_event,
create_tool_error_event,
create_token_event,
create_error_event
)
logger = StructuredLogger(__name__)
# Cache configuration at module level to avoid repeated get_config() calls
_cached_config = None
def get_cached_config():
"""Get cached configuration, loading it if not already cached"""
global _cached_config
if _cached_config is None:
_cached_config = get_config()
return _cached_config
# Context variable for streaming callback (thread-safe)
stream_callback_context: ContextVar[Optional[Callable]] = ContextVar('stream_callback', default=None)
# Agent node (autonomous function calling agent)
async def call_model(state: AgentState, config: Optional[RunnableConfig] = None) -> Dict[str, Any]:
"""
Agent node that autonomously uses tools and generates final answer.
Implements "detect-first-then-stream" strategy for optimal multi-round behavior:
1. Always start with non-streaming detection to check for tool needs
2. If tool_calls exist → return immediately for routing to tools
3. If no tool_calls → temporarily disable tools and perform streaming final synthesis
"""
app_config = get_cached_config()
llm_client = LLMClient()
# Get stream callback from context variable
stream_callback = stream_callback_context.get()
# Get tool schemas and bind tools for planning phase
tool_schemas = get_tool_schemas()
llm_client.bind_tools(tool_schemas, force_tool_choice=True)
# Create conversation trimmer for managing context length
trimmer = create_conversation_trimmer()
# Prepare messages with system prompt
messages = state["messages"].copy()
if not messages or not isinstance(messages[0], SystemMessage):
rag_prompts = app_config.get_rag_prompts()
system_prompt = rag_prompts.get("agent_system_prompt", "")
if not system_prompt:
raise ValueError("system_prompt is null")
messages = [SystemMessage(content=system_prompt)] + messages
# Track tool rounds
current_round = state.get("tool_rounds", 0)
# Get max_tool_rounds from state, fallback to config if not set
max_rounds = state.get("max_tool_rounds", None)
if max_rounds is None:
max_rounds = app_config.app.max_tool_rounds
# Only apply trimming at the start of a new conversation turn (when tool_rounds = 0)
# This prevents trimming current turn's tool results during multi-round tool calling
if current_round == 0:
# Trim conversation history to manage context length (only for previous conversation turns)
if trimmer.should_trim(messages):
messages = trimmer.trim_conversation_history(messages)
logger.info("Applied conversation history trimming for context management (new conversation turn)")
else:
logger.info(f"Skipping trimming during tool round {current_round} to preserve current turn's context")
logger.info(f"Agent node: tool_rounds={current_round}, max_tool_rounds={max_rounds}")
# Check if this should be final synthesis (max rounds reached)
has_tool_messages = any(isinstance(msg, ToolMessage) for msg in messages)
is_final_synthesis = has_tool_messages and current_round >= max_rounds
if is_final_synthesis:
logger.info("Starting final synthesis phase - no more tool calls allowed")
# ✅ STEP 1: Final synthesis with tools disabled from the start
# Disable tools to prevent any tool calling during synthesis
try:
original_tools = llm_client.bind_tools([], force_tool_choice=False) # Disable tools
if not stream_callback:
# No streaming callback, generate final response without tools
draft = await llm_client.ainvoke(list(messages))
return {"messages": [draft]}
# ✅ STEP 2: Streaming final synthesis with improved HTML comment filtering
response_content = ""
accumulated_content = ""
async for token in llm_client.astream(list(messages)):
accumulated_content += token
response_content += token
# Check for complete HTML comments in accumulated content
while "<!--" in accumulated_content and "-->" in accumulated_content:
comment_start = accumulated_content.find("<!--")
comment_end = accumulated_content.find("-->", comment_start)
if comment_start >= 0 and comment_end >= 0:
# Send content before comment
before_comment = accumulated_content[:comment_start]
if stream_callback and before_comment:
await stream_callback(create_token_event(before_comment))
# Skip the comment and continue with content after
accumulated_content = accumulated_content[comment_end + 3:]
else:
break
# Send accumulated content if no pending comment
if "<!--" not in accumulated_content:
if stream_callback and accumulated_content:
await stream_callback(create_token_event(accumulated_content))
accumulated_content = ""
# Send any remaining content (if not in middle of comment)
if accumulated_content and "<!--" not in accumulated_content:
if stream_callback:
await stream_callback(create_token_event(accumulated_content))
return {"messages": [AIMessage(content=response_content)]}
finally:
# ✅ STEP 3: Restore tool binding for next interaction
llm_client.bind_tools(tool_schemas, force_tool_choice=True)
else:
logger.info(f"Tool calling round {current_round + 1}/{max_rounds}")
# ✅ STEP 1: Non-streaming detection to check for tool needs
draft = await llm_client.ainvoke_with_tools(list(messages))
# ✅ STEP 2: If draft has tool_calls, return immediately (let routing handle it)
if isinstance(draft, AIMessage) and hasattr(draft, 'tool_calls') and draft.tool_calls:
# Increment tool round counter for next iteration
new_tool_rounds = current_round + 1
logger.info(f"Incremented tool_rounds to {new_tool_rounds}")
return {"messages": [draft], "tool_rounds": new_tool_rounds}
# ✅ STEP 3: No tool_calls needed → Enter final synthesis with streaming
# Temporarily disable tools to prevent accidental tool calling during synthesis
try:
llm_client.bind_tools([], force_tool_choice=False) # Disable tools
if not stream_callback:
# No streaming callback, use the draft we already have
return {"messages": [draft]}
# ✅ STEP 4: Streaming final synthesis with improved HTML comment filtering
response_content = ""
accumulated_content = ""
async for token in llm_client.astream(list(messages)):
accumulated_content += token
response_content += token
# Check for complete HTML comments in accumulated content
while "<!--" in accumulated_content and "-->" in accumulated_content:
comment_start = accumulated_content.find("<!--")
comment_end = accumulated_content.find("-->", comment_start)
if comment_start >= 0 and comment_end >= 0:
# Send content before comment
before_comment = accumulated_content[:comment_start]
if stream_callback and before_comment:
await stream_callback(create_token_event(before_comment))
# Skip the comment and continue with content after
accumulated_content = accumulated_content[comment_end + 3:]
else:
break
# Send accumulated content if no pending comment
if "<!--" not in accumulated_content:
if stream_callback and accumulated_content:
await stream_callback(create_token_event(accumulated_content))
accumulated_content = ""
# Send any remaining content (if not in middle of comment)
if accumulated_content and "<!--" not in accumulated_content:
if stream_callback:
await stream_callback(create_token_event(accumulated_content))
return {"messages": [AIMessage(content=response_content)]}
finally:
# ✅ STEP 5: Restore tool binding for next interaction
llm_client.bind_tools(tool_schemas, force_tool_choice=True)
# Tools routing condition (simplified for "detect-first-then-stream" strategy)
def should_continue(state: AgentState) -> Literal["tools", "agent", "post_process"]:
"""
Simplified routing logic for "detect-first-then-stream" strategy:
- has tool_calls → route to tools
- no tool_calls → route to post_process (final synthesis already completed)
"""
messages = state["messages"]
if not messages:
logger.info("should_continue: No messages, routing to post_process")
return "post_process"
last_message = messages[-1]
current_round = state.get("tool_rounds", 0)
# Get max_tool_rounds from state, fallback to config if not set
max_rounds = state.get("max_tool_rounds", None)
if max_rounds is None:
app_config = get_cached_config()
max_rounds = app_config.app.max_tool_rounds
logger.info(f"should_continue: Last message type: {type(last_message)}, tool_rounds: {current_round}/{max_rounds}")
# If last message is AI message with tool calls, route to tools
if isinstance(last_message, AIMessage):
has_tool_calls = hasattr(last_message, 'tool_calls') and last_message.tool_calls
logger.info(f"should_continue: AI message has tool_calls: {has_tool_calls}")
if has_tool_calls:
logger.info("should_continue: Routing to tools")
return "tools"
else:
# No tool calls = final synthesis already completed in call_model
logger.info("should_continue: No tool calls, routing to post_process")
return "post_process"
# If last message is tool message(s), continue with agent for next round or final synthesis
if isinstance(last_message, ToolMessage):
logger.info("should_continue: Tool message completed, continuing to agent")
return "agent"
logger.info("should_continue: Routing to post_process")
return "post_process"
# Custom tool node with streaming support and parallel execution
async def run_tools_with_streaming(state: AgentState, config: Optional[RunnableConfig] = None) -> Dict[str, Any]:
"""Execute tools with streaming events - supports parallel execution"""
messages = state["messages"]
last_message = messages[-1]
# Get stream callback from context variable
stream_callback = stream_callback_context.get()
if not isinstance(last_message, AIMessage) or not hasattr(last_message, 'tool_calls'):
return {"messages": []}
tool_calls = last_message.tool_calls or []
tool_results = []
new_messages = []
# Tools mapping
tools_map = get_tools_by_name()
async def execute_single_tool(tool_call):
"""Execute a single tool call with enhanced error handling"""
# Get stream callback from context
stream_callback = stream_callback_context.get()
# Apply error handling decorator
@handle_async_errors(
ErrorCategory.TOOL,
ErrorCode.TOOL_ERROR,
stream_callback,
tool_call.get("id", "unknown") if isinstance(tool_call, dict) else "unknown"
)
async def _execute():
# Validate tool_call format
if not isinstance(tool_call, dict):
raise ValueError(f"Tool call must be dict, got {type(tool_call)}")
tool_name = tool_call.get("name")
tool_args = tool_call.get("args", {})
tool_id = tool_call.get("id", "unknown")
if not tool_name or tool_name not in tools_map:
raise ValueError(f"Tool '{tool_name}' not found")
logger.info(f"Executing tool: {tool_name}", extra={
"tool_id": tool_id, "tool_name": tool_name
})
# Send start event
if stream_callback:
await stream_callback(create_tool_start_event(tool_id, tool_name, tool_args))
# Execute tool
import time
start_time = time.time()
result = await tools_map[tool_name].ainvoke(tool_args)
execution_time = int((time.time() - start_time) * 1000)
# Process result
if isinstance(result, dict):
result["tool_call_id"] = tool_id
if "results" in result and isinstance(result["results"], list):
for i, search_result in enumerate(result["results"]):
if isinstance(search_result, dict):
search_result["@tool_call_id"] = tool_id
search_result["@order_num"] = i
# Send result event
if stream_callback:
await stream_callback(create_tool_result_event(
tool_id, tool_name, result.get("results", []), execution_time
))
# Create tool message
tool_message = ToolMessage(
content=json.dumps(result, ensure_ascii=False),
tool_call_id=tool_id,
name=tool_name
)
return {
"message": tool_message,
"results": result.get("results", []) if isinstance(result, dict) else [],
"success": True
}
try:
return await _execute()
except Exception as e:
# Handle any errors not caught by decorator
tool_id = tool_call.get("id", "unknown") if isinstance(tool_call, dict) else "unknown"
tool_name = tool_call.get("name", "unknown") if isinstance(tool_call, dict) else "unknown"
error_message = ToolMessage(
content=f"Error: {get_user_message(ErrorCategory.TOOL)}",
tool_call_id=tool_id,
name=tool_name
)
return {
"message": error_message,
"results": [],
"success": False
}
# Execute all tool calls in parallel using asyncio.gather
if tool_calls:
logger.info(f"Executing {len(tool_calls)} tool calls in parallel")
tool_execution_results = await asyncio.gather(
*[execute_single_tool(tool_call) for tool_call in tool_calls],
return_exceptions=True
)
# Process results
for execution_result in tool_execution_results:
if execution_result is None:
continue
if isinstance(execution_result, Exception):
logger.error(f"Tool execution exception: {execution_result}")
continue
if not isinstance(execution_result, dict):
logger.error(f"Unexpected execution result type: {type(execution_result)}")
continue
new_messages.append(execution_result["message"])
if execution_result["success"] and execution_result["results"]:
tool_results.extend(execution_result["results"])
logger.info(f"Parallel tool execution completed. {len(new_messages)} tools executed, {len(tool_results)} results collected")
return {
"messages": new_messages,
"tool_results": tool_results
}
# Helper functions for citation processing
def _extract_citations_mapping(agent_response: str) -> Dict[int, Dict[str, Any]]:
"""Extract citations mapping CSV from agent response HTML comment"""
try:
# Look for citations_map comment
pattern = r'<!-- citations_map\s*(.*?)\s*-->'
match = re.search(pattern, agent_response, re.DOTALL | re.IGNORECASE)
if not match:
logger.warning("No citations_map comment found in agent response")
return {}
csv_content = match.group(1).strip()
citations_mapping = {}
for line in csv_content.split('\n'):
line = line.strip()
if not line:
continue
parts = line.split(',')
if len(parts) >= 3:
try:
citation_num = int(parts[0])
tool_call_id = parts[1].strip()
order_num = int(parts[2])
citations_mapping[citation_num] = {
'tool_call_id': tool_call_id,
'order_num': order_num
}
except (ValueError, IndexError) as e:
logger.warning(f"Failed to parse citation line: {line}, error: {e}")
continue
return citations_mapping
except Exception as e:
logger.error(f"Error extracting citations mapping: {e}")
return {}
def _build_citation_markdown(citations_mapping: Dict[int, Dict[str, Any]], tool_results: List[Dict[str, Any]]) -> str:
"""Build citation markdown based on mapping and tool results, following build_citations.py logic"""
if not citations_mapping:
return ""
# Get configuration for citation base URL
config = get_cached_config()
cat_base_url = config.citation.base_url
# Collect citation lines first; only emit header if we have at least one valid citation
entries: List[str] = []
for citation_num in sorted(citations_mapping.keys()):
mapping = citations_mapping[citation_num]
tool_call_id = mapping['tool_call_id']
order_num = mapping['order_num']
# Find the corresponding tool result
result = _find_tool_result(tool_results, tool_call_id, order_num)
if not result:
logger.warning(f"No tool result found for citation [{citation_num}]")
continue
# Extract citation information following build_citations.py logic
full_headers = result.get('full_headers', '')
lowest_header = full_headers.split("||", 1)[0] if full_headers else ""
header_display = f": {lowest_header}" if lowest_header else ""
document_code = result.get('document_code', '')
document_category = result.get('document_category', '')
# Determine standard/regulation title (assuming English language)
standard_regulation_title = ''
if document_category == 'Standard':
standard_regulation_title = result.get('x_Standard_Title_EN', '') or result.get('x_Standard_Title_CN', '')
elif document_category == 'Regulation':
standard_regulation_title = result.get('x_Regulation_Title_EN', '') or result.get('x_Regulation_Title_CN', '')
# Build link
func_uuid = result.get('func_uuid', '')
uuid = result.get('x_Standard_Regulation_Id', '')
document_code_encoded = quote(document_code, safe='') if document_code else ''
standard_regulation_title_encoded = quote(standard_regulation_title, safe='') if standard_regulation_title else ''
link_name = f"{document_code_encoded}({standard_regulation_title_encoded})" if (document_code_encoded or standard_regulation_title_encoded) else ''
link = f'{cat_base_url}?funcUuid={func_uuid}&uuid={uuid}&name={link_name}'
# Format citation line
title = result.get('title', '')
entries.append(f"[{citation_num}] {title}{header_display} | [{standard_regulation_title} | {document_code}]({link})")
# If no valid citations were found, do not include the header
if not entries:
return ""
# Build citations section with entries separated by a blank line (matching previous formatting)
md = "\n\n### 📘 Citations:\n" + "\n\n".join(entries) + "\n\n"
return md
def _find_tool_result(tool_results: List[Dict[str, Any]], tool_call_id: str, order_num: int) -> Optional[Dict[str, Any]]:
"""Find tool result by tool_call_id and order_num"""
matching_results = []
for result in tool_results:
if result.get('@tool_call_id') == tool_call_id:
matching_results.append(result)
# Sort by order and return the one at the specified position
if matching_results and 0 <= order_num < len(matching_results):
# If results have @order_num, use it; otherwise use position in list
if '@order_num' in matching_results[0]:
for result in matching_results:
if result.get('@order_num') == order_num:
return result
else:
return matching_results[order_num]
return None
def _remove_citations_comment(agent_response: str) -> str:
"""Remove citations mapping HTML comment from agent response"""
pattern = r'<!-- citations_map\s*.*?\s*-->'
return re.sub(pattern, '', agent_response, flags=re.DOTALL | re.IGNORECASE).strip()
# Post-processing node with citation list and link building
async def post_process_node(state: AgentState, config: Optional[RunnableConfig] = None) -> Dict[str, Any]:
"""
Post-processing node that builds citation list and links based on agent's citations mapping
and tool call results, following the logic from build_citations.py
"""
try:
logger.info("🔧 POST_PROCESS_NODE: Starting citation processing")
# Get stream callback from context variable
stream_callback = stream_callback_context.get()
# Get the last AI message (agent's response with citations mapping)
agent_response = ""
citations_mapping = {}
for message in reversed(state["messages"]):
if isinstance(message, AIMessage) and message.content:
# Ensure content is a string
if isinstance(message.content, str):
agent_response = message.content
break
if not agent_response:
logger.warning("POST_PROCESS_NODE: No agent response found")
return {"messages": [], "final_answer": ""}
# Extract citations mapping from agent response
citations_mapping = _extract_citations_mapping(agent_response)
logger.info(f"POST_PROCESS_NODE: Extracted {len(citations_mapping)} citations")
# Build citation markdown
citation_markdown = _build_citation_markdown(citations_mapping, state["tool_results"])
# Combine agent response (without HTML comment) with citations
clean_response = _remove_citations_comment(agent_response)
final_content = clean_response + citation_markdown
logger.info("POST_PROCESS_NODE: Built complete response with citations")
# Send citation markdown as a single block instead of streaming
stream_callback = stream_callback_context.get()
if stream_callback and citation_markdown:
logger.info("POST_PROCESS_NODE: Sending citation markdown as single block to client")
await stream_callback(create_token_event(citation_markdown))
# Create AI message with complete content
final_ai_message = AIMessage(content=final_content)
return {
"messages": [final_ai_message],
"final_answer": final_content
}
except Exception as e:
logger.error(f"Post-processing error: {e}")
error_message = "\n\n❌ **Error generating citations**\n\nPlease check the search results above."
# Send error message as single block
stream_callback = stream_callback_context.get()
if stream_callback:
await stream_callback(create_token_event(error_message))
error_content = agent_response + error_message if agent_response else error_message
error_ai_message = AIMessage(content=error_content)
return {
"messages": [error_ai_message],
"final_answer": error_ai_message.content
}
# Main workflow class
class AgenticWorkflow:
"""LangGraph-based autonomous agent workflow following v0.6.0+ best practices"""
def __init__(self):
# Build StateGraph with TypedDict state
workflow = StateGraph(AgentState)
# Add nodes following best practices
workflow.add_node("intent_recognition", intent_recognition_node)
workflow.add_node("agent", call_model)
workflow.add_node("user_manual_rag", user_manual_rag_node)
workflow.add_node("tools", run_tools_with_streaming)
workflow.add_node("post_process", post_process_node)
# Set entry point to intent recognition
workflow.set_entry_point("intent_recognition")
# Intent recognition routes to either Standard_Regulation_RAG or User_Manual_RAG
workflow.add_conditional_edges(
"intent_recognition",
intent_router,
{
"Standard_Regulation_RAG": "agent",
"User_Manual_RAG": "user_manual_rag"
}
)
# Standard RAG workflow (existing pattern)
workflow.add_conditional_edges(
"agent",
should_continue,
{
"tools": "tools",
"agent": "agent", # Allow agent to continue for multi-round
"post_process": "post_process"
}
)
# Tools route back to should_continue for multi-round decision
workflow.add_conditional_edges(
"tools",
should_continue,
{
"agent": "agent", # Continue to agent for next round
"post_process": "post_process" # Or finish if max rounds reached
}
)
# User Manual RAG directly goes to END (single turn)
workflow.add_edge("user_manual_rag", END)
# Post-process is terminal
workflow.add_edge("post_process", END)
# Compile graph with PostgreSQL checkpointer for session memory
try:
checkpointer = get_checkpointer()
self.graph = workflow.compile(checkpointer=checkpointer)
logger.info("Graph compiled with PostgreSQL checkpointer for session memory")
except Exception as e:
logger.warning(f"Failed to initialize PostgreSQL checkpointer, using memory-only graph: {e}")
self.graph = workflow.compile()
async def astream(self, state: TurnState, stream_callback: Callable | None = None):
"""Stream agent execution using LangGraph with PostgreSQL session memory"""
try:
# Get configuration
config = get_cached_config()
# Prepare initial messages for the graph
messages = []
for msg in state.messages:
if msg.role == "user":
messages.append(HumanMessage(content=msg.content))
elif msg.role == "assistant":
messages.append(AIMessage(content=msg.content))
# Create initial agent state (without stream_callback to avoid serialization issues)
initial_state: AgentState = {
"messages": messages,
"session_id": state.session_id,
"intent": None, # Will be determined by intent recognition node
"tool_results": [],
"final_answer": "",
"tool_rounds": 0,
"max_tool_rounds": config.app.max_tool_rounds, # Use configuration value
"max_tool_rounds_user_manual": config.app.max_tool_rounds_user_manual # Use configuration value for user manual agent
}
# Set stream callback in context variable (thread-safe)
stream_callback_context.set(stream_callback)
# Create proper RunnableConfig
runnable_config = RunnableConfig(configurable={"thread_id": state.session_id})
# Stream graph execution with session memory
async for step in self.graph.astream(initial_state, config=runnable_config):
if "post_process" in step:
final_state = step["post_process"]
# Extract the tool summary message and update state
state.final_answer = final_state.get("final_answer", "")
# Add the summary as a regular assistant message
if state.final_answer:
state.messages.append(Message(
role="assistant",
content=state.final_answer,
timestamp=datetime.now()
))
yield {"final": state}
break
elif "user_manual_rag" in step:
# Handle user manual RAG completion
final_state = step["user_manual_rag"]
# Extract the response from user manual RAG
state.final_answer = final_state.get("final_answer", "")
# Add the response as a regular assistant message
if state.final_answer:
state.messages.append(Message(
role="assistant",
content=state.final_answer,
timestamp=datetime.now()
))
yield {"final": state}
break
else:
# Process regular steps (intent_recognition, agent, tools)
yield step
except Exception as e:
logger.error(f"AgentWorkflow error: {e}")
state.final_answer = "I apologize, but I encountered an error while processing your request."
yield {"final": state}
def build_graph() -> AgenticWorkflow:
"""Build and return the autonomous agent workflow"""
return AgenticWorkflow()

View File

@@ -0,0 +1,136 @@
"""
Intent recognition functionality for the Agentic RAG system.
This module contains the intent classification logic.
"""
import logging
from typing import Dict, Any, Optional, Literal
from langchain_core.messages import SystemMessage
from langchain_core.runnables import RunnableConfig
from pydantic import BaseModel
from .state import AgentState
from ..llm_client import LLMClient
from ..config import get_config
from ..utils.error_handler import StructuredLogger
logger = StructuredLogger(__name__)
# Intent Recognition Models
class Intent(BaseModel):
"""Intent classification model for routing user queries"""
label: Literal["Standard_Regulation_RAG", "User_Manual_RAG"]
confidence: Optional[float] = None
def get_last_user_message(messages) -> str:
"""Extract the last user message from conversation history"""
for message in reversed(messages):
if hasattr(message, 'content'):
content = message.content
# Handle both string and list content
if isinstance(content, str):
return content
elif isinstance(content, list):
# Extract string content from list
return " ".join([str(item) for item in content if isinstance(item, str)])
return ""
def render_conversation_history(messages, max_messages: int = 10) -> str:
"""Render conversation history for context"""
recent_messages = messages[-max_messages:] if len(messages) > max_messages else messages
lines = []
for msg in recent_messages:
if hasattr(msg, 'content'):
content = msg.content
if isinstance(content, str):
# Determine message type by class name or other attributes
if 'Human' in str(type(msg)):
lines.append(f"<user>{content}</user>")
elif 'AI' in str(type(msg)):
lines.append(f"<ai>{content}</ai>")
elif isinstance(content, list):
content_str = " ".join([str(item) for item in content if isinstance(item, str)])
if 'Human' in str(type(msg)):
lines.append(f"<user>{content_str}</user>")
elif 'AI' in str(type(msg)):
lines.append(f"<ai>{content_str}</ai>")
return "\n".join(lines)
async def intent_recognition_node(state: AgentState, config: Optional[RunnableConfig] = None) -> Dict[str, Any]:
"""
Intent recognition node that uses LLM to classify user queries into specific domains
"""
try:
logger.info("🎯 INTENT_RECOGNITION_NODE: Starting intent classification")
app_config = get_config()
llm_client = LLMClient()
# Get current user query and conversation history
current_query = get_last_user_message(state["messages"])
conversation_context = render_conversation_history(state["messages"])
# Get intent classification prompt from configuration
rag_prompts = app_config.get_rag_prompts()
intent_prompt_template = rag_prompts.get("intent_recognition_prompt")
if not intent_prompt_template:
logger.error("Intent recognition prompt not found in configuration")
return {"intent": "Standard_Regulation_RAG"}
# Format the prompt with instruction to return only the label
system_prompt = intent_prompt_template.format(
current_query=current_query,
conversation_context=conversation_context
) + "\n\nIMPORTANT: You must respond with ONLY one of these two exact labels: 'Standard_Regulation_RAG' or 'User_Manual_RAG'. Do not include any other text or explanation."
# Classify intent using regular LLM call
intent_result = await llm_client.llm.ainvoke([
SystemMessage(content=system_prompt)
])
# Parse the response to extract the intent label
response_text = ""
if hasattr(intent_result, 'content') and intent_result.content:
if isinstance(intent_result.content, str):
response_text = intent_result.content.strip()
elif isinstance(intent_result.content, list):
# Handle list content by joining string elements
response_text = " ".join([str(item) for item in intent_result.content if isinstance(item, str)]).strip()
# Extract intent label from response
if "User_Manual_RAG" in response_text:
intent_label = "User_Manual_RAG"
elif "Standard_Regulation_RAG" in response_text:
intent_label = "Standard_Regulation_RAG"
else:
# Default fallback
logger.warning(f"Could not parse intent from response: {response_text}, defaulting to Standard_Regulation_RAG")
intent_label = "Standard_Regulation_RAG"
logger.info(f"🎯 INTENT_RECOGNITION_NODE: Classified intent as '{intent_label}'")
return {"intent": intent_label}
except Exception as e:
logger.error(f"Intent recognition error: {e}")
# Default to Standard_Regulation_RAG if classification fails
logger.info("🎯 INTENT_RECOGNITION_NODE: Defaulting to Standard_Regulation_RAG due to error")
return {"intent": "Standard_Regulation_RAG"}
def intent_router(state: AgentState) -> Literal["Standard_Regulation_RAG", "User_Manual_RAG"]:
"""
Route based on intent classification result
"""
intent = state.get("intent")
if intent is None:
logger.warning("🎯 INTENT_ROUTER: No intent found, defaulting to Standard_Regulation_RAG")
return "Standard_Regulation_RAG"
logger.info(f"🎯 INTENT_ROUTER: Routing to {intent}")
return intent

View File

@@ -0,0 +1,270 @@
"""
Conversation history trimming utilities for managing context length.
"""
import logging
from typing import List, Optional, Sequence, Tuple
from langchain_core.messages import BaseMessage, SystemMessage, HumanMessage, ToolMessage, AIMessage, AnyMessage
from langchain_core.messages.utils import trim_messages, count_tokens_approximately
logger = logging.getLogger(__name__)
class ConversationTrimmer:
"""
Manages conversation history to prevent exceeding LLM context limits.
"""
def __init__(self, max_context_length: int = 96000, preserve_system: bool = True):
"""
Initialize the conversation trimmer.
Args:
max_context_length: Maximum context length for conversation history (in tokens)
preserve_system: Whether to always preserve system messages
"""
self.max_context_length = max_context_length
self.preserve_system = preserve_system
# Reserve tokens for response generation (use 85% for history, 15% for response)
self.history_token_limit = int(max_context_length * 0.85)
def trim_conversation_history(self, messages: Sequence[AnyMessage]) -> List[BaseMessage]:
"""
Trim conversation history to fit within token limits.
Args:
messages: List of conversation messages
Returns:
Trimmed list of messages
"""
if not messages:
return list(messages)
try:
# Convert to list for processing
message_list = list(messages)
# First, try multi-round tool call optimization
optimized_messages = self._optimize_multi_round_tool_calls(message_list)
# Check if optimization is sufficient
try:
token_count = count_tokens_approximately(optimized_messages)
if token_count <= self.history_token_limit:
original_count = len(message_list)
optimized_count = len(optimized_messages)
if optimized_count < original_count:
logger.info(f"Multi-round tool optimization: {original_count} -> {optimized_count} messages")
return optimized_messages
except Exception:
# If token counting fails, continue with LangChain trimming
pass
# If still too long, use LangChain's trim_messages utility
trimmed_messages = trim_messages(
optimized_messages,
strategy="last", # Keep most recent messages
token_counter=count_tokens_approximately,
max_tokens=self.history_token_limit,
start_on="human", # Ensure valid conversation start
end_on=("human", "tool", "ai"), # Allow ending on human, tool, or AI messages
include_system=self.preserve_system, # Preserve system messages
allow_partial=False # Don't split individual messages
)
original_count = len(messages)
trimmed_count = len(trimmed_messages)
if trimmed_count < original_count:
logger.info(f"Trimmed conversation history: {original_count} -> {trimmed_count} messages")
return trimmed_messages
except Exception as e:
logger.error(f"Error trimming conversation history: {e}")
# Fallback: keep last N messages
return self._fallback_trim(list(messages))
def _optimize_multi_round_tool_calls(self, messages: List[AnyMessage]) -> List[BaseMessage]:
"""
Optimize conversation history by removing older tool call results in multi-round scenarios.
This reduces token usage while preserving conversation context.
Strategy:
1. Always preserve system messages
2. Always preserve the original user query
3. Keep the most recent AI-Tool message pairs (for context continuity)
4. Remove older ToolMessage content which typically contains large JSON responses
Args:
messages: List of conversation messages
Returns:
Optimized list of messages
"""
if len(messages) <= 4: # Too short to optimize
return [msg for msg in messages]
# Identify message patterns
tool_rounds = self._identify_tool_rounds(messages)
if len(tool_rounds) <= 1: # Single or no tool round, no optimization needed
return [msg for msg in messages]
logger.info(f"Multi-round tool optimization: Found {len(tool_rounds)} tool rounds")
# Build optimized message list
optimized = []
# Always preserve system messages
for msg in messages:
if isinstance(msg, SystemMessage):
optimized.append(msg)
# Preserve initial user query (first human message after system)
first_human_added = False
for msg in messages:
if isinstance(msg, HumanMessage) and not first_human_added:
optimized.append(msg)
first_human_added = True
break
# Keep only the most recent tool round (preserve context for next round)
if tool_rounds:
latest_round_start, latest_round_end = tool_rounds[-1]
# Add messages from the latest tool round
for i in range(latest_round_start, min(latest_round_end + 1, len(messages))):
msg = messages[i]
if not isinstance(msg, SystemMessage) and not (isinstance(msg, HumanMessage) and not first_human_added):
optimized.append(msg)
logger.info(f"Multi-round optimization: {len(messages)} -> {len(optimized)} messages (removed {len(tool_rounds)-1} older tool rounds)")
return optimized
def _identify_tool_rounds(self, messages: List[AnyMessage]) -> List[Tuple[int, int]]:
"""
Identify tool calling rounds in the message sequence.
A tool round typically consists of:
- AI message with tool_calls
- One or more ToolMessage responses
Returns:
List of (start_index, end_index) tuples for each tool round
"""
rounds = []
i = 0
while i < len(messages):
msg = messages[i]
# Look for AI message with tool calls
if isinstance(msg, AIMessage) and hasattr(msg, 'tool_calls') and msg.tool_calls:
round_start = i
round_end = i
# Find the end of this tool round (look for consecutive ToolMessages)
j = i + 1
while j < len(messages) and isinstance(messages[j], ToolMessage):
round_end = j
j += 1
# Only consider it a tool round if we found at least one ToolMessage
if round_end > round_start:
rounds.append((round_start, round_end))
i = round_end + 1
else:
i += 1
else:
i += 1
return rounds
def _fallback_trim(self, messages: List[AnyMessage], max_messages: int = 20) -> List[BaseMessage]:
"""
Fallback trimming based on message count.
Args:
messages: List of conversation messages
max_messages: Maximum number of messages to keep
Returns:
Trimmed list of messages
"""
if len(messages) <= max_messages:
return [msg for msg in messages] # Convert to BaseMessage
# Preserve system message if it exists
system_messages = [msg for msg in messages if isinstance(msg, SystemMessage)]
other_messages = [msg for msg in messages if not isinstance(msg, SystemMessage)]
# Keep the most recent messages
recent_messages = other_messages[-(max_messages - len(system_messages)):]
result = system_messages + recent_messages
logger.info(f"Fallback trimming: {len(messages)} -> {len(result)} messages")
return [msg for msg in result] # Ensure BaseMessage type
def should_trim(self, messages: Sequence[AnyMessage]) -> bool:
"""
Check if conversation history should be trimmed.
Strategy:
1. Always trim if there are multiple tool rounds from previous conversation turns
2. Also trim if approaching token limit
Args:
messages: List of conversation messages
Returns:
True if trimming is needed
"""
try:
# Convert to list for processing
message_list = list(messages)
# Check for multiple tool rounds - if found, always trim to remove old tool results
tool_rounds = self._identify_tool_rounds(message_list)
if len(tool_rounds) > 1:
logger.info(f"Found {len(tool_rounds)} tool rounds - trimming to remove old tool results")
return True
# Also check token count for traditional trimming
token_count = count_tokens_approximately(message_list)
return token_count > self.history_token_limit
except Exception:
# Fallback to message count
return len(messages) > 30
def create_conversation_trimmer(max_context_length: Optional[int] = None) -> ConversationTrimmer:
"""
Create a conversation trimmer with config-based settings.
Args:
max_context_length: Override for maximum context length
Returns:
ConversationTrimmer instance
"""
# If max_context_length is provided, use it directly
if max_context_length is not None:
return ConversationTrimmer(
max_context_length=max_context_length,
preserve_system=True
)
# Try to get from config, fallback to default if config not available
try:
from ..config import get_config
config = get_config()
effective_max_context_length = config.get_max_context_length()
except (RuntimeError, AttributeError):
effective_max_context_length = 96000
return ConversationTrimmer(
max_context_length=effective_max_context_length,
preserve_system=True
)

View File

@@ -0,0 +1,66 @@
from pydantic import BaseModel, Field
from typing import List, Dict, Any, Optional, Literal
from datetime import datetime
from typing_extensions import Annotated
from langgraph.graph.message import add_messages
from langchain_core.messages import BaseMessage
class Message(BaseModel):
"""Base message class for conversation history"""
role: str # "user", "assistant", "tool"
content: str
timestamp: Optional[datetime] = None
tool_call_id: Optional[str] = None
tool_name: Optional[str] = None
class Citation(BaseModel):
"""Citation mapping between numbers and result IDs"""
number: int
result_id: str
url: Optional[str] = None
class ToolResult(BaseModel):
"""Normalized tool result schema"""
id: str
title: str
url: Optional[str] = None
score: Optional[float] = None
metadata: Dict[str, Any] = Field(default_factory=dict)
content: Optional[str] = None # For chunk results
# Standard/regulation specific fields
publisher: Optional[str] = None
publish_date: Optional[str] = None
document_code: Optional[str] = None
document_category: Optional[str] = None
class TurnState(BaseModel):
"""State container for LangGraph workflow"""
session_id: str
messages: List[Message] = Field(default_factory=list)
tool_results: List[ToolResult] = Field(default_factory=list)
citations: List[Citation] = Field(default_factory=list)
meta: Dict[str, Any] = Field(default_factory=dict)
# Additional fields for tracking
current_step: int = 0
max_steps: int = 5
final_answer: Optional[str] = None
# TypedDict for LangGraph AgentState (LangGraph native format)
from typing import TypedDict
from langgraph.graph import MessagesState
class AgentState(MessagesState):
"""LangGraph state with intent recognition support"""
session_id: str
intent: Optional[Literal["Standard_Regulation_RAG", "User_Manual_RAG"]]
tool_results: Annotated[List[Dict[str, Any]], lambda x, y: (x or []) + (y or [])]
final_answer: str
tool_rounds: int
max_tool_rounds: int
max_tool_rounds_user_manual: int

View File

@@ -0,0 +1,98 @@
"""
Tool definitions and schemas for the Agentic RAG system.
This module contains all tool implementations and their corresponding schemas.
"""
import logging
from typing import Dict, Any, List
from langchain_core.tools import tool
from ..retrieval.retrieval import AgenticRetrieval
logger = logging.getLogger(__name__)
# Tool Definitions using @tool decorator (following LangGraph best practices)
@tool
async def retrieve_standard_regulation(query: str) -> Dict[str, Any]:
"""Search for attributes/metadata of China standards and regulations in automobile/manufacturing industry"""
async with AgenticRetrieval() as retrieval:
try:
result = await retrieval.retrieve_standard_regulation(
query=query
)
return {
"tool_name": "retrieve_standard_regulation",
"results_count": len(result.results),
"results": result.results, # Already dict objects, no need for model_dump()
"took_ms": result.took_ms
}
except Exception as e:
logger.error(f"Retrieval error: {e}")
return {"error": str(e), "results_count": 0, "results": []}
@tool
async def retrieve_doc_chunk_standard_regulation(query: str) -> Dict[str, Any]:
"""Search for detailed document content chunks of China standards and regulations in automobile/manufacturing industry"""
async with AgenticRetrieval() as retrieval:
try:
result = await retrieval.retrieve_doc_chunk_standard_regulation(
query=query
)
return {
"tool_name": "retrieve_doc_chunk_standard_regulation",
"results_count": len(result.results),
"results": result.results, # Already dict objects, no need for model_dump()
"took_ms": result.took_ms
}
except Exception as e:
logger.error(f"Doc chunk retrieval error: {e}")
return {"error": str(e), "results_count": 0, "results": []}
# Available tools list
tools = [retrieve_standard_regulation, retrieve_doc_chunk_standard_regulation]
def get_tool_schemas() -> List[Dict[str, Any]]:
"""
Generate tool schemas for LLM function calling.
Returns:
List of tool schemas in OpenAI function calling format
"""
tools.append();
tool_schemas = []
for tool in tools:
schema = {
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Search query for retrieving relevant information"
}
},
"required": ["query"]
}
}
}
tool_schemas.append(schema)
return tool_schemas
def get_tools_by_name() -> Dict[str, Any]:
"""
Create a mapping of tool names to tool functions.
Returns:
Dictionary mapping tool names to tool functions
"""
return {tool.name: tool for tool in tools}

View File

@@ -0,0 +1,464 @@
"""
User Manual Agent node for the Agentic RAG system.
This module contains the autonomous user manual agent that can use tools and generate responses.
"""
import logging
from typing import Dict, Any, List, Optional, Callable, Literal
from contextvars import ContextVar
from langchain_core.messages import AIMessage, SystemMessage, BaseMessage, ToolMessage, HumanMessage
from langchain_core.runnables import RunnableConfig
from .state import AgentState
from .user_manual_tools import get_user_manual_tool_schemas, get_user_manual_tools_by_name
from .message_trimmer import create_conversation_trimmer
from ..llm_client import LLMClient
from ..config import get_config
from ..sse import (
create_tool_start_event,
create_tool_result_event,
create_tool_error_event,
create_token_event,
create_error_event
)
from ..utils.error_handler import (
StructuredLogger, ErrorCategory, ErrorCode,
handle_async_errors, get_user_message
)
logger = StructuredLogger(__name__)
# Cache configuration at module level to avoid repeated get_config() calls
_cached_config = None
def get_cached_config():
"""Get cached configuration, loading it if not already cached"""
global _cached_config
if _cached_config is None:
_cached_config = get_config()
return _cached_config
# User Manual Agent node (autonomous function calling agent)
async def user_manual_agent_node(state: AgentState, config: Optional[RunnableConfig] = None) -> Dict[str, Any]:
"""
User Manual Agent node that autonomously uses user manual tools and generates final answer.
Implements "detect-first-then-stream" strategy for optimal multi-round behavior:
1. Always start with non-streaming detection to check for tool needs
2. If tool_calls exist → return immediately for routing to tools
3. If no tool_calls → temporarily disable tools and perform streaming final synthesis
"""
app_config = get_cached_config()
llm_client = LLMClient()
# Get stream callback from context variable
from .graph import stream_callback_context
stream_callback = stream_callback_context.get()
# Get user manual tool schemas and bind tools for planning phase
tool_schemas = get_user_manual_tool_schemas()
llm_client.bind_tools(tool_schemas, force_tool_choice=True)
# Create conversation trimmer for managing context length
trimmer = create_conversation_trimmer()
# Prepare messages with user manual system prompt
messages = state["messages"].copy()
if not messages or not isinstance(messages[0], SystemMessage):
rag_prompts = app_config.get_rag_prompts()
user_manual_prompt = rag_prompts.get("user_manual_prompt", "")
if not user_manual_prompt:
raise ValueError("user_manual_prompt is null")
# For user manual agent, we need to format the prompt with placeholders
# Extract current query and conversation history
current_query = ""
for message in reversed(messages):
if isinstance(message, HumanMessage):
current_query = message.content
break
conversation_history = ""
if len(messages) > 1:
conversation_history = render_conversation_history(messages[:-1]) # Exclude current query
# Format system prompt (initially with empty context, tools will provide it)
formatted_system_prompt = user_manual_prompt.format(
conversation_history=conversation_history,
context_content="", # Will be filled by tools
current_query=current_query
)
messages = [SystemMessage(content=formatted_system_prompt)] + messages
# Track tool rounds
current_round = state.get("tool_rounds", 0)
# Get max_tool_rounds_user_manual from state, fallback to config if not set
max_rounds = state.get("max_tool_rounds_user_manual", None)
if max_rounds is None:
max_rounds = app_config.app.max_tool_rounds_user_manual
# Only apply trimming at the start of a new conversation turn (when tool_rounds = 0)
# This prevents trimming current turn's tool results during multi-round tool calling
if current_round == 0:
# Trim conversation history to manage context length (only for previous conversation turns)
if trimmer.should_trim(messages):
messages = trimmer.trim_conversation_history(messages)
logger.info("Applied conversation history trimming for context management (new conversation turn)")
else:
logger.info(f"Skipping trimming during tool round {current_round} to preserve current turn's context")
logger.info(f"User Manual Agent node: tool_rounds={current_round}, max_tool_rounds={max_rounds}")
# Check if this should be final synthesis (max rounds reached)
has_tool_messages = any(isinstance(msg, ToolMessage) for msg in messages)
is_final_synthesis = has_tool_messages and current_round >= max_rounds
if is_final_synthesis:
logger.info("Starting final synthesis phase - no more tool calls allowed")
# ✅ STEP 1: Final synthesis with tools disabled from the start
# Disable tools to prevent any tool calling during synthesis
try:
original_tools = llm_client.bind_tools([], force_tool_choice=False) # Disable tools
if not stream_callback:
# No streaming callback, generate final response without tools
draft = await llm_client.ainvoke(list(messages))
return {"messages": [draft]}
# ✅ STEP 2: Streaming final synthesis with improved HTML comment filtering
response_content = ""
accumulated_content = ""
async for token in llm_client.astream(list(messages)):
accumulated_content += token
response_content += token
# Check for complete HTML comments in accumulated content
while "<!--" in accumulated_content and "-->" in accumulated_content:
comment_start = accumulated_content.find("<!--")
comment_end = accumulated_content.find("-->", comment_start)
if comment_start >= 0 and comment_end >= 0:
# Send content before comment
before_comment = accumulated_content[:comment_start]
if stream_callback and before_comment:
await stream_callback(create_token_event(before_comment))
# Skip the comment and continue with content after
accumulated_content = accumulated_content[comment_end + 3:]
else:
break
# Send accumulated content if no pending comment
if "<!--" not in accumulated_content:
if stream_callback and accumulated_content:
await stream_callback(create_token_event(accumulated_content))
accumulated_content = ""
# Send any remaining content (if not in middle of comment)
if accumulated_content and "<!--" not in accumulated_content:
if stream_callback:
await stream_callback(create_token_event(accumulated_content))
return {"messages": [AIMessage(content=response_content)]}
finally:
# ✅ STEP 3: Restore tool binding for next interaction
llm_client.bind_tools(tool_schemas, force_tool_choice=True)
else:
logger.info(f"User Manual tool calling round {current_round + 1}/{max_rounds}")
# ✅ STEP 1: Non-streaming detection to check for tool needs
draft = await llm_client.ainvoke_with_tools(list(messages))
# ✅ STEP 2: If draft has tool_calls, execute them within this node
if isinstance(draft, AIMessage) and hasattr(draft, 'tool_calls') and draft.tool_calls:
logger.info(f"Detected {len(draft.tool_calls)} tool calls, executing within user manual agent")
# Create a new state with the tool call message added
tool_call_state = state.copy()
updated_messages = state["messages"].copy()
updated_messages.append(draft)
tool_call_state["messages"] = updated_messages
# Execute the tools using the existing streaming tool execution function
tool_results = await run_user_manual_tools_with_streaming(tool_call_state)
tool_messages = tool_results.get("messages", [])
# Increment tool round counter for next iteration
new_tool_rounds = current_round + 1
logger.info(f"Incremented user manual tool_rounds to {new_tool_rounds}")
# Continue with another round if under max rounds
if new_tool_rounds < max_rounds:
# Recursive call for next round with all messages
final_messages = updated_messages + tool_messages
recursive_state = state.copy()
recursive_state["messages"] = final_messages
recursive_state["tool_rounds"] = new_tool_rounds
return await user_manual_agent_node(recursive_state)
else:
# Max rounds reached, force final synthesis
logger.info("Max tool rounds reached, forcing final synthesis")
# Update messages for final synthesis
messages = updated_messages + tool_messages
# Continue to final synthesis below
# ✅ STEP 3: No tool_calls needed or max rounds reached → Enter final synthesis with streaming
# Temporarily disable tools to prevent accidental tool calling during synthesis
try:
llm_client.bind_tools([], force_tool_choice=False) # Disable tools
if not stream_callback:
# No streaming callback, use the draft we already have
return {"messages": [draft]}
# ✅ STEP 4: Streaming final synthesis with improved HTML comment filtering
response_content = ""
accumulated_content = ""
async for token in llm_client.astream(list(messages)):
accumulated_content += token
response_content += token
# Check for complete HTML comments in accumulated content
while "<!--" in accumulated_content and "-->" in accumulated_content:
comment_start = accumulated_content.find("<!--")
comment_end = accumulated_content.find("-->", comment_start)
if comment_start >= 0 and comment_end >= 0:
# Send content before comment
before_comment = accumulated_content[:comment_start]
if stream_callback and before_comment:
await stream_callback(create_token_event(before_comment))
# Skip the comment and continue with content after
accumulated_content = accumulated_content[comment_end + 3:]
else:
break
# Send accumulated content if no pending comment
if "<!--" not in accumulated_content:
if stream_callback and accumulated_content:
await stream_callback(create_token_event(accumulated_content))
accumulated_content = ""
# Send any remaining content (if not in middle of comment)
if accumulated_content and "<!--" not in accumulated_content:
if stream_callback:
await stream_callback(create_token_event(accumulated_content))
return {"messages": [AIMessage(content=response_content)]}
finally:
# ✅ STEP 5: Restore tool binding for next interaction
llm_client.bind_tools(tool_schemas, force_tool_choice=True)
def render_conversation_history(messages, max_messages: int = 10) -> str:
"""Render conversation history for context"""
recent_messages = messages[-max_messages:] if len(messages) > max_messages else messages
lines = []
for msg in recent_messages:
if hasattr(msg, 'content'):
content = msg.content
if isinstance(content, str):
# Determine message type by class name or other attributes
if 'Human' in str(type(msg)):
lines.append(f"<user>{content}</user>")
elif 'AI' in str(type(msg)):
lines.append(f"<ai>{content}</ai>")
elif isinstance(content, list):
content_str = " ".join([str(item) for item in content if isinstance(item, str)])
if 'Human' in str(type(msg)):
lines.append(f"<user>{content_str}</user>")
elif 'AI' in str(type(msg)):
lines.append(f"<ai>{content_str}</ai>")
return "\n".join(lines)
# User Manual Tools routing condition
def user_manual_should_continue(state: AgentState) -> Literal["user_manual_tools", "user_manual_agent", "post_process"]:
"""
Routing logic for user manual agent:
- has tool_calls → route to user_manual_tools
- no tool_calls → route to post_process (final synthesis already completed)
"""
messages = state["messages"]
if not messages:
logger.info("user_manual_should_continue: No messages, routing to post_process")
return "post_process"
last_message = messages[-1]
current_round = state.get("tool_rounds", 0)
# Get max_tool_rounds_user_manual from state, fallback to config if not set
max_rounds = state.get("max_tool_rounds_user_manual", None)
if max_rounds is None:
app_config = get_cached_config()
max_rounds = app_config.app.max_tool_rounds_user_manual
logger.info(f"user_manual_should_continue: Last message type: {type(last_message)}, tool_rounds: {current_round}/{max_rounds}")
# If last message is AI message with tool calls, route to tools
if isinstance(last_message, AIMessage):
has_tool_calls = hasattr(last_message, 'tool_calls') and last_message.tool_calls
logger.info(f"user_manual_should_continue: AI message has tool_calls: {has_tool_calls}")
if has_tool_calls:
logger.info("user_manual_should_continue: Routing to user_manual_tools")
return "user_manual_tools"
else:
# No tool calls = final synthesis already completed in user_manual_agent_node
logger.info("user_manual_should_continue: No tool calls, routing to post_process")
return "post_process"
# If last message is tool message(s), continue with agent for next round or final synthesis
if isinstance(last_message, ToolMessage):
logger.info("user_manual_should_continue: Tool message completed, continuing to user_manual_agent")
return "user_manual_agent"
logger.info("user_manual_should_continue: Routing to post_process")
return "post_process"
# User Manual Tools node with streaming support
async def run_user_manual_tools_with_streaming(state: AgentState, config: Optional[RunnableConfig] = None) -> Dict[str, Any]:
"""Execute user manual tools with streaming events - supports parallel execution"""
messages = state["messages"]
last_message = messages[-1]
# Get stream callback from context variable
from .graph import stream_callback_context
stream_callback = stream_callback_context.get()
if not isinstance(last_message, AIMessage) or not hasattr(last_message, 'tool_calls'):
return {"messages": []}
tool_calls = last_message.tool_calls or []
tool_results = []
new_messages = []
# User manual tools mapping
tools_map = get_user_manual_tools_by_name()
async def execute_single_tool(tool_call):
"""Execute a single user manual tool call with enhanced error handling"""
# Get stream callback from context
from .graph import stream_callback_context
stream_callback = stream_callback_context.get()
# Apply error handling decorator
@handle_async_errors(
ErrorCategory.TOOL,
ErrorCode.TOOL_ERROR,
stream_callback,
tool_call.get("id", "unknown") if isinstance(tool_call, dict) else "unknown"
)
async def _execute():
# Validate tool_call format
if not isinstance(tool_call, dict):
raise ValueError(f"Tool call must be dict, got {type(tool_call)}")
tool_name = tool_call.get("name")
tool_args = tool_call.get("args", {})
tool_id = tool_call.get("id", "unknown")
if not tool_name:
raise ValueError("Tool call missing 'name' field")
if tool_name not in tools_map:
available_tools = list(tools_map.keys())
raise ValueError(f"Tool '{tool_name}' not found. Available user manual tools: {available_tools}")
tool_func = tools_map[tool_name]
# Stream tool start event
if stream_callback:
await stream_callback(create_tool_start_event(tool_id, tool_name, tool_args))
import time
start_time = time.time()
try:
# Execute the user manual tool
result = await tool_func.ainvoke(tool_args)
# Calculate execution time
took_ms = int((time.time() - start_time) * 1000)
# Stream tool result event
if stream_callback:
await stream_callback(create_tool_result_event(tool_id, tool_name, result, took_ms))
# Create tool message
tool_message = ToolMessage(
content=str(result),
tool_call_id=tool_id,
name=tool_name
)
return tool_message, {"name": tool_name, "result": result, "took_ms": took_ms}
except Exception as e:
took_ms = int((time.time() - start_time) * 1000)
error_msg = get_user_message(ErrorCategory.TOOL)
# Stream tool error event
if stream_callback:
await stream_callback(create_tool_error_event(tool_id, tool_name, error_msg))
# Create error tool message
tool_message = ToolMessage(
content=f"Error executing {tool_name}: {error_msg}",
tool_call_id=tool_id,
name=tool_name
)
return tool_message, {"name": tool_name, "error": error_msg, "took_ms": took_ms}
return await _execute()
# Execute user manual tools (typically just one for user manual retrieval)
import asyncio
tasks = [execute_single_tool(tool_call) for tool_call in tool_calls]
results = await asyncio.gather(*tasks, return_exceptions=True)
for i, result in enumerate(results):
if isinstance(result, Exception):
# Handle execution exception
tool_call = tool_calls[i]
tool_id = tool_call.get("id", f"error_{i}") or f"error_{i}"
tool_name = tool_call.get("name", "unknown")
error_msg = get_user_message(ErrorCategory.TOOL)
if stream_callback:
await stream_callback(create_tool_error_event(tool_id, tool_name, error_msg))
error_message = ToolMessage(
content=f"Error executing {tool_name}: {error_msg}",
tool_call_id=tool_id,
name=tool_name
)
new_messages.append(error_message)
elif isinstance(result, tuple) and len(result) == 2:
# result is a tuple: (tool_message, tool_result)
tool_message, tool_result = result
new_messages.append(tool_message)
tool_results.append(tool_result)
else:
# Unexpected result format
logger.error(f"Unexpected tool execution result format: {type(result)}")
continue
return {"messages": new_messages, "tool_results": tool_results}
# Legacy function for backward compatibility
async def user_manual_rag_node(state: AgentState, config: Optional[RunnableConfig] = None) -> Dict[str, Any]:
"""
Legacy user manual RAG node - redirects to new agent-based implementation
"""
logger.info("📚 USER_MANUAL_RAG_NODE: Redirecting to user_manual_agent_node")
return await user_manual_agent_node(state, config)

View File

@@ -0,0 +1,77 @@
"""
User manual specific tools for the Agentic RAG system.
This module contains tools specifically for user manual retrieval and processing.
"""
import logging
from typing import Dict, Any, List
from langchain_core.tools import tool
from ..retrieval.retrieval import AgenticRetrieval
logger = logging.getLogger(__name__)
# User Manual Tools
@tool
async def retrieve_system_usermanual(query: str) -> Dict[str, Any]:
"""Search for document content chunks of user manual of this system(CATOnline)"""
async with AgenticRetrieval() as retrieval:
try:
result = await retrieval.retrieve_doc_chunk_user_manual(
query=query
)
return {
"tool_name": "retrieve_system_usermanual",
"results_count": len(result.results),
"results": result.results, # Already dict objects, no need for model_dump()
"took_ms": result.took_ms
}
except Exception as e:
logger.error(f"User manual retrieval error: {e}")
return {"error": str(e), "results_count": 0, "results": []}
# User manual tools list
user_manual_tools = [retrieve_system_usermanual]
def get_user_manual_tool_schemas() -> List[Dict[str, Any]]:
"""
Generate tool schemas for user manual tools.
Returns:
List of tool schemas in OpenAI function calling format
"""
tool_schemas = []
for tool in user_manual_tools:
schema = {
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Search query for retrieving relevant information"
}
},
"required": ["query"]
}
}
}
tool_schemas.append(schema)
return tool_schemas
def get_user_manual_tools_by_name() -> Dict[str, Any]:
"""
Create a mapping of user manual tool names to tool functions.
Returns:
Dictionary mapping tool names to tool functions
"""
return {tool.name: tool for tool in user_manual_tools}