init
This commit is contained in:
464
vw-agentic-rag/service/graph/user_manual_rag.py
Normal file
464
vw-agentic-rag/service/graph/user_manual_rag.py
Normal file
@@ -0,0 +1,464 @@
|
||||
"""
|
||||
User Manual Agent node for the Agentic RAG system.
|
||||
This module contains the autonomous user manual agent that can use tools and generate responses.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Any, List, Optional, Callable, Literal
|
||||
from contextvars import ContextVar
|
||||
from langchain_core.messages import AIMessage, SystemMessage, BaseMessage, ToolMessage, HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from .state import AgentState
|
||||
from .user_manual_tools import get_user_manual_tool_schemas, get_user_manual_tools_by_name
|
||||
from .message_trimmer import create_conversation_trimmer
|
||||
from ..llm_client import LLMClient
|
||||
from ..config import get_config
|
||||
from ..sse import (
|
||||
create_tool_start_event,
|
||||
create_tool_result_event,
|
||||
create_tool_error_event,
|
||||
create_token_event,
|
||||
create_error_event
|
||||
)
|
||||
from ..utils.error_handler import (
|
||||
StructuredLogger, ErrorCategory, ErrorCode,
|
||||
handle_async_errors, get_user_message
|
||||
)
|
||||
|
||||
logger = StructuredLogger(__name__)
|
||||
|
||||
# Cache configuration at module level to avoid repeated get_config() calls
|
||||
_cached_config = None
|
||||
|
||||
def get_cached_config():
|
||||
"""Get cached configuration, loading it if not already cached"""
|
||||
global _cached_config
|
||||
if _cached_config is None:
|
||||
_cached_config = get_config()
|
||||
return _cached_config
|
||||
|
||||
|
||||
# User Manual Agent node (autonomous function calling agent)
|
||||
async def user_manual_agent_node(state: AgentState, config: Optional[RunnableConfig] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
User Manual Agent node that autonomously uses user manual tools and generates final answer.
|
||||
Implements "detect-first-then-stream" strategy for optimal multi-round behavior:
|
||||
1. Always start with non-streaming detection to check for tool needs
|
||||
2. If tool_calls exist → return immediately for routing to tools
|
||||
3. If no tool_calls → temporarily disable tools and perform streaming final synthesis
|
||||
"""
|
||||
app_config = get_cached_config()
|
||||
llm_client = LLMClient()
|
||||
|
||||
# Get stream callback from context variable
|
||||
from .graph import stream_callback_context
|
||||
stream_callback = stream_callback_context.get()
|
||||
|
||||
# Get user manual tool schemas and bind tools for planning phase
|
||||
tool_schemas = get_user_manual_tool_schemas()
|
||||
llm_client.bind_tools(tool_schemas, force_tool_choice=True)
|
||||
|
||||
# Create conversation trimmer for managing context length
|
||||
trimmer = create_conversation_trimmer()
|
||||
|
||||
# Prepare messages with user manual system prompt
|
||||
messages = state["messages"].copy()
|
||||
if not messages or not isinstance(messages[0], SystemMessage):
|
||||
rag_prompts = app_config.get_rag_prompts()
|
||||
user_manual_prompt = rag_prompts.get("user_manual_prompt", "")
|
||||
if not user_manual_prompt:
|
||||
raise ValueError("user_manual_prompt is null")
|
||||
|
||||
# For user manual agent, we need to format the prompt with placeholders
|
||||
# Extract current query and conversation history
|
||||
current_query = ""
|
||||
for message in reversed(messages):
|
||||
if isinstance(message, HumanMessage):
|
||||
current_query = message.content
|
||||
break
|
||||
|
||||
conversation_history = ""
|
||||
if len(messages) > 1:
|
||||
conversation_history = render_conversation_history(messages[:-1]) # Exclude current query
|
||||
|
||||
# Format system prompt (initially with empty context, tools will provide it)
|
||||
formatted_system_prompt = user_manual_prompt.format(
|
||||
conversation_history=conversation_history,
|
||||
context_content="", # Will be filled by tools
|
||||
current_query=current_query
|
||||
)
|
||||
|
||||
messages = [SystemMessage(content=formatted_system_prompt)] + messages
|
||||
|
||||
# Track tool rounds
|
||||
current_round = state.get("tool_rounds", 0)
|
||||
# Get max_tool_rounds_user_manual from state, fallback to config if not set
|
||||
max_rounds = state.get("max_tool_rounds_user_manual", None)
|
||||
if max_rounds is None:
|
||||
max_rounds = app_config.app.max_tool_rounds_user_manual
|
||||
|
||||
# Only apply trimming at the start of a new conversation turn (when tool_rounds = 0)
|
||||
# This prevents trimming current turn's tool results during multi-round tool calling
|
||||
if current_round == 0:
|
||||
# Trim conversation history to manage context length (only for previous conversation turns)
|
||||
if trimmer.should_trim(messages):
|
||||
messages = trimmer.trim_conversation_history(messages)
|
||||
logger.info("Applied conversation history trimming for context management (new conversation turn)")
|
||||
else:
|
||||
logger.info(f"Skipping trimming during tool round {current_round} to preserve current turn's context")
|
||||
|
||||
logger.info(f"User Manual Agent node: tool_rounds={current_round}, max_tool_rounds={max_rounds}")
|
||||
|
||||
# Check if this should be final synthesis (max rounds reached)
|
||||
has_tool_messages = any(isinstance(msg, ToolMessage) for msg in messages)
|
||||
is_final_synthesis = has_tool_messages and current_round >= max_rounds
|
||||
|
||||
if is_final_synthesis:
|
||||
logger.info("Starting final synthesis phase - no more tool calls allowed")
|
||||
# ✅ STEP 1: Final synthesis with tools disabled from the start
|
||||
# Disable tools to prevent any tool calling during synthesis
|
||||
try:
|
||||
original_tools = llm_client.bind_tools([], force_tool_choice=False) # Disable tools
|
||||
|
||||
if not stream_callback:
|
||||
# No streaming callback, generate final response without tools
|
||||
draft = await llm_client.ainvoke(list(messages))
|
||||
return {"messages": [draft]}
|
||||
|
||||
# ✅ STEP 2: Streaming final synthesis with improved HTML comment filtering
|
||||
response_content = ""
|
||||
accumulated_content = ""
|
||||
|
||||
async for token in llm_client.astream(list(messages)):
|
||||
accumulated_content += token
|
||||
response_content += token
|
||||
|
||||
# Check for complete HTML comments in accumulated content
|
||||
while "<!--" in accumulated_content and "-->" in accumulated_content:
|
||||
comment_start = accumulated_content.find("<!--")
|
||||
comment_end = accumulated_content.find("-->", comment_start)
|
||||
|
||||
if comment_start >= 0 and comment_end >= 0:
|
||||
# Send content before comment
|
||||
before_comment = accumulated_content[:comment_start]
|
||||
if stream_callback and before_comment:
|
||||
await stream_callback(create_token_event(before_comment))
|
||||
|
||||
# Skip the comment and continue with content after
|
||||
accumulated_content = accumulated_content[comment_end + 3:]
|
||||
else:
|
||||
break
|
||||
|
||||
# Send accumulated content if no pending comment
|
||||
if "<!--" not in accumulated_content:
|
||||
if stream_callback and accumulated_content:
|
||||
await stream_callback(create_token_event(accumulated_content))
|
||||
accumulated_content = ""
|
||||
|
||||
# Send any remaining content (if not in middle of comment)
|
||||
if accumulated_content and "<!--" not in accumulated_content:
|
||||
if stream_callback:
|
||||
await stream_callback(create_token_event(accumulated_content))
|
||||
|
||||
return {"messages": [AIMessage(content=response_content)]}
|
||||
|
||||
finally:
|
||||
# ✅ STEP 3: Restore tool binding for next interaction
|
||||
llm_client.bind_tools(tool_schemas, force_tool_choice=True)
|
||||
|
||||
else:
|
||||
logger.info(f"User Manual tool calling round {current_round + 1}/{max_rounds}")
|
||||
|
||||
# ✅ STEP 1: Non-streaming detection to check for tool needs
|
||||
draft = await llm_client.ainvoke_with_tools(list(messages))
|
||||
|
||||
# ✅ STEP 2: If draft has tool_calls, execute them within this node
|
||||
if isinstance(draft, AIMessage) and hasattr(draft, 'tool_calls') and draft.tool_calls:
|
||||
logger.info(f"Detected {len(draft.tool_calls)} tool calls, executing within user manual agent")
|
||||
|
||||
# Create a new state with the tool call message added
|
||||
tool_call_state = state.copy()
|
||||
updated_messages = state["messages"].copy()
|
||||
updated_messages.append(draft)
|
||||
tool_call_state["messages"] = updated_messages
|
||||
|
||||
# Execute the tools using the existing streaming tool execution function
|
||||
tool_results = await run_user_manual_tools_with_streaming(tool_call_state)
|
||||
tool_messages = tool_results.get("messages", [])
|
||||
|
||||
# Increment tool round counter for next iteration
|
||||
new_tool_rounds = current_round + 1
|
||||
logger.info(f"Incremented user manual tool_rounds to {new_tool_rounds}")
|
||||
|
||||
# Continue with another round if under max rounds
|
||||
if new_tool_rounds < max_rounds:
|
||||
# Recursive call for next round with all messages
|
||||
final_messages = updated_messages + tool_messages
|
||||
recursive_state = state.copy()
|
||||
recursive_state["messages"] = final_messages
|
||||
recursive_state["tool_rounds"] = new_tool_rounds
|
||||
return await user_manual_agent_node(recursive_state)
|
||||
else:
|
||||
# Max rounds reached, force final synthesis
|
||||
logger.info("Max tool rounds reached, forcing final synthesis")
|
||||
# Update messages for final synthesis
|
||||
messages = updated_messages + tool_messages
|
||||
# Continue to final synthesis below
|
||||
|
||||
# ✅ STEP 3: No tool_calls needed or max rounds reached → Enter final synthesis with streaming
|
||||
# Temporarily disable tools to prevent accidental tool calling during synthesis
|
||||
try:
|
||||
llm_client.bind_tools([], force_tool_choice=False) # Disable tools
|
||||
|
||||
if not stream_callback:
|
||||
# No streaming callback, use the draft we already have
|
||||
return {"messages": [draft]}
|
||||
|
||||
# ✅ STEP 4: Streaming final synthesis with improved HTML comment filtering
|
||||
response_content = ""
|
||||
accumulated_content = ""
|
||||
|
||||
async for token in llm_client.astream(list(messages)):
|
||||
accumulated_content += token
|
||||
response_content += token
|
||||
|
||||
# Check for complete HTML comments in accumulated content
|
||||
while "<!--" in accumulated_content and "-->" in accumulated_content:
|
||||
comment_start = accumulated_content.find("<!--")
|
||||
comment_end = accumulated_content.find("-->", comment_start)
|
||||
|
||||
if comment_start >= 0 and comment_end >= 0:
|
||||
# Send content before comment
|
||||
before_comment = accumulated_content[:comment_start]
|
||||
if stream_callback and before_comment:
|
||||
await stream_callback(create_token_event(before_comment))
|
||||
|
||||
# Skip the comment and continue with content after
|
||||
accumulated_content = accumulated_content[comment_end + 3:]
|
||||
else:
|
||||
break
|
||||
|
||||
# Send accumulated content if no pending comment
|
||||
if "<!--" not in accumulated_content:
|
||||
if stream_callback and accumulated_content:
|
||||
await stream_callback(create_token_event(accumulated_content))
|
||||
accumulated_content = ""
|
||||
|
||||
# Send any remaining content (if not in middle of comment)
|
||||
if accumulated_content and "<!--" not in accumulated_content:
|
||||
if stream_callback:
|
||||
await stream_callback(create_token_event(accumulated_content))
|
||||
|
||||
return {"messages": [AIMessage(content=response_content)]}
|
||||
|
||||
finally:
|
||||
# ✅ STEP 5: Restore tool binding for next interaction
|
||||
llm_client.bind_tools(tool_schemas, force_tool_choice=True)
|
||||
|
||||
|
||||
def render_conversation_history(messages, max_messages: int = 10) -> str:
|
||||
"""Render conversation history for context"""
|
||||
recent_messages = messages[-max_messages:] if len(messages) > max_messages else messages
|
||||
lines = []
|
||||
for msg in recent_messages:
|
||||
if hasattr(msg, 'content'):
|
||||
content = msg.content
|
||||
if isinstance(content, str):
|
||||
# Determine message type by class name or other attributes
|
||||
if 'Human' in str(type(msg)):
|
||||
lines.append(f"<user>{content}</user>")
|
||||
elif 'AI' in str(type(msg)):
|
||||
lines.append(f"<ai>{content}</ai>")
|
||||
elif isinstance(content, list):
|
||||
content_str = " ".join([str(item) for item in content if isinstance(item, str)])
|
||||
if 'Human' in str(type(msg)):
|
||||
lines.append(f"<user>{content_str}</user>")
|
||||
elif 'AI' in str(type(msg)):
|
||||
lines.append(f"<ai>{content_str}</ai>")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
# User Manual Tools routing condition
|
||||
def user_manual_should_continue(state: AgentState) -> Literal["user_manual_tools", "user_manual_agent", "post_process"]:
|
||||
"""
|
||||
Routing logic for user manual agent:
|
||||
- has tool_calls → route to user_manual_tools
|
||||
- no tool_calls → route to post_process (final synthesis already completed)
|
||||
"""
|
||||
messages = state["messages"]
|
||||
if not messages:
|
||||
logger.info("user_manual_should_continue: No messages, routing to post_process")
|
||||
return "post_process"
|
||||
|
||||
last_message = messages[-1]
|
||||
current_round = state.get("tool_rounds", 0)
|
||||
# Get max_tool_rounds_user_manual from state, fallback to config if not set
|
||||
max_rounds = state.get("max_tool_rounds_user_manual", None)
|
||||
if max_rounds is None:
|
||||
app_config = get_cached_config()
|
||||
max_rounds = app_config.app.max_tool_rounds_user_manual
|
||||
|
||||
logger.info(f"user_manual_should_continue: Last message type: {type(last_message)}, tool_rounds: {current_round}/{max_rounds}")
|
||||
|
||||
# If last message is AI message with tool calls, route to tools
|
||||
if isinstance(last_message, AIMessage):
|
||||
has_tool_calls = hasattr(last_message, 'tool_calls') and last_message.tool_calls
|
||||
logger.info(f"user_manual_should_continue: AI message has tool_calls: {has_tool_calls}")
|
||||
|
||||
if has_tool_calls:
|
||||
logger.info("user_manual_should_continue: Routing to user_manual_tools")
|
||||
return "user_manual_tools"
|
||||
else:
|
||||
# No tool calls = final synthesis already completed in user_manual_agent_node
|
||||
logger.info("user_manual_should_continue: No tool calls, routing to post_process")
|
||||
return "post_process"
|
||||
|
||||
# If last message is tool message(s), continue with agent for next round or final synthesis
|
||||
if isinstance(last_message, ToolMessage):
|
||||
logger.info("user_manual_should_continue: Tool message completed, continuing to user_manual_agent")
|
||||
return "user_manual_agent"
|
||||
|
||||
logger.info("user_manual_should_continue: Routing to post_process")
|
||||
return "post_process"
|
||||
|
||||
|
||||
# User Manual Tools node with streaming support
|
||||
async def run_user_manual_tools_with_streaming(state: AgentState, config: Optional[RunnableConfig] = None) -> Dict[str, Any]:
|
||||
"""Execute user manual tools with streaming events - supports parallel execution"""
|
||||
messages = state["messages"]
|
||||
last_message = messages[-1]
|
||||
|
||||
# Get stream callback from context variable
|
||||
from .graph import stream_callback_context
|
||||
stream_callback = stream_callback_context.get()
|
||||
|
||||
if not isinstance(last_message, AIMessage) or not hasattr(last_message, 'tool_calls'):
|
||||
return {"messages": []}
|
||||
|
||||
tool_calls = last_message.tool_calls or []
|
||||
tool_results = []
|
||||
new_messages = []
|
||||
|
||||
# User manual tools mapping
|
||||
tools_map = get_user_manual_tools_by_name()
|
||||
|
||||
async def execute_single_tool(tool_call):
|
||||
"""Execute a single user manual tool call with enhanced error handling"""
|
||||
# Get stream callback from context
|
||||
from .graph import stream_callback_context
|
||||
stream_callback = stream_callback_context.get()
|
||||
|
||||
# Apply error handling decorator
|
||||
@handle_async_errors(
|
||||
ErrorCategory.TOOL,
|
||||
ErrorCode.TOOL_ERROR,
|
||||
stream_callback,
|
||||
tool_call.get("id", "unknown") if isinstance(tool_call, dict) else "unknown"
|
||||
)
|
||||
async def _execute():
|
||||
# Validate tool_call format
|
||||
if not isinstance(tool_call, dict):
|
||||
raise ValueError(f"Tool call must be dict, got {type(tool_call)}")
|
||||
|
||||
tool_name = tool_call.get("name")
|
||||
tool_args = tool_call.get("args", {})
|
||||
tool_id = tool_call.get("id", "unknown")
|
||||
|
||||
if not tool_name:
|
||||
raise ValueError("Tool call missing 'name' field")
|
||||
|
||||
if tool_name not in tools_map:
|
||||
available_tools = list(tools_map.keys())
|
||||
raise ValueError(f"Tool '{tool_name}' not found. Available user manual tools: {available_tools}")
|
||||
|
||||
tool_func = tools_map[tool_name]
|
||||
|
||||
# Stream tool start event
|
||||
if stream_callback:
|
||||
await stream_callback(create_tool_start_event(tool_id, tool_name, tool_args))
|
||||
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# Execute the user manual tool
|
||||
result = await tool_func.ainvoke(tool_args)
|
||||
|
||||
# Calculate execution time
|
||||
took_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
# Stream tool result event
|
||||
if stream_callback:
|
||||
await stream_callback(create_tool_result_event(tool_id, tool_name, result, took_ms))
|
||||
|
||||
# Create tool message
|
||||
tool_message = ToolMessage(
|
||||
content=str(result),
|
||||
tool_call_id=tool_id,
|
||||
name=tool_name
|
||||
)
|
||||
|
||||
return tool_message, {"name": tool_name, "result": result, "took_ms": took_ms}
|
||||
|
||||
except Exception as e:
|
||||
took_ms = int((time.time() - start_time) * 1000)
|
||||
error_msg = get_user_message(ErrorCategory.TOOL)
|
||||
|
||||
# Stream tool error event
|
||||
if stream_callback:
|
||||
await stream_callback(create_tool_error_event(tool_id, tool_name, error_msg))
|
||||
|
||||
# Create error tool message
|
||||
tool_message = ToolMessage(
|
||||
content=f"Error executing {tool_name}: {error_msg}",
|
||||
tool_call_id=tool_id,
|
||||
name=tool_name
|
||||
)
|
||||
|
||||
return tool_message, {"name": tool_name, "error": error_msg, "took_ms": took_ms}
|
||||
|
||||
return await _execute()
|
||||
|
||||
# Execute user manual tools (typically just one for user manual retrieval)
|
||||
import asyncio
|
||||
tasks = [execute_single_tool(tool_call) for tool_call in tool_calls]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
for i, result in enumerate(results):
|
||||
if isinstance(result, Exception):
|
||||
# Handle execution exception
|
||||
tool_call = tool_calls[i]
|
||||
tool_id = tool_call.get("id", f"error_{i}") or f"error_{i}"
|
||||
tool_name = tool_call.get("name", "unknown")
|
||||
error_msg = get_user_message(ErrorCategory.TOOL)
|
||||
|
||||
if stream_callback:
|
||||
await stream_callback(create_tool_error_event(tool_id, tool_name, error_msg))
|
||||
|
||||
error_message = ToolMessage(
|
||||
content=f"Error executing {tool_name}: {error_msg}",
|
||||
tool_call_id=tool_id,
|
||||
name=tool_name
|
||||
)
|
||||
new_messages.append(error_message)
|
||||
elif isinstance(result, tuple) and len(result) == 2:
|
||||
# result is a tuple: (tool_message, tool_result)
|
||||
tool_message, tool_result = result
|
||||
new_messages.append(tool_message)
|
||||
tool_results.append(tool_result)
|
||||
else:
|
||||
# Unexpected result format
|
||||
logger.error(f"Unexpected tool execution result format: {type(result)}")
|
||||
continue
|
||||
|
||||
return {"messages": new_messages, "tool_results": tool_results}
|
||||
|
||||
|
||||
# Legacy function for backward compatibility
|
||||
async def user_manual_rag_node(state: AgentState, config: Optional[RunnableConfig] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Legacy user manual RAG node - redirects to new agent-based implementation
|
||||
"""
|
||||
logger.info("📚 USER_MANUAL_RAG_NODE: Redirecting to user_manual_agent_node")
|
||||
return await user_manual_agent_node(state, config)
|
||||
Reference in New Issue
Block a user