Files
catonline_ai/vw-agentic-rag/service/graph/user_manual_rag.py
2025-09-26 17:15:54 +08:00

465 lines
21 KiB
Python

"""
User Manual Agent node for the Agentic RAG system.
This module contains the autonomous user manual agent that can use tools and generate responses.
"""
import logging
from typing import Dict, Any, List, Optional, Callable, Literal
from contextvars import ContextVar
from langchain_core.messages import AIMessage, SystemMessage, BaseMessage, ToolMessage, HumanMessage
from langchain_core.runnables import RunnableConfig
from .state import AgentState
from .user_manual_tools import get_user_manual_tool_schemas, get_user_manual_tools_by_name
from .message_trimmer import create_conversation_trimmer
from ..llm_client import LLMClient
from ..config import get_config
from ..sse import (
create_tool_start_event,
create_tool_result_event,
create_tool_error_event,
create_token_event,
create_error_event
)
from ..utils.error_handler import (
StructuredLogger, ErrorCategory, ErrorCode,
handle_async_errors, get_user_message
)
logger = StructuredLogger(__name__)
# Cache configuration at module level to avoid repeated get_config() calls
_cached_config = None
def get_cached_config():
"""Get cached configuration, loading it if not already cached"""
global _cached_config
if _cached_config is None:
_cached_config = get_config()
return _cached_config
# User Manual Agent node (autonomous function calling agent)
async def user_manual_agent_node(state: AgentState, config: Optional[RunnableConfig] = None) -> Dict[str, Any]:
"""
User Manual Agent node that autonomously uses user manual tools and generates final answer.
Implements "detect-first-then-stream" strategy for optimal multi-round behavior:
1. Always start with non-streaming detection to check for tool needs
2. If tool_calls exist → return immediately for routing to tools
3. If no tool_calls → temporarily disable tools and perform streaming final synthesis
"""
app_config = get_cached_config()
llm_client = LLMClient()
# Get stream callback from context variable
from .graph import stream_callback_context
stream_callback = stream_callback_context.get()
# Get user manual tool schemas and bind tools for planning phase
tool_schemas = get_user_manual_tool_schemas()
llm_client.bind_tools(tool_schemas, force_tool_choice=True)
# Create conversation trimmer for managing context length
trimmer = create_conversation_trimmer()
# Prepare messages with user manual system prompt
messages = state["messages"].copy()
if not messages or not isinstance(messages[0], SystemMessage):
rag_prompts = app_config.get_rag_prompts()
user_manual_prompt = rag_prompts.get("user_manual_prompt", "")
if not user_manual_prompt:
raise ValueError("user_manual_prompt is null")
# For user manual agent, we need to format the prompt with placeholders
# Extract current query and conversation history
current_query = ""
for message in reversed(messages):
if isinstance(message, HumanMessage):
current_query = message.content
break
conversation_history = ""
if len(messages) > 1:
conversation_history = render_conversation_history(messages[:-1]) # Exclude current query
# Format system prompt (initially with empty context, tools will provide it)
formatted_system_prompt = user_manual_prompt.format(
conversation_history=conversation_history,
context_content="", # Will be filled by tools
current_query=current_query
)
messages = [SystemMessage(content=formatted_system_prompt)] + messages
# Track tool rounds
current_round = state.get("tool_rounds", 0)
# Get max_tool_rounds_user_manual from state, fallback to config if not set
max_rounds = state.get("max_tool_rounds_user_manual", None)
if max_rounds is None:
max_rounds = app_config.app.max_tool_rounds_user_manual
# Only apply trimming at the start of a new conversation turn (when tool_rounds = 0)
# This prevents trimming current turn's tool results during multi-round tool calling
if current_round == 0:
# Trim conversation history to manage context length (only for previous conversation turns)
if trimmer.should_trim(messages):
messages = trimmer.trim_conversation_history(messages)
logger.info("Applied conversation history trimming for context management (new conversation turn)")
else:
logger.info(f"Skipping trimming during tool round {current_round} to preserve current turn's context")
logger.info(f"User Manual Agent node: tool_rounds={current_round}, max_tool_rounds={max_rounds}")
# Check if this should be final synthesis (max rounds reached)
has_tool_messages = any(isinstance(msg, ToolMessage) for msg in messages)
is_final_synthesis = has_tool_messages and current_round >= max_rounds
if is_final_synthesis:
logger.info("Starting final synthesis phase - no more tool calls allowed")
# ✅ STEP 1: Final synthesis with tools disabled from the start
# Disable tools to prevent any tool calling during synthesis
try:
original_tools = llm_client.bind_tools([], force_tool_choice=False) # Disable tools
if not stream_callback:
# No streaming callback, generate final response without tools
draft = await llm_client.ainvoke(list(messages))
return {"messages": [draft]}
# ✅ STEP 2: Streaming final synthesis with improved HTML comment filtering
response_content = ""
accumulated_content = ""
async for token in llm_client.astream(list(messages)):
accumulated_content += token
response_content += token
# Check for complete HTML comments in accumulated content
while "<!--" in accumulated_content and "-->" in accumulated_content:
comment_start = accumulated_content.find("<!--")
comment_end = accumulated_content.find("-->", comment_start)
if comment_start >= 0 and comment_end >= 0:
# Send content before comment
before_comment = accumulated_content[:comment_start]
if stream_callback and before_comment:
await stream_callback(create_token_event(before_comment))
# Skip the comment and continue with content after
accumulated_content = accumulated_content[comment_end + 3:]
else:
break
# Send accumulated content if no pending comment
if "<!--" not in accumulated_content:
if stream_callback and accumulated_content:
await stream_callback(create_token_event(accumulated_content))
accumulated_content = ""
# Send any remaining content (if not in middle of comment)
if accumulated_content and "<!--" not in accumulated_content:
if stream_callback:
await stream_callback(create_token_event(accumulated_content))
return {"messages": [AIMessage(content=response_content)]}
finally:
# ✅ STEP 3: Restore tool binding for next interaction
llm_client.bind_tools(tool_schemas, force_tool_choice=True)
else:
logger.info(f"User Manual tool calling round {current_round + 1}/{max_rounds}")
# ✅ STEP 1: Non-streaming detection to check for tool needs
draft = await llm_client.ainvoke_with_tools(list(messages))
# ✅ STEP 2: If draft has tool_calls, execute them within this node
if isinstance(draft, AIMessage) and hasattr(draft, 'tool_calls') and draft.tool_calls:
logger.info(f"Detected {len(draft.tool_calls)} tool calls, executing within user manual agent")
# Create a new state with the tool call message added
tool_call_state = state.copy()
updated_messages = state["messages"].copy()
updated_messages.append(draft)
tool_call_state["messages"] = updated_messages
# Execute the tools using the existing streaming tool execution function
tool_results = await run_user_manual_tools_with_streaming(tool_call_state)
tool_messages = tool_results.get("messages", [])
# Increment tool round counter for next iteration
new_tool_rounds = current_round + 1
logger.info(f"Incremented user manual tool_rounds to {new_tool_rounds}")
# Continue with another round if under max rounds
if new_tool_rounds < max_rounds:
# Recursive call for next round with all messages
final_messages = updated_messages + tool_messages
recursive_state = state.copy()
recursive_state["messages"] = final_messages
recursive_state["tool_rounds"] = new_tool_rounds
return await user_manual_agent_node(recursive_state)
else:
# Max rounds reached, force final synthesis
logger.info("Max tool rounds reached, forcing final synthesis")
# Update messages for final synthesis
messages = updated_messages + tool_messages
# Continue to final synthesis below
# ✅ STEP 3: No tool_calls needed or max rounds reached → Enter final synthesis with streaming
# Temporarily disable tools to prevent accidental tool calling during synthesis
try:
llm_client.bind_tools([], force_tool_choice=False) # Disable tools
if not stream_callback:
# No streaming callback, use the draft we already have
return {"messages": [draft]}
# ✅ STEP 4: Streaming final synthesis with improved HTML comment filtering
response_content = ""
accumulated_content = ""
async for token in llm_client.astream(list(messages)):
accumulated_content += token
response_content += token
# Check for complete HTML comments in accumulated content
while "<!--" in accumulated_content and "-->" in accumulated_content:
comment_start = accumulated_content.find("<!--")
comment_end = accumulated_content.find("-->", comment_start)
if comment_start >= 0 and comment_end >= 0:
# Send content before comment
before_comment = accumulated_content[:comment_start]
if stream_callback and before_comment:
await stream_callback(create_token_event(before_comment))
# Skip the comment and continue with content after
accumulated_content = accumulated_content[comment_end + 3:]
else:
break
# Send accumulated content if no pending comment
if "<!--" not in accumulated_content:
if stream_callback and accumulated_content:
await stream_callback(create_token_event(accumulated_content))
accumulated_content = ""
# Send any remaining content (if not in middle of comment)
if accumulated_content and "<!--" not in accumulated_content:
if stream_callback:
await stream_callback(create_token_event(accumulated_content))
return {"messages": [AIMessage(content=response_content)]}
finally:
# ✅ STEP 5: Restore tool binding for next interaction
llm_client.bind_tools(tool_schemas, force_tool_choice=True)
def render_conversation_history(messages, max_messages: int = 10) -> str:
"""Render conversation history for context"""
recent_messages = messages[-max_messages:] if len(messages) > max_messages else messages
lines = []
for msg in recent_messages:
if hasattr(msg, 'content'):
content = msg.content
if isinstance(content, str):
# Determine message type by class name or other attributes
if 'Human' in str(type(msg)):
lines.append(f"<user>{content}</user>")
elif 'AI' in str(type(msg)):
lines.append(f"<ai>{content}</ai>")
elif isinstance(content, list):
content_str = " ".join([str(item) for item in content if isinstance(item, str)])
if 'Human' in str(type(msg)):
lines.append(f"<user>{content_str}</user>")
elif 'AI' in str(type(msg)):
lines.append(f"<ai>{content_str}</ai>")
return "\n".join(lines)
# User Manual Tools routing condition
def user_manual_should_continue(state: AgentState) -> Literal["user_manual_tools", "user_manual_agent", "post_process"]:
"""
Routing logic for user manual agent:
- has tool_calls → route to user_manual_tools
- no tool_calls → route to post_process (final synthesis already completed)
"""
messages = state["messages"]
if not messages:
logger.info("user_manual_should_continue: No messages, routing to post_process")
return "post_process"
last_message = messages[-1]
current_round = state.get("tool_rounds", 0)
# Get max_tool_rounds_user_manual from state, fallback to config if not set
max_rounds = state.get("max_tool_rounds_user_manual", None)
if max_rounds is None:
app_config = get_cached_config()
max_rounds = app_config.app.max_tool_rounds_user_manual
logger.info(f"user_manual_should_continue: Last message type: {type(last_message)}, tool_rounds: {current_round}/{max_rounds}")
# If last message is AI message with tool calls, route to tools
if isinstance(last_message, AIMessage):
has_tool_calls = hasattr(last_message, 'tool_calls') and last_message.tool_calls
logger.info(f"user_manual_should_continue: AI message has tool_calls: {has_tool_calls}")
if has_tool_calls:
logger.info("user_manual_should_continue: Routing to user_manual_tools")
return "user_manual_tools"
else:
# No tool calls = final synthesis already completed in user_manual_agent_node
logger.info("user_manual_should_continue: No tool calls, routing to post_process")
return "post_process"
# If last message is tool message(s), continue with agent for next round or final synthesis
if isinstance(last_message, ToolMessage):
logger.info("user_manual_should_continue: Tool message completed, continuing to user_manual_agent")
return "user_manual_agent"
logger.info("user_manual_should_continue: Routing to post_process")
return "post_process"
# User Manual Tools node with streaming support
async def run_user_manual_tools_with_streaming(state: AgentState, config: Optional[RunnableConfig] = None) -> Dict[str, Any]:
"""Execute user manual tools with streaming events - supports parallel execution"""
messages = state["messages"]
last_message = messages[-1]
# Get stream callback from context variable
from .graph import stream_callback_context
stream_callback = stream_callback_context.get()
if not isinstance(last_message, AIMessage) or not hasattr(last_message, 'tool_calls'):
return {"messages": []}
tool_calls = last_message.tool_calls or []
tool_results = []
new_messages = []
# User manual tools mapping
tools_map = get_user_manual_tools_by_name()
async def execute_single_tool(tool_call):
"""Execute a single user manual tool call with enhanced error handling"""
# Get stream callback from context
from .graph import stream_callback_context
stream_callback = stream_callback_context.get()
# Apply error handling decorator
@handle_async_errors(
ErrorCategory.TOOL,
ErrorCode.TOOL_ERROR,
stream_callback,
tool_call.get("id", "unknown") if isinstance(tool_call, dict) else "unknown"
)
async def _execute():
# Validate tool_call format
if not isinstance(tool_call, dict):
raise ValueError(f"Tool call must be dict, got {type(tool_call)}")
tool_name = tool_call.get("name")
tool_args = tool_call.get("args", {})
tool_id = tool_call.get("id", "unknown")
if not tool_name:
raise ValueError("Tool call missing 'name' field")
if tool_name not in tools_map:
available_tools = list(tools_map.keys())
raise ValueError(f"Tool '{tool_name}' not found. Available user manual tools: {available_tools}")
tool_func = tools_map[tool_name]
# Stream tool start event
if stream_callback:
await stream_callback(create_tool_start_event(tool_id, tool_name, tool_args))
import time
start_time = time.time()
try:
# Execute the user manual tool
result = await tool_func.ainvoke(tool_args)
# Calculate execution time
took_ms = int((time.time() - start_time) * 1000)
# Stream tool result event
if stream_callback:
await stream_callback(create_tool_result_event(tool_id, tool_name, result, took_ms))
# Create tool message
tool_message = ToolMessage(
content=str(result),
tool_call_id=tool_id,
name=tool_name
)
return tool_message, {"name": tool_name, "result": result, "took_ms": took_ms}
except Exception as e:
took_ms = int((time.time() - start_time) * 1000)
error_msg = get_user_message(ErrorCategory.TOOL)
# Stream tool error event
if stream_callback:
await stream_callback(create_tool_error_event(tool_id, tool_name, error_msg))
# Create error tool message
tool_message = ToolMessage(
content=f"Error executing {tool_name}: {error_msg}",
tool_call_id=tool_id,
name=tool_name
)
return tool_message, {"name": tool_name, "error": error_msg, "took_ms": took_ms}
return await _execute()
# Execute user manual tools (typically just one for user manual retrieval)
import asyncio
tasks = [execute_single_tool(tool_call) for tool_call in tool_calls]
results = await asyncio.gather(*tasks, return_exceptions=True)
for i, result in enumerate(results):
if isinstance(result, Exception):
# Handle execution exception
tool_call = tool_calls[i]
tool_id = tool_call.get("id", f"error_{i}") or f"error_{i}"
tool_name = tool_call.get("name", "unknown")
error_msg = get_user_message(ErrorCategory.TOOL)
if stream_callback:
await stream_callback(create_tool_error_event(tool_id, tool_name, error_msg))
error_message = ToolMessage(
content=f"Error executing {tool_name}: {error_msg}",
tool_call_id=tool_id,
name=tool_name
)
new_messages.append(error_message)
elif isinstance(result, tuple) and len(result) == 2:
# result is a tuple: (tool_message, tool_result)
tool_message, tool_result = result
new_messages.append(tool_message)
tool_results.append(tool_result)
else:
# Unexpected result format
logger.error(f"Unexpected tool execution result format: {type(result)}")
continue
return {"messages": new_messages, "tool_results": tool_results}
# Legacy function for backward compatibility
async def user_manual_rag_node(state: AgentState, config: Optional[RunnableConfig] = None) -> Dict[str, Any]:
"""
Legacy user manual RAG node - redirects to new agent-based implementation
"""
logger.info("📚 USER_MANUAL_RAG_NODE: Redirecting to user_manual_agent_node")
return await user_manual_agent_node(state, config)