Files
catonline_ai/vw-agentic-rag/service/graph/graph.py
2025-09-26 17:15:54 +08:00

747 lines
32 KiB
Python

import json
import logging
import re
import asyncio
from typing import Dict, Any, List, Callable, Annotated, Literal, TypedDict, Optional, Union, cast
from datetime import datetime
from urllib.parse import quote
from contextvars import ContextVar
from pydantic import BaseModel
from langgraph.graph import StateGraph, END, add_messages, MessagesState
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, ToolMessage, BaseMessage
from langchain_core.runnables import RunnableConfig
from .state import TurnState, Message, ToolResult, AgentState
from .message_trimmer import create_conversation_trimmer
from .tools import get_tool_schemas, get_tools_by_name
from .user_manual_tools import get_user_manual_tools_by_name
from .intent_recognition import intent_recognition_node, intent_router
from .user_manual_rag import user_manual_rag_node
from ..llm_client import LLMClient
from ..config import get_config
from ..utils.templates import render_prompt_template
from ..memory.postgresql_memory import get_checkpointer
from ..utils.error_handler import (
StructuredLogger, ErrorCategory, ErrorCode,
handle_async_errors, get_user_message
)
from ..sse import (
create_tool_start_event,
create_tool_result_event,
create_tool_error_event,
create_token_event,
create_error_event
)
logger = StructuredLogger(__name__)
# Cache configuration at module level to avoid repeated get_config() calls
_cached_config = None
def get_cached_config():
"""Get cached configuration, loading it if not already cached"""
global _cached_config
if _cached_config is None:
_cached_config = get_config()
return _cached_config
# Context variable for streaming callback (thread-safe)
stream_callback_context: ContextVar[Optional[Callable]] = ContextVar('stream_callback', default=None)
# Agent node (autonomous function calling agent)
async def call_model(state: AgentState, config: Optional[RunnableConfig] = None) -> Dict[str, Any]:
"""
Agent node that autonomously uses tools and generates final answer.
Implements "detect-first-then-stream" strategy for optimal multi-round behavior:
1. Always start with non-streaming detection to check for tool needs
2. If tool_calls exist → return immediately for routing to tools
3. If no tool_calls → temporarily disable tools and perform streaming final synthesis
"""
app_config = get_cached_config()
llm_client = LLMClient()
# Get stream callback from context variable
stream_callback = stream_callback_context.get()
# Get tool schemas and bind tools for planning phase
tool_schemas = get_tool_schemas()
llm_client.bind_tools(tool_schemas, force_tool_choice=True)
# Create conversation trimmer for managing context length
trimmer = create_conversation_trimmer()
# Prepare messages with system prompt
messages = state["messages"].copy()
if not messages or not isinstance(messages[0], SystemMessage):
rag_prompts = app_config.get_rag_prompts()
system_prompt = rag_prompts.get("agent_system_prompt", "")
if not system_prompt:
raise ValueError("system_prompt is null")
messages = [SystemMessage(content=system_prompt)] + messages
# Track tool rounds
current_round = state.get("tool_rounds", 0)
# Get max_tool_rounds from state, fallback to config if not set
max_rounds = state.get("max_tool_rounds", None)
if max_rounds is None:
max_rounds = app_config.app.max_tool_rounds
# Only apply trimming at the start of a new conversation turn (when tool_rounds = 0)
# This prevents trimming current turn's tool results during multi-round tool calling
if current_round == 0:
# Trim conversation history to manage context length (only for previous conversation turns)
if trimmer.should_trim(messages):
messages = trimmer.trim_conversation_history(messages)
logger.info("Applied conversation history trimming for context management (new conversation turn)")
else:
logger.info(f"Skipping trimming during tool round {current_round} to preserve current turn's context")
logger.info(f"Agent node: tool_rounds={current_round}, max_tool_rounds={max_rounds}")
# Check if this should be final synthesis (max rounds reached)
has_tool_messages = any(isinstance(msg, ToolMessage) for msg in messages)
is_final_synthesis = has_tool_messages and current_round >= max_rounds
if is_final_synthesis:
logger.info("Starting final synthesis phase - no more tool calls allowed")
# ✅ STEP 1: Final synthesis with tools disabled from the start
# Disable tools to prevent any tool calling during synthesis
try:
original_tools = llm_client.bind_tools([], force_tool_choice=False) # Disable tools
if not stream_callback:
# No streaming callback, generate final response without tools
draft = await llm_client.ainvoke(list(messages))
return {"messages": [draft]}
# ✅ STEP 2: Streaming final synthesis with improved HTML comment filtering
response_content = ""
accumulated_content = ""
async for token in llm_client.astream(list(messages)):
accumulated_content += token
response_content += token
# Check for complete HTML comments in accumulated content
while "<!--" in accumulated_content and "-->" in accumulated_content:
comment_start = accumulated_content.find("<!--")
comment_end = accumulated_content.find("-->", comment_start)
if comment_start >= 0 and comment_end >= 0:
# Send content before comment
before_comment = accumulated_content[:comment_start]
if stream_callback and before_comment:
await stream_callback(create_token_event(before_comment))
# Skip the comment and continue with content after
accumulated_content = accumulated_content[comment_end + 3:]
else:
break
# Send accumulated content if no pending comment
if "<!--" not in accumulated_content:
if stream_callback and accumulated_content:
await stream_callback(create_token_event(accumulated_content))
accumulated_content = ""
# Send any remaining content (if not in middle of comment)
if accumulated_content and "<!--" not in accumulated_content:
if stream_callback:
await stream_callback(create_token_event(accumulated_content))
return {"messages": [AIMessage(content=response_content)]}
finally:
# ✅ STEP 3: Restore tool binding for next interaction
llm_client.bind_tools(tool_schemas, force_tool_choice=True)
else:
logger.info(f"Tool calling round {current_round + 1}/{max_rounds}")
# ✅ STEP 1: Non-streaming detection to check for tool needs
draft = await llm_client.ainvoke_with_tools(list(messages))
# ✅ STEP 2: If draft has tool_calls, return immediately (let routing handle it)
if isinstance(draft, AIMessage) and hasattr(draft, 'tool_calls') and draft.tool_calls:
# Increment tool round counter for next iteration
new_tool_rounds = current_round + 1
logger.info(f"Incremented tool_rounds to {new_tool_rounds}")
return {"messages": [draft], "tool_rounds": new_tool_rounds}
# ✅ STEP 3: No tool_calls needed → Enter final synthesis with streaming
# Temporarily disable tools to prevent accidental tool calling during synthesis
try:
llm_client.bind_tools([], force_tool_choice=False) # Disable tools
if not stream_callback:
# No streaming callback, use the draft we already have
return {"messages": [draft]}
# ✅ STEP 4: Streaming final synthesis with improved HTML comment filtering
response_content = ""
accumulated_content = ""
async for token in llm_client.astream(list(messages)):
accumulated_content += token
response_content += token
# Check for complete HTML comments in accumulated content
while "<!--" in accumulated_content and "-->" in accumulated_content:
comment_start = accumulated_content.find("<!--")
comment_end = accumulated_content.find("-->", comment_start)
if comment_start >= 0 and comment_end >= 0:
# Send content before comment
before_comment = accumulated_content[:comment_start]
if stream_callback and before_comment:
await stream_callback(create_token_event(before_comment))
# Skip the comment and continue with content after
accumulated_content = accumulated_content[comment_end + 3:]
else:
break
# Send accumulated content if no pending comment
if "<!--" not in accumulated_content:
if stream_callback and accumulated_content:
await stream_callback(create_token_event(accumulated_content))
accumulated_content = ""
# Send any remaining content (if not in middle of comment)
if accumulated_content and "<!--" not in accumulated_content:
if stream_callback:
await stream_callback(create_token_event(accumulated_content))
return {"messages": [AIMessage(content=response_content)]}
finally:
# ✅ STEP 5: Restore tool binding for next interaction
llm_client.bind_tools(tool_schemas, force_tool_choice=True)
# Tools routing condition (simplified for "detect-first-then-stream" strategy)
def should_continue(state: AgentState) -> Literal["tools", "agent", "post_process"]:
"""
Simplified routing logic for "detect-first-then-stream" strategy:
- has tool_calls → route to tools
- no tool_calls → route to post_process (final synthesis already completed)
"""
messages = state["messages"]
if not messages:
logger.info("should_continue: No messages, routing to post_process")
return "post_process"
last_message = messages[-1]
current_round = state.get("tool_rounds", 0)
# Get max_tool_rounds from state, fallback to config if not set
max_rounds = state.get("max_tool_rounds", None)
if max_rounds is None:
app_config = get_cached_config()
max_rounds = app_config.app.max_tool_rounds
logger.info(f"should_continue: Last message type: {type(last_message)}, tool_rounds: {current_round}/{max_rounds}")
# If last message is AI message with tool calls, route to tools
if isinstance(last_message, AIMessage):
has_tool_calls = hasattr(last_message, 'tool_calls') and last_message.tool_calls
logger.info(f"should_continue: AI message has tool_calls: {has_tool_calls}")
if has_tool_calls:
logger.info("should_continue: Routing to tools")
return "tools"
else:
# No tool calls = final synthesis already completed in call_model
logger.info("should_continue: No tool calls, routing to post_process")
return "post_process"
# If last message is tool message(s), continue with agent for next round or final synthesis
if isinstance(last_message, ToolMessage):
logger.info("should_continue: Tool message completed, continuing to agent")
return "agent"
logger.info("should_continue: Routing to post_process")
return "post_process"
# Custom tool node with streaming support and parallel execution
async def run_tools_with_streaming(state: AgentState, config: Optional[RunnableConfig] = None) -> Dict[str, Any]:
"""Execute tools with streaming events - supports parallel execution"""
messages = state["messages"]
last_message = messages[-1]
# Get stream callback from context variable
stream_callback = stream_callback_context.get()
if not isinstance(last_message, AIMessage) or not hasattr(last_message, 'tool_calls'):
return {"messages": []}
tool_calls = last_message.tool_calls or []
tool_results = []
new_messages = []
# Tools mapping
tools_map = get_tools_by_name()
async def execute_single_tool(tool_call):
"""Execute a single tool call with enhanced error handling"""
# Get stream callback from context
stream_callback = stream_callback_context.get()
# Apply error handling decorator
@handle_async_errors(
ErrorCategory.TOOL,
ErrorCode.TOOL_ERROR,
stream_callback,
tool_call.get("id", "unknown") if isinstance(tool_call, dict) else "unknown"
)
async def _execute():
# Validate tool_call format
if not isinstance(tool_call, dict):
raise ValueError(f"Tool call must be dict, got {type(tool_call)}")
tool_name = tool_call.get("name")
tool_args = tool_call.get("args", {})
tool_id = tool_call.get("id", "unknown")
if not tool_name or tool_name not in tools_map:
raise ValueError(f"Tool '{tool_name}' not found")
logger.info(f"Executing tool: {tool_name}", extra={
"tool_id": tool_id, "tool_name": tool_name
})
# Send start event
if stream_callback:
await stream_callback(create_tool_start_event(tool_id, tool_name, tool_args))
# Execute tool
import time
start_time = time.time()
result = await tools_map[tool_name].ainvoke(tool_args)
execution_time = int((time.time() - start_time) * 1000)
# Process result
if isinstance(result, dict):
result["tool_call_id"] = tool_id
if "results" in result and isinstance(result["results"], list):
for i, search_result in enumerate(result["results"]):
if isinstance(search_result, dict):
search_result["@tool_call_id"] = tool_id
search_result["@order_num"] = i
# Send result event
if stream_callback:
await stream_callback(create_tool_result_event(
tool_id, tool_name, result.get("results", []), execution_time
))
# Create tool message
tool_message = ToolMessage(
content=json.dumps(result, ensure_ascii=False),
tool_call_id=tool_id,
name=tool_name
)
return {
"message": tool_message,
"results": result.get("results", []) if isinstance(result, dict) else [],
"success": True
}
try:
return await _execute()
except Exception as e:
# Handle any errors not caught by decorator
tool_id = tool_call.get("id", "unknown") if isinstance(tool_call, dict) else "unknown"
tool_name = tool_call.get("name", "unknown") if isinstance(tool_call, dict) else "unknown"
error_message = ToolMessage(
content=f"Error: {get_user_message(ErrorCategory.TOOL)}",
tool_call_id=tool_id,
name=tool_name
)
return {
"message": error_message,
"results": [],
"success": False
}
# Execute all tool calls in parallel using asyncio.gather
if tool_calls:
logger.info(f"Executing {len(tool_calls)} tool calls in parallel")
tool_execution_results = await asyncio.gather(
*[execute_single_tool(tool_call) for tool_call in tool_calls],
return_exceptions=True
)
# Process results
for execution_result in tool_execution_results:
if execution_result is None:
continue
if isinstance(execution_result, Exception):
logger.error(f"Tool execution exception: {execution_result}")
continue
if not isinstance(execution_result, dict):
logger.error(f"Unexpected execution result type: {type(execution_result)}")
continue
new_messages.append(execution_result["message"])
if execution_result["success"] and execution_result["results"]:
tool_results.extend(execution_result["results"])
logger.info(f"Parallel tool execution completed. {len(new_messages)} tools executed, {len(tool_results)} results collected")
return {
"messages": new_messages,
"tool_results": tool_results
}
# Helper functions for citation processing
def _extract_citations_mapping(agent_response: str) -> Dict[int, Dict[str, Any]]:
"""Extract citations mapping CSV from agent response HTML comment"""
try:
# Look for citations_map comment
pattern = r'<!-- citations_map\s*(.*?)\s*-->'
match = re.search(pattern, agent_response, re.DOTALL | re.IGNORECASE)
if not match:
logger.warning("No citations_map comment found in agent response")
return {}
csv_content = match.group(1).strip()
citations_mapping = {}
for line in csv_content.split('\n'):
line = line.strip()
if not line:
continue
parts = line.split(',')
if len(parts) >= 3:
try:
citation_num = int(parts[0])
tool_call_id = parts[1].strip()
order_num = int(parts[2])
citations_mapping[citation_num] = {
'tool_call_id': tool_call_id,
'order_num': order_num
}
except (ValueError, IndexError) as e:
logger.warning(f"Failed to parse citation line: {line}, error: {e}")
continue
return citations_mapping
except Exception as e:
logger.error(f"Error extracting citations mapping: {e}")
return {}
def _build_citation_markdown(citations_mapping: Dict[int, Dict[str, Any]], tool_results: List[Dict[str, Any]]) -> str:
"""Build citation markdown based on mapping and tool results, following build_citations.py logic"""
if not citations_mapping:
return ""
# Get configuration for citation base URL
config = get_cached_config()
cat_base_url = config.citation.base_url
# Collect citation lines first; only emit header if we have at least one valid citation
entries: List[str] = []
for citation_num in sorted(citations_mapping.keys()):
mapping = citations_mapping[citation_num]
tool_call_id = mapping['tool_call_id']
order_num = mapping['order_num']
# Find the corresponding tool result
result = _find_tool_result(tool_results, tool_call_id, order_num)
if not result:
logger.warning(f"No tool result found for citation [{citation_num}]")
continue
# Extract citation information following build_citations.py logic
full_headers = result.get('full_headers', '')
lowest_header = full_headers.split("||", 1)[0] if full_headers else ""
header_display = f": {lowest_header}" if lowest_header else ""
document_code = result.get('document_code', '')
document_category = result.get('document_category', '')
# Determine standard/regulation title (assuming English language)
standard_regulation_title = ''
if document_category == 'Standard':
standard_regulation_title = result.get('x_Standard_Title_EN', '') or result.get('x_Standard_Title_CN', '')
elif document_category == 'Regulation':
standard_regulation_title = result.get('x_Regulation_Title_EN', '') or result.get('x_Regulation_Title_CN', '')
# Build link
func_uuid = result.get('func_uuid', '')
uuid = result.get('x_Standard_Regulation_Id', '')
document_code_encoded = quote(document_code, safe='') if document_code else ''
standard_regulation_title_encoded = quote(standard_regulation_title, safe='') if standard_regulation_title else ''
link_name = f"{document_code_encoded}({standard_regulation_title_encoded})" if (document_code_encoded or standard_regulation_title_encoded) else ''
link = f'{cat_base_url}?funcUuid={func_uuid}&uuid={uuid}&name={link_name}'
# Format citation line
title = result.get('title', '')
entries.append(f"[{citation_num}] {title}{header_display} | [{standard_regulation_title} | {document_code}]({link})")
# If no valid citations were found, do not include the header
if not entries:
return ""
# Build citations section with entries separated by a blank line (matching previous formatting)
md = "\n\n### 📘 Citations:\n" + "\n\n".join(entries) + "\n\n"
return md
def _find_tool_result(tool_results: List[Dict[str, Any]], tool_call_id: str, order_num: int) -> Optional[Dict[str, Any]]:
"""Find tool result by tool_call_id and order_num"""
matching_results = []
for result in tool_results:
if result.get('@tool_call_id') == tool_call_id:
matching_results.append(result)
# Sort by order and return the one at the specified position
if matching_results and 0 <= order_num < len(matching_results):
# If results have @order_num, use it; otherwise use position in list
if '@order_num' in matching_results[0]:
for result in matching_results:
if result.get('@order_num') == order_num:
return result
else:
return matching_results[order_num]
return None
def _remove_citations_comment(agent_response: str) -> str:
"""Remove citations mapping HTML comment from agent response"""
pattern = r'<!-- citations_map\s*.*?\s*-->'
return re.sub(pattern, '', agent_response, flags=re.DOTALL | re.IGNORECASE).strip()
# Post-processing node with citation list and link building
async def post_process_node(state: AgentState, config: Optional[RunnableConfig] = None) -> Dict[str, Any]:
"""
Post-processing node that builds citation list and links based on agent's citations mapping
and tool call results, following the logic from build_citations.py
"""
try:
logger.info("🔧 POST_PROCESS_NODE: Starting citation processing")
# Get stream callback from context variable
stream_callback = stream_callback_context.get()
# Get the last AI message (agent's response with citations mapping)
agent_response = ""
citations_mapping = {}
for message in reversed(state["messages"]):
if isinstance(message, AIMessage) and message.content:
# Ensure content is a string
if isinstance(message.content, str):
agent_response = message.content
break
if not agent_response:
logger.warning("POST_PROCESS_NODE: No agent response found")
return {"messages": [], "final_answer": ""}
# Extract citations mapping from agent response
citations_mapping = _extract_citations_mapping(agent_response)
logger.info(f"POST_PROCESS_NODE: Extracted {len(citations_mapping)} citations")
# Build citation markdown
citation_markdown = _build_citation_markdown(citations_mapping, state["tool_results"])
# Combine agent response (without HTML comment) with citations
clean_response = _remove_citations_comment(agent_response)
final_content = clean_response + citation_markdown
logger.info("POST_PROCESS_NODE: Built complete response with citations")
# Send citation markdown as a single block instead of streaming
stream_callback = stream_callback_context.get()
if stream_callback and citation_markdown:
logger.info("POST_PROCESS_NODE: Sending citation markdown as single block to client")
await stream_callback(create_token_event(citation_markdown))
# Create AI message with complete content
final_ai_message = AIMessage(content=final_content)
return {
"messages": [final_ai_message],
"final_answer": final_content
}
except Exception as e:
logger.error(f"Post-processing error: {e}")
error_message = "\n\n❌ **Error generating citations**\n\nPlease check the search results above."
# Send error message as single block
stream_callback = stream_callback_context.get()
if stream_callback:
await stream_callback(create_token_event(error_message))
error_content = agent_response + error_message if agent_response else error_message
error_ai_message = AIMessage(content=error_content)
return {
"messages": [error_ai_message],
"final_answer": error_ai_message.content
}
# Main workflow class
class AgenticWorkflow:
"""LangGraph-based autonomous agent workflow following v0.6.0+ best practices"""
def __init__(self):
# Build StateGraph with TypedDict state
workflow = StateGraph(AgentState)
# Add nodes following best practices
workflow.add_node("intent_recognition", intent_recognition_node)
workflow.add_node("agent", call_model)
workflow.add_node("user_manual_rag", user_manual_rag_node)
workflow.add_node("tools", run_tools_with_streaming)
workflow.add_node("post_process", post_process_node)
# Set entry point to intent recognition
workflow.set_entry_point("intent_recognition")
# Intent recognition routes to either Standard_Regulation_RAG or User_Manual_RAG
workflow.add_conditional_edges(
"intent_recognition",
intent_router,
{
"Standard_Regulation_RAG": "agent",
"User_Manual_RAG": "user_manual_rag"
}
)
# Standard RAG workflow (existing pattern)
workflow.add_conditional_edges(
"agent",
should_continue,
{
"tools": "tools",
"agent": "agent", # Allow agent to continue for multi-round
"post_process": "post_process"
}
)
# Tools route back to should_continue for multi-round decision
workflow.add_conditional_edges(
"tools",
should_continue,
{
"agent": "agent", # Continue to agent for next round
"post_process": "post_process" # Or finish if max rounds reached
}
)
# User Manual RAG directly goes to END (single turn)
workflow.add_edge("user_manual_rag", END)
# Post-process is terminal
workflow.add_edge("post_process", END)
# Compile graph with PostgreSQL checkpointer for session memory
try:
checkpointer = get_checkpointer()
self.graph = workflow.compile(checkpointer=checkpointer)
logger.info("Graph compiled with PostgreSQL checkpointer for session memory")
except Exception as e:
logger.warning(f"Failed to initialize PostgreSQL checkpointer, using memory-only graph: {e}")
self.graph = workflow.compile()
async def astream(self, state: TurnState, stream_callback: Callable | None = None):
"""Stream agent execution using LangGraph with PostgreSQL session memory"""
try:
# Get configuration
config = get_cached_config()
# Prepare initial messages for the graph
messages = []
for msg in state.messages:
if msg.role == "user":
messages.append(HumanMessage(content=msg.content))
elif msg.role == "assistant":
messages.append(AIMessage(content=msg.content))
# Create initial agent state (without stream_callback to avoid serialization issues)
initial_state: AgentState = {
"messages": messages,
"session_id": state.session_id,
"intent": None, # Will be determined by intent recognition node
"tool_results": [],
"final_answer": "",
"tool_rounds": 0,
"max_tool_rounds": config.app.max_tool_rounds, # Use configuration value
"max_tool_rounds_user_manual": config.app.max_tool_rounds_user_manual # Use configuration value for user manual agent
}
# Set stream callback in context variable (thread-safe)
stream_callback_context.set(stream_callback)
# Create proper RunnableConfig
runnable_config = RunnableConfig(configurable={"thread_id": state.session_id})
# Stream graph execution with session memory
async for step in self.graph.astream(initial_state, config=runnable_config):
if "post_process" in step:
final_state = step["post_process"]
# Extract the tool summary message and update state
state.final_answer = final_state.get("final_answer", "")
# Add the summary as a regular assistant message
if state.final_answer:
state.messages.append(Message(
role="assistant",
content=state.final_answer,
timestamp=datetime.now()
))
yield {"final": state}
break
elif "user_manual_rag" in step:
# Handle user manual RAG completion
final_state = step["user_manual_rag"]
# Extract the response from user manual RAG
state.final_answer = final_state.get("final_answer", "")
# Add the response as a regular assistant message
if state.final_answer:
state.messages.append(Message(
role="assistant",
content=state.final_answer,
timestamp=datetime.now()
))
yield {"final": state}
break
else:
# Process regular steps (intent_recognition, agent, tools)
yield step
except Exception as e:
logger.error(f"AgentWorkflow error: {e}")
state.final_answer = "I apologize, but I encountered an error while processing your request."
yield {"final": state}
def build_graph() -> AgenticWorkflow:
"""Build and return the autonomous agent workflow"""
return AgenticWorkflow()