747 lines
32 KiB
Python
747 lines
32 KiB
Python
import json
|
|
import logging
|
|
import re
|
|
import asyncio
|
|
from typing import Dict, Any, List, Callable, Annotated, Literal, TypedDict, Optional, Union, cast
|
|
from datetime import datetime
|
|
from urllib.parse import quote
|
|
from contextvars import ContextVar
|
|
from pydantic import BaseModel
|
|
|
|
from langgraph.graph import StateGraph, END, add_messages, MessagesState
|
|
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, ToolMessage, BaseMessage
|
|
from langchain_core.runnables import RunnableConfig
|
|
|
|
from .state import TurnState, Message, ToolResult, AgentState
|
|
from .message_trimmer import create_conversation_trimmer
|
|
from .tools import get_tool_schemas, get_tools_by_name
|
|
from .user_manual_tools import get_user_manual_tools_by_name
|
|
from .intent_recognition import intent_recognition_node, intent_router
|
|
from .user_manual_rag import user_manual_rag_node
|
|
from ..llm_client import LLMClient
|
|
from ..config import get_config
|
|
from ..utils.templates import render_prompt_template
|
|
from ..memory.postgresql_memory import get_checkpointer
|
|
from ..utils.error_handler import (
|
|
StructuredLogger, ErrorCategory, ErrorCode,
|
|
handle_async_errors, get_user_message
|
|
)
|
|
from ..sse import (
|
|
create_tool_start_event,
|
|
create_tool_result_event,
|
|
create_tool_error_event,
|
|
create_token_event,
|
|
create_error_event
|
|
)
|
|
|
|
logger = StructuredLogger(__name__)
|
|
|
|
# Cache configuration at module level to avoid repeated get_config() calls
|
|
_cached_config = None
|
|
|
|
def get_cached_config():
|
|
"""Get cached configuration, loading it if not already cached"""
|
|
global _cached_config
|
|
if _cached_config is None:
|
|
_cached_config = get_config()
|
|
return _cached_config
|
|
|
|
# Context variable for streaming callback (thread-safe)
|
|
stream_callback_context: ContextVar[Optional[Callable]] = ContextVar('stream_callback', default=None)
|
|
|
|
|
|
# Agent node (autonomous function calling agent)
|
|
async def call_model(state: AgentState, config: Optional[RunnableConfig] = None) -> Dict[str, Any]:
|
|
"""
|
|
Agent node that autonomously uses tools and generates final answer.
|
|
Implements "detect-first-then-stream" strategy for optimal multi-round behavior:
|
|
1. Always start with non-streaming detection to check for tool needs
|
|
2. If tool_calls exist → return immediately for routing to tools
|
|
3. If no tool_calls → temporarily disable tools and perform streaming final synthesis
|
|
"""
|
|
app_config = get_cached_config()
|
|
llm_client = LLMClient()
|
|
|
|
# Get stream callback from context variable
|
|
stream_callback = stream_callback_context.get()
|
|
|
|
# Get tool schemas and bind tools for planning phase
|
|
tool_schemas = get_tool_schemas()
|
|
llm_client.bind_tools(tool_schemas, force_tool_choice=True)
|
|
|
|
# Create conversation trimmer for managing context length
|
|
trimmer = create_conversation_trimmer()
|
|
|
|
# Prepare messages with system prompt
|
|
messages = state["messages"].copy()
|
|
if not messages or not isinstance(messages[0], SystemMessage):
|
|
rag_prompts = app_config.get_rag_prompts()
|
|
system_prompt = rag_prompts.get("agent_system_prompt", "")
|
|
if not system_prompt:
|
|
raise ValueError("system_prompt is null")
|
|
|
|
messages = [SystemMessage(content=system_prompt)] + messages
|
|
|
|
# Track tool rounds
|
|
current_round = state.get("tool_rounds", 0)
|
|
# Get max_tool_rounds from state, fallback to config if not set
|
|
max_rounds = state.get("max_tool_rounds", None)
|
|
if max_rounds is None:
|
|
max_rounds = app_config.app.max_tool_rounds
|
|
|
|
# Only apply trimming at the start of a new conversation turn (when tool_rounds = 0)
|
|
# This prevents trimming current turn's tool results during multi-round tool calling
|
|
if current_round == 0:
|
|
# Trim conversation history to manage context length (only for previous conversation turns)
|
|
if trimmer.should_trim(messages):
|
|
messages = trimmer.trim_conversation_history(messages)
|
|
logger.info("Applied conversation history trimming for context management (new conversation turn)")
|
|
else:
|
|
logger.info(f"Skipping trimming during tool round {current_round} to preserve current turn's context")
|
|
|
|
logger.info(f"Agent node: tool_rounds={current_round}, max_tool_rounds={max_rounds}")
|
|
|
|
# Check if this should be final synthesis (max rounds reached)
|
|
has_tool_messages = any(isinstance(msg, ToolMessage) for msg in messages)
|
|
is_final_synthesis = has_tool_messages and current_round >= max_rounds
|
|
|
|
if is_final_synthesis:
|
|
logger.info("Starting final synthesis phase - no more tool calls allowed")
|
|
# ✅ STEP 1: Final synthesis with tools disabled from the start
|
|
# Disable tools to prevent any tool calling during synthesis
|
|
try:
|
|
original_tools = llm_client.bind_tools([], force_tool_choice=False) # Disable tools
|
|
|
|
if not stream_callback:
|
|
# No streaming callback, generate final response without tools
|
|
draft = await llm_client.ainvoke(list(messages))
|
|
return {"messages": [draft]}
|
|
|
|
# ✅ STEP 2: Streaming final synthesis with improved HTML comment filtering
|
|
response_content = ""
|
|
accumulated_content = ""
|
|
|
|
async for token in llm_client.astream(list(messages)):
|
|
accumulated_content += token
|
|
response_content += token
|
|
|
|
# Check for complete HTML comments in accumulated content
|
|
while "<!--" in accumulated_content and "-->" in accumulated_content:
|
|
comment_start = accumulated_content.find("<!--")
|
|
comment_end = accumulated_content.find("-->", comment_start)
|
|
|
|
if comment_start >= 0 and comment_end >= 0:
|
|
# Send content before comment
|
|
before_comment = accumulated_content[:comment_start]
|
|
if stream_callback and before_comment:
|
|
await stream_callback(create_token_event(before_comment))
|
|
|
|
# Skip the comment and continue with content after
|
|
accumulated_content = accumulated_content[comment_end + 3:]
|
|
else:
|
|
break
|
|
|
|
# Send accumulated content if no pending comment
|
|
if "<!--" not in accumulated_content:
|
|
if stream_callback and accumulated_content:
|
|
await stream_callback(create_token_event(accumulated_content))
|
|
accumulated_content = ""
|
|
|
|
# Send any remaining content (if not in middle of comment)
|
|
if accumulated_content and "<!--" not in accumulated_content:
|
|
if stream_callback:
|
|
await stream_callback(create_token_event(accumulated_content))
|
|
|
|
return {"messages": [AIMessage(content=response_content)]}
|
|
|
|
finally:
|
|
# ✅ STEP 3: Restore tool binding for next interaction
|
|
llm_client.bind_tools(tool_schemas, force_tool_choice=True)
|
|
|
|
else:
|
|
logger.info(f"Tool calling round {current_round + 1}/{max_rounds}")
|
|
|
|
# ✅ STEP 1: Non-streaming detection to check for tool needs
|
|
draft = await llm_client.ainvoke_with_tools(list(messages))
|
|
|
|
# ✅ STEP 2: If draft has tool_calls, return immediately (let routing handle it)
|
|
if isinstance(draft, AIMessage) and hasattr(draft, 'tool_calls') and draft.tool_calls:
|
|
# Increment tool round counter for next iteration
|
|
new_tool_rounds = current_round + 1
|
|
logger.info(f"Incremented tool_rounds to {new_tool_rounds}")
|
|
return {"messages": [draft], "tool_rounds": new_tool_rounds}
|
|
|
|
# ✅ STEP 3: No tool_calls needed → Enter final synthesis with streaming
|
|
# Temporarily disable tools to prevent accidental tool calling during synthesis
|
|
try:
|
|
llm_client.bind_tools([], force_tool_choice=False) # Disable tools
|
|
|
|
if not stream_callback:
|
|
# No streaming callback, use the draft we already have
|
|
return {"messages": [draft]}
|
|
|
|
# ✅ STEP 4: Streaming final synthesis with improved HTML comment filtering
|
|
response_content = ""
|
|
accumulated_content = ""
|
|
|
|
async for token in llm_client.astream(list(messages)):
|
|
accumulated_content += token
|
|
response_content += token
|
|
|
|
# Check for complete HTML comments in accumulated content
|
|
while "<!--" in accumulated_content and "-->" in accumulated_content:
|
|
comment_start = accumulated_content.find("<!--")
|
|
comment_end = accumulated_content.find("-->", comment_start)
|
|
|
|
if comment_start >= 0 and comment_end >= 0:
|
|
# Send content before comment
|
|
before_comment = accumulated_content[:comment_start]
|
|
if stream_callback and before_comment:
|
|
await stream_callback(create_token_event(before_comment))
|
|
|
|
# Skip the comment and continue with content after
|
|
accumulated_content = accumulated_content[comment_end + 3:]
|
|
else:
|
|
break
|
|
|
|
# Send accumulated content if no pending comment
|
|
if "<!--" not in accumulated_content:
|
|
if stream_callback and accumulated_content:
|
|
await stream_callback(create_token_event(accumulated_content))
|
|
accumulated_content = ""
|
|
|
|
# Send any remaining content (if not in middle of comment)
|
|
if accumulated_content and "<!--" not in accumulated_content:
|
|
if stream_callback:
|
|
await stream_callback(create_token_event(accumulated_content))
|
|
|
|
return {"messages": [AIMessage(content=response_content)]}
|
|
|
|
finally:
|
|
# ✅ STEP 5: Restore tool binding for next interaction
|
|
llm_client.bind_tools(tool_schemas, force_tool_choice=True)
|
|
|
|
|
|
# Tools routing condition (simplified for "detect-first-then-stream" strategy)
|
|
def should_continue(state: AgentState) -> Literal["tools", "agent", "post_process"]:
|
|
"""
|
|
Simplified routing logic for "detect-first-then-stream" strategy:
|
|
- has tool_calls → route to tools
|
|
- no tool_calls → route to post_process (final synthesis already completed)
|
|
"""
|
|
messages = state["messages"]
|
|
if not messages:
|
|
logger.info("should_continue: No messages, routing to post_process")
|
|
return "post_process"
|
|
|
|
last_message = messages[-1]
|
|
current_round = state.get("tool_rounds", 0)
|
|
# Get max_tool_rounds from state, fallback to config if not set
|
|
max_rounds = state.get("max_tool_rounds", None)
|
|
if max_rounds is None:
|
|
app_config = get_cached_config()
|
|
max_rounds = app_config.app.max_tool_rounds
|
|
|
|
logger.info(f"should_continue: Last message type: {type(last_message)}, tool_rounds: {current_round}/{max_rounds}")
|
|
|
|
# If last message is AI message with tool calls, route to tools
|
|
if isinstance(last_message, AIMessage):
|
|
has_tool_calls = hasattr(last_message, 'tool_calls') and last_message.tool_calls
|
|
logger.info(f"should_continue: AI message has tool_calls: {has_tool_calls}")
|
|
|
|
if has_tool_calls:
|
|
logger.info("should_continue: Routing to tools")
|
|
return "tools"
|
|
else:
|
|
# No tool calls = final synthesis already completed in call_model
|
|
logger.info("should_continue: No tool calls, routing to post_process")
|
|
return "post_process"
|
|
|
|
# If last message is tool message(s), continue with agent for next round or final synthesis
|
|
if isinstance(last_message, ToolMessage):
|
|
logger.info("should_continue: Tool message completed, continuing to agent")
|
|
return "agent"
|
|
|
|
logger.info("should_continue: Routing to post_process")
|
|
return "post_process"
|
|
|
|
|
|
# Custom tool node with streaming support and parallel execution
|
|
async def run_tools_with_streaming(state: AgentState, config: Optional[RunnableConfig] = None) -> Dict[str, Any]:
|
|
"""Execute tools with streaming events - supports parallel execution"""
|
|
messages = state["messages"]
|
|
last_message = messages[-1]
|
|
|
|
# Get stream callback from context variable
|
|
stream_callback = stream_callback_context.get()
|
|
|
|
if not isinstance(last_message, AIMessage) or not hasattr(last_message, 'tool_calls'):
|
|
return {"messages": []}
|
|
|
|
tool_calls = last_message.tool_calls or []
|
|
tool_results = []
|
|
new_messages = []
|
|
|
|
# Tools mapping
|
|
tools_map = get_tools_by_name()
|
|
|
|
async def execute_single_tool(tool_call):
|
|
"""Execute a single tool call with enhanced error handling"""
|
|
# Get stream callback from context
|
|
stream_callback = stream_callback_context.get()
|
|
|
|
# Apply error handling decorator
|
|
@handle_async_errors(
|
|
ErrorCategory.TOOL,
|
|
ErrorCode.TOOL_ERROR,
|
|
stream_callback,
|
|
tool_call.get("id", "unknown") if isinstance(tool_call, dict) else "unknown"
|
|
)
|
|
async def _execute():
|
|
# Validate tool_call format
|
|
if not isinstance(tool_call, dict):
|
|
raise ValueError(f"Tool call must be dict, got {type(tool_call)}")
|
|
|
|
tool_name = tool_call.get("name")
|
|
tool_args = tool_call.get("args", {})
|
|
tool_id = tool_call.get("id", "unknown")
|
|
|
|
if not tool_name or tool_name not in tools_map:
|
|
raise ValueError(f"Tool '{tool_name}' not found")
|
|
|
|
logger.info(f"Executing tool: {tool_name}", extra={
|
|
"tool_id": tool_id, "tool_name": tool_name
|
|
})
|
|
|
|
# Send start event
|
|
if stream_callback:
|
|
await stream_callback(create_tool_start_event(tool_id, tool_name, tool_args))
|
|
|
|
# Execute tool
|
|
import time
|
|
start_time = time.time()
|
|
result = await tools_map[tool_name].ainvoke(tool_args)
|
|
execution_time = int((time.time() - start_time) * 1000)
|
|
|
|
# Process result
|
|
if isinstance(result, dict):
|
|
result["tool_call_id"] = tool_id
|
|
if "results" in result and isinstance(result["results"], list):
|
|
for i, search_result in enumerate(result["results"]):
|
|
if isinstance(search_result, dict):
|
|
search_result["@tool_call_id"] = tool_id
|
|
search_result["@order_num"] = i
|
|
|
|
# Send result event
|
|
if stream_callback:
|
|
await stream_callback(create_tool_result_event(
|
|
tool_id, tool_name, result.get("results", []), execution_time
|
|
))
|
|
|
|
# Create tool message
|
|
tool_message = ToolMessage(
|
|
content=json.dumps(result, ensure_ascii=False),
|
|
tool_call_id=tool_id,
|
|
name=tool_name
|
|
)
|
|
|
|
return {
|
|
"message": tool_message,
|
|
"results": result.get("results", []) if isinstance(result, dict) else [],
|
|
"success": True
|
|
}
|
|
|
|
try:
|
|
return await _execute()
|
|
except Exception as e:
|
|
# Handle any errors not caught by decorator
|
|
tool_id = tool_call.get("id", "unknown") if isinstance(tool_call, dict) else "unknown"
|
|
tool_name = tool_call.get("name", "unknown") if isinstance(tool_call, dict) else "unknown"
|
|
|
|
error_message = ToolMessage(
|
|
content=f"Error: {get_user_message(ErrorCategory.TOOL)}",
|
|
tool_call_id=tool_id,
|
|
name=tool_name
|
|
)
|
|
|
|
return {
|
|
"message": error_message,
|
|
"results": [],
|
|
"success": False
|
|
}
|
|
|
|
# Execute all tool calls in parallel using asyncio.gather
|
|
if tool_calls:
|
|
logger.info(f"Executing {len(tool_calls)} tool calls in parallel")
|
|
tool_execution_results = await asyncio.gather(
|
|
*[execute_single_tool(tool_call) for tool_call in tool_calls],
|
|
return_exceptions=True
|
|
)
|
|
|
|
# Process results
|
|
for execution_result in tool_execution_results:
|
|
if execution_result is None:
|
|
continue
|
|
if isinstance(execution_result, Exception):
|
|
logger.error(f"Tool execution exception: {execution_result}")
|
|
continue
|
|
if not isinstance(execution_result, dict):
|
|
logger.error(f"Unexpected execution result type: {type(execution_result)}")
|
|
continue
|
|
|
|
new_messages.append(execution_result["message"])
|
|
if execution_result["success"] and execution_result["results"]:
|
|
tool_results.extend(execution_result["results"])
|
|
|
|
logger.info(f"Parallel tool execution completed. {len(new_messages)} tools executed, {len(tool_results)} results collected")
|
|
|
|
return {
|
|
"messages": new_messages,
|
|
"tool_results": tool_results
|
|
}
|
|
|
|
|
|
# Helper functions for citation processing
|
|
def _extract_citations_mapping(agent_response: str) -> Dict[int, Dict[str, Any]]:
|
|
"""Extract citations mapping CSV from agent response HTML comment"""
|
|
try:
|
|
# Look for citations_map comment
|
|
pattern = r'<!-- citations_map\s*(.*?)\s*-->'
|
|
match = re.search(pattern, agent_response, re.DOTALL | re.IGNORECASE)
|
|
|
|
if not match:
|
|
logger.warning("No citations_map comment found in agent response")
|
|
return {}
|
|
|
|
csv_content = match.group(1).strip()
|
|
citations_mapping = {}
|
|
|
|
for line in csv_content.split('\n'):
|
|
line = line.strip()
|
|
if not line:
|
|
continue
|
|
|
|
parts = line.split(',')
|
|
if len(parts) >= 3:
|
|
try:
|
|
citation_num = int(parts[0])
|
|
tool_call_id = parts[1].strip()
|
|
order_num = int(parts[2])
|
|
|
|
citations_mapping[citation_num] = {
|
|
'tool_call_id': tool_call_id,
|
|
'order_num': order_num
|
|
}
|
|
except (ValueError, IndexError) as e:
|
|
logger.warning(f"Failed to parse citation line: {line}, error: {e}")
|
|
continue
|
|
|
|
return citations_mapping
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error extracting citations mapping: {e}")
|
|
return {}
|
|
|
|
|
|
def _build_citation_markdown(citations_mapping: Dict[int, Dict[str, Any]], tool_results: List[Dict[str, Any]]) -> str:
|
|
"""Build citation markdown based on mapping and tool results, following build_citations.py logic"""
|
|
if not citations_mapping:
|
|
return ""
|
|
|
|
# Get configuration for citation base URL
|
|
config = get_cached_config()
|
|
cat_base_url = config.citation.base_url
|
|
|
|
# Collect citation lines first; only emit header if we have at least one valid citation
|
|
entries: List[str] = []
|
|
|
|
for citation_num in sorted(citations_mapping.keys()):
|
|
mapping = citations_mapping[citation_num]
|
|
tool_call_id = mapping['tool_call_id']
|
|
order_num = mapping['order_num']
|
|
|
|
# Find the corresponding tool result
|
|
result = _find_tool_result(tool_results, tool_call_id, order_num)
|
|
if not result:
|
|
logger.warning(f"No tool result found for citation [{citation_num}]")
|
|
continue
|
|
|
|
# Extract citation information following build_citations.py logic
|
|
full_headers = result.get('full_headers', '')
|
|
lowest_header = full_headers.split("||", 1)[0] if full_headers else ""
|
|
header_display = f": {lowest_header}" if lowest_header else ""
|
|
|
|
document_code = result.get('document_code', '')
|
|
document_category = result.get('document_category', '')
|
|
|
|
# Determine standard/regulation title (assuming English language)
|
|
standard_regulation_title = ''
|
|
if document_category == 'Standard':
|
|
standard_regulation_title = result.get('x_Standard_Title_EN', '') or result.get('x_Standard_Title_CN', '')
|
|
elif document_category == 'Regulation':
|
|
standard_regulation_title = result.get('x_Regulation_Title_EN', '') or result.get('x_Regulation_Title_CN', '')
|
|
|
|
# Build link
|
|
func_uuid = result.get('func_uuid', '')
|
|
uuid = result.get('x_Standard_Regulation_Id', '')
|
|
document_code_encoded = quote(document_code, safe='') if document_code else ''
|
|
standard_regulation_title_encoded = quote(standard_regulation_title, safe='') if standard_regulation_title else ''
|
|
link_name = f"{document_code_encoded}({standard_regulation_title_encoded})" if (document_code_encoded or standard_regulation_title_encoded) else ''
|
|
link = f'{cat_base_url}?funcUuid={func_uuid}&uuid={uuid}&name={link_name}'
|
|
|
|
# Format citation line
|
|
title = result.get('title', '')
|
|
entries.append(f"[{citation_num}] {title}{header_display} | [{standard_regulation_title} | {document_code}]({link})")
|
|
|
|
# If no valid citations were found, do not include the header
|
|
if not entries:
|
|
return ""
|
|
|
|
# Build citations section with entries separated by a blank line (matching previous formatting)
|
|
md = "\n\n### 📘 Citations:\n" + "\n\n".join(entries) + "\n\n"
|
|
return md
|
|
|
|
|
|
def _find_tool_result(tool_results: List[Dict[str, Any]], tool_call_id: str, order_num: int) -> Optional[Dict[str, Any]]:
|
|
"""Find tool result by tool_call_id and order_num"""
|
|
matching_results = []
|
|
|
|
for result in tool_results:
|
|
if result.get('@tool_call_id') == tool_call_id:
|
|
matching_results.append(result)
|
|
|
|
# Sort by order and return the one at the specified position
|
|
if matching_results and 0 <= order_num < len(matching_results):
|
|
# If results have @order_num, use it; otherwise use position in list
|
|
if '@order_num' in matching_results[0]:
|
|
for result in matching_results:
|
|
if result.get('@order_num') == order_num:
|
|
return result
|
|
else:
|
|
return matching_results[order_num]
|
|
|
|
return None
|
|
|
|
|
|
def _remove_citations_comment(agent_response: str) -> str:
|
|
"""Remove citations mapping HTML comment from agent response"""
|
|
pattern = r'<!-- citations_map\s*.*?\s*-->'
|
|
return re.sub(pattern, '', agent_response, flags=re.DOTALL | re.IGNORECASE).strip()
|
|
|
|
|
|
# Post-processing node with citation list and link building
|
|
async def post_process_node(state: AgentState, config: Optional[RunnableConfig] = None) -> Dict[str, Any]:
|
|
"""
|
|
Post-processing node that builds citation list and links based on agent's citations mapping
|
|
and tool call results, following the logic from build_citations.py
|
|
"""
|
|
try:
|
|
logger.info("🔧 POST_PROCESS_NODE: Starting citation processing")
|
|
|
|
# Get stream callback from context variable
|
|
stream_callback = stream_callback_context.get()
|
|
|
|
# Get the last AI message (agent's response with citations mapping)
|
|
agent_response = ""
|
|
citations_mapping = {}
|
|
|
|
for message in reversed(state["messages"]):
|
|
if isinstance(message, AIMessage) and message.content:
|
|
# Ensure content is a string
|
|
if isinstance(message.content, str):
|
|
agent_response = message.content
|
|
break
|
|
|
|
if not agent_response:
|
|
logger.warning("POST_PROCESS_NODE: No agent response found")
|
|
return {"messages": [], "final_answer": ""}
|
|
|
|
# Extract citations mapping from agent response
|
|
citations_mapping = _extract_citations_mapping(agent_response)
|
|
logger.info(f"POST_PROCESS_NODE: Extracted {len(citations_mapping)} citations")
|
|
|
|
# Build citation markdown
|
|
citation_markdown = _build_citation_markdown(citations_mapping, state["tool_results"])
|
|
|
|
# Combine agent response (without HTML comment) with citations
|
|
clean_response = _remove_citations_comment(agent_response)
|
|
final_content = clean_response + citation_markdown
|
|
|
|
logger.info("POST_PROCESS_NODE: Built complete response with citations")
|
|
|
|
# Send citation markdown as a single block instead of streaming
|
|
stream_callback = stream_callback_context.get()
|
|
if stream_callback and citation_markdown:
|
|
logger.info("POST_PROCESS_NODE: Sending citation markdown as single block to client")
|
|
await stream_callback(create_token_event(citation_markdown))
|
|
|
|
# Create AI message with complete content
|
|
final_ai_message = AIMessage(content=final_content)
|
|
|
|
return {
|
|
"messages": [final_ai_message],
|
|
"final_answer": final_content
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Post-processing error: {e}")
|
|
error_message = "\n\n❌ **Error generating citations**\n\nPlease check the search results above."
|
|
|
|
# Send error message as single block
|
|
stream_callback = stream_callback_context.get()
|
|
if stream_callback:
|
|
await stream_callback(create_token_event(error_message))
|
|
|
|
error_content = agent_response + error_message if agent_response else error_message
|
|
error_ai_message = AIMessage(content=error_content)
|
|
return {
|
|
"messages": [error_ai_message],
|
|
"final_answer": error_ai_message.content
|
|
}
|
|
|
|
|
|
# Main workflow class
|
|
class AgenticWorkflow:
|
|
"""LangGraph-based autonomous agent workflow following v0.6.0+ best practices"""
|
|
|
|
def __init__(self):
|
|
# Build StateGraph with TypedDict state
|
|
workflow = StateGraph(AgentState)
|
|
|
|
# Add nodes following best practices
|
|
workflow.add_node("intent_recognition", intent_recognition_node)
|
|
workflow.add_node("agent", call_model)
|
|
workflow.add_node("user_manual_rag", user_manual_rag_node)
|
|
workflow.add_node("tools", run_tools_with_streaming)
|
|
workflow.add_node("post_process", post_process_node)
|
|
|
|
# Set entry point to intent recognition
|
|
workflow.set_entry_point("intent_recognition")
|
|
|
|
# Intent recognition routes to either Standard_Regulation_RAG or User_Manual_RAG
|
|
workflow.add_conditional_edges(
|
|
"intent_recognition",
|
|
intent_router,
|
|
{
|
|
"Standard_Regulation_RAG": "agent",
|
|
"User_Manual_RAG": "user_manual_rag"
|
|
}
|
|
)
|
|
|
|
# Standard RAG workflow (existing pattern)
|
|
workflow.add_conditional_edges(
|
|
"agent",
|
|
should_continue,
|
|
{
|
|
"tools": "tools",
|
|
"agent": "agent", # Allow agent to continue for multi-round
|
|
"post_process": "post_process"
|
|
}
|
|
)
|
|
|
|
# Tools route back to should_continue for multi-round decision
|
|
workflow.add_conditional_edges(
|
|
"tools",
|
|
should_continue,
|
|
{
|
|
"agent": "agent", # Continue to agent for next round
|
|
"post_process": "post_process" # Or finish if max rounds reached
|
|
}
|
|
)
|
|
|
|
# User Manual RAG directly goes to END (single turn)
|
|
workflow.add_edge("user_manual_rag", END)
|
|
|
|
# Post-process is terminal
|
|
workflow.add_edge("post_process", END)
|
|
|
|
# Compile graph with PostgreSQL checkpointer for session memory
|
|
try:
|
|
checkpointer = get_checkpointer()
|
|
self.graph = workflow.compile(checkpointer=checkpointer)
|
|
logger.info("Graph compiled with PostgreSQL checkpointer for session memory")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to initialize PostgreSQL checkpointer, using memory-only graph: {e}")
|
|
self.graph = workflow.compile()
|
|
|
|
async def astream(self, state: TurnState, stream_callback: Callable | None = None):
|
|
"""Stream agent execution using LangGraph with PostgreSQL session memory"""
|
|
try:
|
|
# Get configuration
|
|
config = get_cached_config()
|
|
|
|
# Prepare initial messages for the graph
|
|
messages = []
|
|
for msg in state.messages:
|
|
if msg.role == "user":
|
|
messages.append(HumanMessage(content=msg.content))
|
|
elif msg.role == "assistant":
|
|
messages.append(AIMessage(content=msg.content))
|
|
|
|
# Create initial agent state (without stream_callback to avoid serialization issues)
|
|
initial_state: AgentState = {
|
|
"messages": messages,
|
|
"session_id": state.session_id,
|
|
"intent": None, # Will be determined by intent recognition node
|
|
"tool_results": [],
|
|
"final_answer": "",
|
|
"tool_rounds": 0,
|
|
"max_tool_rounds": config.app.max_tool_rounds, # Use configuration value
|
|
"max_tool_rounds_user_manual": config.app.max_tool_rounds_user_manual # Use configuration value for user manual agent
|
|
}
|
|
|
|
# Set stream callback in context variable (thread-safe)
|
|
stream_callback_context.set(stream_callback)
|
|
|
|
# Create proper RunnableConfig
|
|
runnable_config = RunnableConfig(configurable={"thread_id": state.session_id})
|
|
|
|
# Stream graph execution with session memory
|
|
async for step in self.graph.astream(initial_state, config=runnable_config):
|
|
if "post_process" in step:
|
|
final_state = step["post_process"]
|
|
|
|
# Extract the tool summary message and update state
|
|
state.final_answer = final_state.get("final_answer", "")
|
|
|
|
# Add the summary as a regular assistant message
|
|
if state.final_answer:
|
|
state.messages.append(Message(
|
|
role="assistant",
|
|
content=state.final_answer,
|
|
timestamp=datetime.now()
|
|
))
|
|
|
|
yield {"final": state}
|
|
break
|
|
elif "user_manual_rag" in step:
|
|
# Handle user manual RAG completion
|
|
final_state = step["user_manual_rag"]
|
|
|
|
# Extract the response from user manual RAG
|
|
state.final_answer = final_state.get("final_answer", "")
|
|
|
|
# Add the response as a regular assistant message
|
|
if state.final_answer:
|
|
state.messages.append(Message(
|
|
role="assistant",
|
|
content=state.final_answer,
|
|
timestamp=datetime.now()
|
|
))
|
|
|
|
yield {"final": state}
|
|
break
|
|
else:
|
|
# Process regular steps (intent_recognition, agent, tools)
|
|
yield step
|
|
|
|
except Exception as e:
|
|
logger.error(f"AgentWorkflow error: {e}")
|
|
state.final_answer = "I apologize, but I encountered an error while processing your request."
|
|
yield {"final": state}
|
|
|
|
|
|
def build_graph() -> AgenticWorkflow:
|
|
"""Build and return the autonomous agent workflow"""
|
|
return AgenticWorkflow()
|