465 lines
21 KiB
Python
465 lines
21 KiB
Python
|
|
"""
|
||
|
|
User Manual Agent node for the Agentic RAG system.
|
||
|
|
This module contains the autonomous user manual agent that can use tools and generate responses.
|
||
|
|
"""
|
||
|
|
|
||
|
|
import logging
|
||
|
|
from typing import Dict, Any, List, Optional, Callable, Literal
|
||
|
|
from contextvars import ContextVar
|
||
|
|
from langchain_core.messages import AIMessage, SystemMessage, BaseMessage, ToolMessage, HumanMessage
|
||
|
|
from langchain_core.runnables import RunnableConfig
|
||
|
|
|
||
|
|
from .state import AgentState
|
||
|
|
from .user_manual_tools import get_user_manual_tool_schemas, get_user_manual_tools_by_name
|
||
|
|
from .message_trimmer import create_conversation_trimmer
|
||
|
|
from ..llm_client import LLMClient
|
||
|
|
from ..config import get_config
|
||
|
|
from ..sse import (
|
||
|
|
create_tool_start_event,
|
||
|
|
create_tool_result_event,
|
||
|
|
create_tool_error_event,
|
||
|
|
create_token_event,
|
||
|
|
create_error_event
|
||
|
|
)
|
||
|
|
from ..utils.error_handler import (
|
||
|
|
StructuredLogger, ErrorCategory, ErrorCode,
|
||
|
|
handle_async_errors, get_user_message
|
||
|
|
)
|
||
|
|
|
||
|
|
logger = StructuredLogger(__name__)
|
||
|
|
|
||
|
|
# Cache configuration at module level to avoid repeated get_config() calls
|
||
|
|
_cached_config = None
|
||
|
|
|
||
|
|
def get_cached_config():
|
||
|
|
"""Get cached configuration, loading it if not already cached"""
|
||
|
|
global _cached_config
|
||
|
|
if _cached_config is None:
|
||
|
|
_cached_config = get_config()
|
||
|
|
return _cached_config
|
||
|
|
|
||
|
|
|
||
|
|
# User Manual Agent node (autonomous function calling agent)
|
||
|
|
async def user_manual_agent_node(state: AgentState, config: Optional[RunnableConfig] = None) -> Dict[str, Any]:
|
||
|
|
"""
|
||
|
|
User Manual Agent node that autonomously uses user manual tools and generates final answer.
|
||
|
|
Implements "detect-first-then-stream" strategy for optimal multi-round behavior:
|
||
|
|
1. Always start with non-streaming detection to check for tool needs
|
||
|
|
2. If tool_calls exist → return immediately for routing to tools
|
||
|
|
3. If no tool_calls → temporarily disable tools and perform streaming final synthesis
|
||
|
|
"""
|
||
|
|
app_config = get_cached_config()
|
||
|
|
llm_client = LLMClient()
|
||
|
|
|
||
|
|
# Get stream callback from context variable
|
||
|
|
from .graph import stream_callback_context
|
||
|
|
stream_callback = stream_callback_context.get()
|
||
|
|
|
||
|
|
# Get user manual tool schemas and bind tools for planning phase
|
||
|
|
tool_schemas = get_user_manual_tool_schemas()
|
||
|
|
llm_client.bind_tools(tool_schemas, force_tool_choice=True)
|
||
|
|
|
||
|
|
# Create conversation trimmer for managing context length
|
||
|
|
trimmer = create_conversation_trimmer()
|
||
|
|
|
||
|
|
# Prepare messages with user manual system prompt
|
||
|
|
messages = state["messages"].copy()
|
||
|
|
if not messages or not isinstance(messages[0], SystemMessage):
|
||
|
|
rag_prompts = app_config.get_rag_prompts()
|
||
|
|
user_manual_prompt = rag_prompts.get("user_manual_prompt", "")
|
||
|
|
if not user_manual_prompt:
|
||
|
|
raise ValueError("user_manual_prompt is null")
|
||
|
|
|
||
|
|
# For user manual agent, we need to format the prompt with placeholders
|
||
|
|
# Extract current query and conversation history
|
||
|
|
current_query = ""
|
||
|
|
for message in reversed(messages):
|
||
|
|
if isinstance(message, HumanMessage):
|
||
|
|
current_query = message.content
|
||
|
|
break
|
||
|
|
|
||
|
|
conversation_history = ""
|
||
|
|
if len(messages) > 1:
|
||
|
|
conversation_history = render_conversation_history(messages[:-1]) # Exclude current query
|
||
|
|
|
||
|
|
# Format system prompt (initially with empty context, tools will provide it)
|
||
|
|
formatted_system_prompt = user_manual_prompt.format(
|
||
|
|
conversation_history=conversation_history,
|
||
|
|
context_content="", # Will be filled by tools
|
||
|
|
current_query=current_query
|
||
|
|
)
|
||
|
|
|
||
|
|
messages = [SystemMessage(content=formatted_system_prompt)] + messages
|
||
|
|
|
||
|
|
# Track tool rounds
|
||
|
|
current_round = state.get("tool_rounds", 0)
|
||
|
|
# Get max_tool_rounds_user_manual from state, fallback to config if not set
|
||
|
|
max_rounds = state.get("max_tool_rounds_user_manual", None)
|
||
|
|
if max_rounds is None:
|
||
|
|
max_rounds = app_config.app.max_tool_rounds_user_manual
|
||
|
|
|
||
|
|
# Only apply trimming at the start of a new conversation turn (when tool_rounds = 0)
|
||
|
|
# This prevents trimming current turn's tool results during multi-round tool calling
|
||
|
|
if current_round == 0:
|
||
|
|
# Trim conversation history to manage context length (only for previous conversation turns)
|
||
|
|
if trimmer.should_trim(messages):
|
||
|
|
messages = trimmer.trim_conversation_history(messages)
|
||
|
|
logger.info("Applied conversation history trimming for context management (new conversation turn)")
|
||
|
|
else:
|
||
|
|
logger.info(f"Skipping trimming during tool round {current_round} to preserve current turn's context")
|
||
|
|
|
||
|
|
logger.info(f"User Manual Agent node: tool_rounds={current_round}, max_tool_rounds={max_rounds}")
|
||
|
|
|
||
|
|
# Check if this should be final synthesis (max rounds reached)
|
||
|
|
has_tool_messages = any(isinstance(msg, ToolMessage) for msg in messages)
|
||
|
|
is_final_synthesis = has_tool_messages and current_round >= max_rounds
|
||
|
|
|
||
|
|
if is_final_synthesis:
|
||
|
|
logger.info("Starting final synthesis phase - no more tool calls allowed")
|
||
|
|
# ✅ STEP 1: Final synthesis with tools disabled from the start
|
||
|
|
# Disable tools to prevent any tool calling during synthesis
|
||
|
|
try:
|
||
|
|
original_tools = llm_client.bind_tools([], force_tool_choice=False) # Disable tools
|
||
|
|
|
||
|
|
if not stream_callback:
|
||
|
|
# No streaming callback, generate final response without tools
|
||
|
|
draft = await llm_client.ainvoke(list(messages))
|
||
|
|
return {"messages": [draft]}
|
||
|
|
|
||
|
|
# ✅ STEP 2: Streaming final synthesis with improved HTML comment filtering
|
||
|
|
response_content = ""
|
||
|
|
accumulated_content = ""
|
||
|
|
|
||
|
|
async for token in llm_client.astream(list(messages)):
|
||
|
|
accumulated_content += token
|
||
|
|
response_content += token
|
||
|
|
|
||
|
|
# Check for complete HTML comments in accumulated content
|
||
|
|
while "<!--" in accumulated_content and "-->" in accumulated_content:
|
||
|
|
comment_start = accumulated_content.find("<!--")
|
||
|
|
comment_end = accumulated_content.find("-->", comment_start)
|
||
|
|
|
||
|
|
if comment_start >= 0 and comment_end >= 0:
|
||
|
|
# Send content before comment
|
||
|
|
before_comment = accumulated_content[:comment_start]
|
||
|
|
if stream_callback and before_comment:
|
||
|
|
await stream_callback(create_token_event(before_comment))
|
||
|
|
|
||
|
|
# Skip the comment and continue with content after
|
||
|
|
accumulated_content = accumulated_content[comment_end + 3:]
|
||
|
|
else:
|
||
|
|
break
|
||
|
|
|
||
|
|
# Send accumulated content if no pending comment
|
||
|
|
if "<!--" not in accumulated_content:
|
||
|
|
if stream_callback and accumulated_content:
|
||
|
|
await stream_callback(create_token_event(accumulated_content))
|
||
|
|
accumulated_content = ""
|
||
|
|
|
||
|
|
# Send any remaining content (if not in middle of comment)
|
||
|
|
if accumulated_content and "<!--" not in accumulated_content:
|
||
|
|
if stream_callback:
|
||
|
|
await stream_callback(create_token_event(accumulated_content))
|
||
|
|
|
||
|
|
return {"messages": [AIMessage(content=response_content)]}
|
||
|
|
|
||
|
|
finally:
|
||
|
|
# ✅ STEP 3: Restore tool binding for next interaction
|
||
|
|
llm_client.bind_tools(tool_schemas, force_tool_choice=True)
|
||
|
|
|
||
|
|
else:
|
||
|
|
logger.info(f"User Manual tool calling round {current_round + 1}/{max_rounds}")
|
||
|
|
|
||
|
|
# ✅ STEP 1: Non-streaming detection to check for tool needs
|
||
|
|
draft = await llm_client.ainvoke_with_tools(list(messages))
|
||
|
|
|
||
|
|
# ✅ STEP 2: If draft has tool_calls, execute them within this node
|
||
|
|
if isinstance(draft, AIMessage) and hasattr(draft, 'tool_calls') and draft.tool_calls:
|
||
|
|
logger.info(f"Detected {len(draft.tool_calls)} tool calls, executing within user manual agent")
|
||
|
|
|
||
|
|
# Create a new state with the tool call message added
|
||
|
|
tool_call_state = state.copy()
|
||
|
|
updated_messages = state["messages"].copy()
|
||
|
|
updated_messages.append(draft)
|
||
|
|
tool_call_state["messages"] = updated_messages
|
||
|
|
|
||
|
|
# Execute the tools using the existing streaming tool execution function
|
||
|
|
tool_results = await run_user_manual_tools_with_streaming(tool_call_state)
|
||
|
|
tool_messages = tool_results.get("messages", [])
|
||
|
|
|
||
|
|
# Increment tool round counter for next iteration
|
||
|
|
new_tool_rounds = current_round + 1
|
||
|
|
logger.info(f"Incremented user manual tool_rounds to {new_tool_rounds}")
|
||
|
|
|
||
|
|
# Continue with another round if under max rounds
|
||
|
|
if new_tool_rounds < max_rounds:
|
||
|
|
# Recursive call for next round with all messages
|
||
|
|
final_messages = updated_messages + tool_messages
|
||
|
|
recursive_state = state.copy()
|
||
|
|
recursive_state["messages"] = final_messages
|
||
|
|
recursive_state["tool_rounds"] = new_tool_rounds
|
||
|
|
return await user_manual_agent_node(recursive_state)
|
||
|
|
else:
|
||
|
|
# Max rounds reached, force final synthesis
|
||
|
|
logger.info("Max tool rounds reached, forcing final synthesis")
|
||
|
|
# Update messages for final synthesis
|
||
|
|
messages = updated_messages + tool_messages
|
||
|
|
# Continue to final synthesis below
|
||
|
|
|
||
|
|
# ✅ STEP 3: No tool_calls needed or max rounds reached → Enter final synthesis with streaming
|
||
|
|
# Temporarily disable tools to prevent accidental tool calling during synthesis
|
||
|
|
try:
|
||
|
|
llm_client.bind_tools([], force_tool_choice=False) # Disable tools
|
||
|
|
|
||
|
|
if not stream_callback:
|
||
|
|
# No streaming callback, use the draft we already have
|
||
|
|
return {"messages": [draft]}
|
||
|
|
|
||
|
|
# ✅ STEP 4: Streaming final synthesis with improved HTML comment filtering
|
||
|
|
response_content = ""
|
||
|
|
accumulated_content = ""
|
||
|
|
|
||
|
|
async for token in llm_client.astream(list(messages)):
|
||
|
|
accumulated_content += token
|
||
|
|
response_content += token
|
||
|
|
|
||
|
|
# Check for complete HTML comments in accumulated content
|
||
|
|
while "<!--" in accumulated_content and "-->" in accumulated_content:
|
||
|
|
comment_start = accumulated_content.find("<!--")
|
||
|
|
comment_end = accumulated_content.find("-->", comment_start)
|
||
|
|
|
||
|
|
if comment_start >= 0 and comment_end >= 0:
|
||
|
|
# Send content before comment
|
||
|
|
before_comment = accumulated_content[:comment_start]
|
||
|
|
if stream_callback and before_comment:
|
||
|
|
await stream_callback(create_token_event(before_comment))
|
||
|
|
|
||
|
|
# Skip the comment and continue with content after
|
||
|
|
accumulated_content = accumulated_content[comment_end + 3:]
|
||
|
|
else:
|
||
|
|
break
|
||
|
|
|
||
|
|
# Send accumulated content if no pending comment
|
||
|
|
if "<!--" not in accumulated_content:
|
||
|
|
if stream_callback and accumulated_content:
|
||
|
|
await stream_callback(create_token_event(accumulated_content))
|
||
|
|
accumulated_content = ""
|
||
|
|
|
||
|
|
# Send any remaining content (if not in middle of comment)
|
||
|
|
if accumulated_content and "<!--" not in accumulated_content:
|
||
|
|
if stream_callback:
|
||
|
|
await stream_callback(create_token_event(accumulated_content))
|
||
|
|
|
||
|
|
return {"messages": [AIMessage(content=response_content)]}
|
||
|
|
|
||
|
|
finally:
|
||
|
|
# ✅ STEP 5: Restore tool binding for next interaction
|
||
|
|
llm_client.bind_tools(tool_schemas, force_tool_choice=True)
|
||
|
|
|
||
|
|
|
||
|
|
def render_conversation_history(messages, max_messages: int = 10) -> str:
|
||
|
|
"""Render conversation history for context"""
|
||
|
|
recent_messages = messages[-max_messages:] if len(messages) > max_messages else messages
|
||
|
|
lines = []
|
||
|
|
for msg in recent_messages:
|
||
|
|
if hasattr(msg, 'content'):
|
||
|
|
content = msg.content
|
||
|
|
if isinstance(content, str):
|
||
|
|
# Determine message type by class name or other attributes
|
||
|
|
if 'Human' in str(type(msg)):
|
||
|
|
lines.append(f"<user>{content}</user>")
|
||
|
|
elif 'AI' in str(type(msg)):
|
||
|
|
lines.append(f"<ai>{content}</ai>")
|
||
|
|
elif isinstance(content, list):
|
||
|
|
content_str = " ".join([str(item) for item in content if isinstance(item, str)])
|
||
|
|
if 'Human' in str(type(msg)):
|
||
|
|
lines.append(f"<user>{content_str}</user>")
|
||
|
|
elif 'AI' in str(type(msg)):
|
||
|
|
lines.append(f"<ai>{content_str}</ai>")
|
||
|
|
return "\n".join(lines)
|
||
|
|
|
||
|
|
|
||
|
|
# User Manual Tools routing condition
|
||
|
|
def user_manual_should_continue(state: AgentState) -> Literal["user_manual_tools", "user_manual_agent", "post_process"]:
|
||
|
|
"""
|
||
|
|
Routing logic for user manual agent:
|
||
|
|
- has tool_calls → route to user_manual_tools
|
||
|
|
- no tool_calls → route to post_process (final synthesis already completed)
|
||
|
|
"""
|
||
|
|
messages = state["messages"]
|
||
|
|
if not messages:
|
||
|
|
logger.info("user_manual_should_continue: No messages, routing to post_process")
|
||
|
|
return "post_process"
|
||
|
|
|
||
|
|
last_message = messages[-1]
|
||
|
|
current_round = state.get("tool_rounds", 0)
|
||
|
|
# Get max_tool_rounds_user_manual from state, fallback to config if not set
|
||
|
|
max_rounds = state.get("max_tool_rounds_user_manual", None)
|
||
|
|
if max_rounds is None:
|
||
|
|
app_config = get_cached_config()
|
||
|
|
max_rounds = app_config.app.max_tool_rounds_user_manual
|
||
|
|
|
||
|
|
logger.info(f"user_manual_should_continue: Last message type: {type(last_message)}, tool_rounds: {current_round}/{max_rounds}")
|
||
|
|
|
||
|
|
# If last message is AI message with tool calls, route to tools
|
||
|
|
if isinstance(last_message, AIMessage):
|
||
|
|
has_tool_calls = hasattr(last_message, 'tool_calls') and last_message.tool_calls
|
||
|
|
logger.info(f"user_manual_should_continue: AI message has tool_calls: {has_tool_calls}")
|
||
|
|
|
||
|
|
if has_tool_calls:
|
||
|
|
logger.info("user_manual_should_continue: Routing to user_manual_tools")
|
||
|
|
return "user_manual_tools"
|
||
|
|
else:
|
||
|
|
# No tool calls = final synthesis already completed in user_manual_agent_node
|
||
|
|
logger.info("user_manual_should_continue: No tool calls, routing to post_process")
|
||
|
|
return "post_process"
|
||
|
|
|
||
|
|
# If last message is tool message(s), continue with agent for next round or final synthesis
|
||
|
|
if isinstance(last_message, ToolMessage):
|
||
|
|
logger.info("user_manual_should_continue: Tool message completed, continuing to user_manual_agent")
|
||
|
|
return "user_manual_agent"
|
||
|
|
|
||
|
|
logger.info("user_manual_should_continue: Routing to post_process")
|
||
|
|
return "post_process"
|
||
|
|
|
||
|
|
|
||
|
|
# User Manual Tools node with streaming support
|
||
|
|
async def run_user_manual_tools_with_streaming(state: AgentState, config: Optional[RunnableConfig] = None) -> Dict[str, Any]:
|
||
|
|
"""Execute user manual tools with streaming events - supports parallel execution"""
|
||
|
|
messages = state["messages"]
|
||
|
|
last_message = messages[-1]
|
||
|
|
|
||
|
|
# Get stream callback from context variable
|
||
|
|
from .graph import stream_callback_context
|
||
|
|
stream_callback = stream_callback_context.get()
|
||
|
|
|
||
|
|
if not isinstance(last_message, AIMessage) or not hasattr(last_message, 'tool_calls'):
|
||
|
|
return {"messages": []}
|
||
|
|
|
||
|
|
tool_calls = last_message.tool_calls or []
|
||
|
|
tool_results = []
|
||
|
|
new_messages = []
|
||
|
|
|
||
|
|
# User manual tools mapping
|
||
|
|
tools_map = get_user_manual_tools_by_name()
|
||
|
|
|
||
|
|
async def execute_single_tool(tool_call):
|
||
|
|
"""Execute a single user manual tool call with enhanced error handling"""
|
||
|
|
# Get stream callback from context
|
||
|
|
from .graph import stream_callback_context
|
||
|
|
stream_callback = stream_callback_context.get()
|
||
|
|
|
||
|
|
# Apply error handling decorator
|
||
|
|
@handle_async_errors(
|
||
|
|
ErrorCategory.TOOL,
|
||
|
|
ErrorCode.TOOL_ERROR,
|
||
|
|
stream_callback,
|
||
|
|
tool_call.get("id", "unknown") if isinstance(tool_call, dict) else "unknown"
|
||
|
|
)
|
||
|
|
async def _execute():
|
||
|
|
# Validate tool_call format
|
||
|
|
if not isinstance(tool_call, dict):
|
||
|
|
raise ValueError(f"Tool call must be dict, got {type(tool_call)}")
|
||
|
|
|
||
|
|
tool_name = tool_call.get("name")
|
||
|
|
tool_args = tool_call.get("args", {})
|
||
|
|
tool_id = tool_call.get("id", "unknown")
|
||
|
|
|
||
|
|
if not tool_name:
|
||
|
|
raise ValueError("Tool call missing 'name' field")
|
||
|
|
|
||
|
|
if tool_name not in tools_map:
|
||
|
|
available_tools = list(tools_map.keys())
|
||
|
|
raise ValueError(f"Tool '{tool_name}' not found. Available user manual tools: {available_tools}")
|
||
|
|
|
||
|
|
tool_func = tools_map[tool_name]
|
||
|
|
|
||
|
|
# Stream tool start event
|
||
|
|
if stream_callback:
|
||
|
|
await stream_callback(create_tool_start_event(tool_id, tool_name, tool_args))
|
||
|
|
|
||
|
|
import time
|
||
|
|
start_time = time.time()
|
||
|
|
|
||
|
|
try:
|
||
|
|
# Execute the user manual tool
|
||
|
|
result = await tool_func.ainvoke(tool_args)
|
||
|
|
|
||
|
|
# Calculate execution time
|
||
|
|
took_ms = int((time.time() - start_time) * 1000)
|
||
|
|
|
||
|
|
# Stream tool result event
|
||
|
|
if stream_callback:
|
||
|
|
await stream_callback(create_tool_result_event(tool_id, tool_name, result, took_ms))
|
||
|
|
|
||
|
|
# Create tool message
|
||
|
|
tool_message = ToolMessage(
|
||
|
|
content=str(result),
|
||
|
|
tool_call_id=tool_id,
|
||
|
|
name=tool_name
|
||
|
|
)
|
||
|
|
|
||
|
|
return tool_message, {"name": tool_name, "result": result, "took_ms": took_ms}
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
took_ms = int((time.time() - start_time) * 1000)
|
||
|
|
error_msg = get_user_message(ErrorCategory.TOOL)
|
||
|
|
|
||
|
|
# Stream tool error event
|
||
|
|
if stream_callback:
|
||
|
|
await stream_callback(create_tool_error_event(tool_id, tool_name, error_msg))
|
||
|
|
|
||
|
|
# Create error tool message
|
||
|
|
tool_message = ToolMessage(
|
||
|
|
content=f"Error executing {tool_name}: {error_msg}",
|
||
|
|
tool_call_id=tool_id,
|
||
|
|
name=tool_name
|
||
|
|
)
|
||
|
|
|
||
|
|
return tool_message, {"name": tool_name, "error": error_msg, "took_ms": took_ms}
|
||
|
|
|
||
|
|
return await _execute()
|
||
|
|
|
||
|
|
# Execute user manual tools (typically just one for user manual retrieval)
|
||
|
|
import asyncio
|
||
|
|
tasks = [execute_single_tool(tool_call) for tool_call in tool_calls]
|
||
|
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||
|
|
|
||
|
|
for i, result in enumerate(results):
|
||
|
|
if isinstance(result, Exception):
|
||
|
|
# Handle execution exception
|
||
|
|
tool_call = tool_calls[i]
|
||
|
|
tool_id = tool_call.get("id", f"error_{i}") or f"error_{i}"
|
||
|
|
tool_name = tool_call.get("name", "unknown")
|
||
|
|
error_msg = get_user_message(ErrorCategory.TOOL)
|
||
|
|
|
||
|
|
if stream_callback:
|
||
|
|
await stream_callback(create_tool_error_event(tool_id, tool_name, error_msg))
|
||
|
|
|
||
|
|
error_message = ToolMessage(
|
||
|
|
content=f"Error executing {tool_name}: {error_msg}",
|
||
|
|
tool_call_id=tool_id,
|
||
|
|
name=tool_name
|
||
|
|
)
|
||
|
|
new_messages.append(error_message)
|
||
|
|
elif isinstance(result, tuple) and len(result) == 2:
|
||
|
|
# result is a tuple: (tool_message, tool_result)
|
||
|
|
tool_message, tool_result = result
|
||
|
|
new_messages.append(tool_message)
|
||
|
|
tool_results.append(tool_result)
|
||
|
|
else:
|
||
|
|
# Unexpected result format
|
||
|
|
logger.error(f"Unexpected tool execution result format: {type(result)}")
|
||
|
|
continue
|
||
|
|
|
||
|
|
return {"messages": new_messages, "tool_results": tool_results}
|
||
|
|
|
||
|
|
|
||
|
|
# Legacy function for backward compatibility
|
||
|
|
async def user_manual_rag_node(state: AgentState, config: Optional[RunnableConfig] = None) -> Dict[str, Any]:
|
||
|
|
"""
|
||
|
|
Legacy user manual RAG node - redirects to new agent-based implementation
|
||
|
|
"""
|
||
|
|
logger.info("📚 USER_MANUAL_RAG_NODE: Redirecting to user_manual_agent_node")
|
||
|
|
return await user_manual_agent_node(state, config)
|