init
This commit is contained in:
1
vw-agentic-rag/service/graph/__init__.py
Normal file
1
vw-agentic-rag/service/graph/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Empty __init__.py files to make packages
|
||||
746
vw-agentic-rag/service/graph/graph.py
Normal file
746
vw-agentic-rag/service/graph/graph.py
Normal file
@@ -0,0 +1,746 @@
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import asyncio
|
||||
from typing import Dict, Any, List, Callable, Annotated, Literal, TypedDict, Optional, Union, cast
|
||||
from datetime import datetime
|
||||
from urllib.parse import quote
|
||||
from contextvars import ContextVar
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langgraph.graph import StateGraph, END, add_messages, MessagesState
|
||||
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, ToolMessage, BaseMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from .state import TurnState, Message, ToolResult, AgentState
|
||||
from .message_trimmer import create_conversation_trimmer
|
||||
from .tools import get_tool_schemas, get_tools_by_name
|
||||
from .user_manual_tools import get_user_manual_tools_by_name
|
||||
from .intent_recognition import intent_recognition_node, intent_router
|
||||
from .user_manual_rag import user_manual_rag_node
|
||||
from ..llm_client import LLMClient
|
||||
from ..config import get_config
|
||||
from ..utils.templates import render_prompt_template
|
||||
from ..memory.postgresql_memory import get_checkpointer
|
||||
from ..utils.error_handler import (
|
||||
StructuredLogger, ErrorCategory, ErrorCode,
|
||||
handle_async_errors, get_user_message
|
||||
)
|
||||
from ..sse import (
|
||||
create_tool_start_event,
|
||||
create_tool_result_event,
|
||||
create_tool_error_event,
|
||||
create_token_event,
|
||||
create_error_event
|
||||
)
|
||||
|
||||
logger = StructuredLogger(__name__)
|
||||
|
||||
# Cache configuration at module level to avoid repeated get_config() calls
|
||||
_cached_config = None
|
||||
|
||||
def get_cached_config():
|
||||
"""Get cached configuration, loading it if not already cached"""
|
||||
global _cached_config
|
||||
if _cached_config is None:
|
||||
_cached_config = get_config()
|
||||
return _cached_config
|
||||
|
||||
# Context variable for streaming callback (thread-safe)
|
||||
stream_callback_context: ContextVar[Optional[Callable]] = ContextVar('stream_callback', default=None)
|
||||
|
||||
|
||||
# Agent node (autonomous function calling agent)
|
||||
async def call_model(state: AgentState, config: Optional[RunnableConfig] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Agent node that autonomously uses tools and generates final answer.
|
||||
Implements "detect-first-then-stream" strategy for optimal multi-round behavior:
|
||||
1. Always start with non-streaming detection to check for tool needs
|
||||
2. If tool_calls exist → return immediately for routing to tools
|
||||
3. If no tool_calls → temporarily disable tools and perform streaming final synthesis
|
||||
"""
|
||||
app_config = get_cached_config()
|
||||
llm_client = LLMClient()
|
||||
|
||||
# Get stream callback from context variable
|
||||
stream_callback = stream_callback_context.get()
|
||||
|
||||
# Get tool schemas and bind tools for planning phase
|
||||
tool_schemas = get_tool_schemas()
|
||||
llm_client.bind_tools(tool_schemas, force_tool_choice=True)
|
||||
|
||||
# Create conversation trimmer for managing context length
|
||||
trimmer = create_conversation_trimmer()
|
||||
|
||||
# Prepare messages with system prompt
|
||||
messages = state["messages"].copy()
|
||||
if not messages or not isinstance(messages[0], SystemMessage):
|
||||
rag_prompts = app_config.get_rag_prompts()
|
||||
system_prompt = rag_prompts.get("agent_system_prompt", "")
|
||||
if not system_prompt:
|
||||
raise ValueError("system_prompt is null")
|
||||
|
||||
messages = [SystemMessage(content=system_prompt)] + messages
|
||||
|
||||
# Track tool rounds
|
||||
current_round = state.get("tool_rounds", 0)
|
||||
# Get max_tool_rounds from state, fallback to config if not set
|
||||
max_rounds = state.get("max_tool_rounds", None)
|
||||
if max_rounds is None:
|
||||
max_rounds = app_config.app.max_tool_rounds
|
||||
|
||||
# Only apply trimming at the start of a new conversation turn (when tool_rounds = 0)
|
||||
# This prevents trimming current turn's tool results during multi-round tool calling
|
||||
if current_round == 0:
|
||||
# Trim conversation history to manage context length (only for previous conversation turns)
|
||||
if trimmer.should_trim(messages):
|
||||
messages = trimmer.trim_conversation_history(messages)
|
||||
logger.info("Applied conversation history trimming for context management (new conversation turn)")
|
||||
else:
|
||||
logger.info(f"Skipping trimming during tool round {current_round} to preserve current turn's context")
|
||||
|
||||
logger.info(f"Agent node: tool_rounds={current_round}, max_tool_rounds={max_rounds}")
|
||||
|
||||
# Check if this should be final synthesis (max rounds reached)
|
||||
has_tool_messages = any(isinstance(msg, ToolMessage) for msg in messages)
|
||||
is_final_synthesis = has_tool_messages and current_round >= max_rounds
|
||||
|
||||
if is_final_synthesis:
|
||||
logger.info("Starting final synthesis phase - no more tool calls allowed")
|
||||
# ✅ STEP 1: Final synthesis with tools disabled from the start
|
||||
# Disable tools to prevent any tool calling during synthesis
|
||||
try:
|
||||
original_tools = llm_client.bind_tools([], force_tool_choice=False) # Disable tools
|
||||
|
||||
if not stream_callback:
|
||||
# No streaming callback, generate final response without tools
|
||||
draft = await llm_client.ainvoke(list(messages))
|
||||
return {"messages": [draft]}
|
||||
|
||||
# ✅ STEP 2: Streaming final synthesis with improved HTML comment filtering
|
||||
response_content = ""
|
||||
accumulated_content = ""
|
||||
|
||||
async for token in llm_client.astream(list(messages)):
|
||||
accumulated_content += token
|
||||
response_content += token
|
||||
|
||||
# Check for complete HTML comments in accumulated content
|
||||
while "<!--" in accumulated_content and "-->" in accumulated_content:
|
||||
comment_start = accumulated_content.find("<!--")
|
||||
comment_end = accumulated_content.find("-->", comment_start)
|
||||
|
||||
if comment_start >= 0 and comment_end >= 0:
|
||||
# Send content before comment
|
||||
before_comment = accumulated_content[:comment_start]
|
||||
if stream_callback and before_comment:
|
||||
await stream_callback(create_token_event(before_comment))
|
||||
|
||||
# Skip the comment and continue with content after
|
||||
accumulated_content = accumulated_content[comment_end + 3:]
|
||||
else:
|
||||
break
|
||||
|
||||
# Send accumulated content if no pending comment
|
||||
if "<!--" not in accumulated_content:
|
||||
if stream_callback and accumulated_content:
|
||||
await stream_callback(create_token_event(accumulated_content))
|
||||
accumulated_content = ""
|
||||
|
||||
# Send any remaining content (if not in middle of comment)
|
||||
if accumulated_content and "<!--" not in accumulated_content:
|
||||
if stream_callback:
|
||||
await stream_callback(create_token_event(accumulated_content))
|
||||
|
||||
return {"messages": [AIMessage(content=response_content)]}
|
||||
|
||||
finally:
|
||||
# ✅ STEP 3: Restore tool binding for next interaction
|
||||
llm_client.bind_tools(tool_schemas, force_tool_choice=True)
|
||||
|
||||
else:
|
||||
logger.info(f"Tool calling round {current_round + 1}/{max_rounds}")
|
||||
|
||||
# ✅ STEP 1: Non-streaming detection to check for tool needs
|
||||
draft = await llm_client.ainvoke_with_tools(list(messages))
|
||||
|
||||
# ✅ STEP 2: If draft has tool_calls, return immediately (let routing handle it)
|
||||
if isinstance(draft, AIMessage) and hasattr(draft, 'tool_calls') and draft.tool_calls:
|
||||
# Increment tool round counter for next iteration
|
||||
new_tool_rounds = current_round + 1
|
||||
logger.info(f"Incremented tool_rounds to {new_tool_rounds}")
|
||||
return {"messages": [draft], "tool_rounds": new_tool_rounds}
|
||||
|
||||
# ✅ STEP 3: No tool_calls needed → Enter final synthesis with streaming
|
||||
# Temporarily disable tools to prevent accidental tool calling during synthesis
|
||||
try:
|
||||
llm_client.bind_tools([], force_tool_choice=False) # Disable tools
|
||||
|
||||
if not stream_callback:
|
||||
# No streaming callback, use the draft we already have
|
||||
return {"messages": [draft]}
|
||||
|
||||
# ✅ STEP 4: Streaming final synthesis with improved HTML comment filtering
|
||||
response_content = ""
|
||||
accumulated_content = ""
|
||||
|
||||
async for token in llm_client.astream(list(messages)):
|
||||
accumulated_content += token
|
||||
response_content += token
|
||||
|
||||
# Check for complete HTML comments in accumulated content
|
||||
while "<!--" in accumulated_content and "-->" in accumulated_content:
|
||||
comment_start = accumulated_content.find("<!--")
|
||||
comment_end = accumulated_content.find("-->", comment_start)
|
||||
|
||||
if comment_start >= 0 and comment_end >= 0:
|
||||
# Send content before comment
|
||||
before_comment = accumulated_content[:comment_start]
|
||||
if stream_callback and before_comment:
|
||||
await stream_callback(create_token_event(before_comment))
|
||||
|
||||
# Skip the comment and continue with content after
|
||||
accumulated_content = accumulated_content[comment_end + 3:]
|
||||
else:
|
||||
break
|
||||
|
||||
# Send accumulated content if no pending comment
|
||||
if "<!--" not in accumulated_content:
|
||||
if stream_callback and accumulated_content:
|
||||
await stream_callback(create_token_event(accumulated_content))
|
||||
accumulated_content = ""
|
||||
|
||||
# Send any remaining content (if not in middle of comment)
|
||||
if accumulated_content and "<!--" not in accumulated_content:
|
||||
if stream_callback:
|
||||
await stream_callback(create_token_event(accumulated_content))
|
||||
|
||||
return {"messages": [AIMessage(content=response_content)]}
|
||||
|
||||
finally:
|
||||
# ✅ STEP 5: Restore tool binding for next interaction
|
||||
llm_client.bind_tools(tool_schemas, force_tool_choice=True)
|
||||
|
||||
|
||||
# Tools routing condition (simplified for "detect-first-then-stream" strategy)
|
||||
def should_continue(state: AgentState) -> Literal["tools", "agent", "post_process"]:
|
||||
"""
|
||||
Simplified routing logic for "detect-first-then-stream" strategy:
|
||||
- has tool_calls → route to tools
|
||||
- no tool_calls → route to post_process (final synthesis already completed)
|
||||
"""
|
||||
messages = state["messages"]
|
||||
if not messages:
|
||||
logger.info("should_continue: No messages, routing to post_process")
|
||||
return "post_process"
|
||||
|
||||
last_message = messages[-1]
|
||||
current_round = state.get("tool_rounds", 0)
|
||||
# Get max_tool_rounds from state, fallback to config if not set
|
||||
max_rounds = state.get("max_tool_rounds", None)
|
||||
if max_rounds is None:
|
||||
app_config = get_cached_config()
|
||||
max_rounds = app_config.app.max_tool_rounds
|
||||
|
||||
logger.info(f"should_continue: Last message type: {type(last_message)}, tool_rounds: {current_round}/{max_rounds}")
|
||||
|
||||
# If last message is AI message with tool calls, route to tools
|
||||
if isinstance(last_message, AIMessage):
|
||||
has_tool_calls = hasattr(last_message, 'tool_calls') and last_message.tool_calls
|
||||
logger.info(f"should_continue: AI message has tool_calls: {has_tool_calls}")
|
||||
|
||||
if has_tool_calls:
|
||||
logger.info("should_continue: Routing to tools")
|
||||
return "tools"
|
||||
else:
|
||||
# No tool calls = final synthesis already completed in call_model
|
||||
logger.info("should_continue: No tool calls, routing to post_process")
|
||||
return "post_process"
|
||||
|
||||
# If last message is tool message(s), continue with agent for next round or final synthesis
|
||||
if isinstance(last_message, ToolMessage):
|
||||
logger.info("should_continue: Tool message completed, continuing to agent")
|
||||
return "agent"
|
||||
|
||||
logger.info("should_continue: Routing to post_process")
|
||||
return "post_process"
|
||||
|
||||
|
||||
# Custom tool node with streaming support and parallel execution
|
||||
async def run_tools_with_streaming(state: AgentState, config: Optional[RunnableConfig] = None) -> Dict[str, Any]:
|
||||
"""Execute tools with streaming events - supports parallel execution"""
|
||||
messages = state["messages"]
|
||||
last_message = messages[-1]
|
||||
|
||||
# Get stream callback from context variable
|
||||
stream_callback = stream_callback_context.get()
|
||||
|
||||
if not isinstance(last_message, AIMessage) or not hasattr(last_message, 'tool_calls'):
|
||||
return {"messages": []}
|
||||
|
||||
tool_calls = last_message.tool_calls or []
|
||||
tool_results = []
|
||||
new_messages = []
|
||||
|
||||
# Tools mapping
|
||||
tools_map = get_tools_by_name()
|
||||
|
||||
async def execute_single_tool(tool_call):
|
||||
"""Execute a single tool call with enhanced error handling"""
|
||||
# Get stream callback from context
|
||||
stream_callback = stream_callback_context.get()
|
||||
|
||||
# Apply error handling decorator
|
||||
@handle_async_errors(
|
||||
ErrorCategory.TOOL,
|
||||
ErrorCode.TOOL_ERROR,
|
||||
stream_callback,
|
||||
tool_call.get("id", "unknown") if isinstance(tool_call, dict) else "unknown"
|
||||
)
|
||||
async def _execute():
|
||||
# Validate tool_call format
|
||||
if not isinstance(tool_call, dict):
|
||||
raise ValueError(f"Tool call must be dict, got {type(tool_call)}")
|
||||
|
||||
tool_name = tool_call.get("name")
|
||||
tool_args = tool_call.get("args", {})
|
||||
tool_id = tool_call.get("id", "unknown")
|
||||
|
||||
if not tool_name or tool_name not in tools_map:
|
||||
raise ValueError(f"Tool '{tool_name}' not found")
|
||||
|
||||
logger.info(f"Executing tool: {tool_name}", extra={
|
||||
"tool_id": tool_id, "tool_name": tool_name
|
||||
})
|
||||
|
||||
# Send start event
|
||||
if stream_callback:
|
||||
await stream_callback(create_tool_start_event(tool_id, tool_name, tool_args))
|
||||
|
||||
# Execute tool
|
||||
import time
|
||||
start_time = time.time()
|
||||
result = await tools_map[tool_name].ainvoke(tool_args)
|
||||
execution_time = int((time.time() - start_time) * 1000)
|
||||
|
||||
# Process result
|
||||
if isinstance(result, dict):
|
||||
result["tool_call_id"] = tool_id
|
||||
if "results" in result and isinstance(result["results"], list):
|
||||
for i, search_result in enumerate(result["results"]):
|
||||
if isinstance(search_result, dict):
|
||||
search_result["@tool_call_id"] = tool_id
|
||||
search_result["@order_num"] = i
|
||||
|
||||
# Send result event
|
||||
if stream_callback:
|
||||
await stream_callback(create_tool_result_event(
|
||||
tool_id, tool_name, result.get("results", []), execution_time
|
||||
))
|
||||
|
||||
# Create tool message
|
||||
tool_message = ToolMessage(
|
||||
content=json.dumps(result, ensure_ascii=False),
|
||||
tool_call_id=tool_id,
|
||||
name=tool_name
|
||||
)
|
||||
|
||||
return {
|
||||
"message": tool_message,
|
||||
"results": result.get("results", []) if isinstance(result, dict) else [],
|
||||
"success": True
|
||||
}
|
||||
|
||||
try:
|
||||
return await _execute()
|
||||
except Exception as e:
|
||||
# Handle any errors not caught by decorator
|
||||
tool_id = tool_call.get("id", "unknown") if isinstance(tool_call, dict) else "unknown"
|
||||
tool_name = tool_call.get("name", "unknown") if isinstance(tool_call, dict) else "unknown"
|
||||
|
||||
error_message = ToolMessage(
|
||||
content=f"Error: {get_user_message(ErrorCategory.TOOL)}",
|
||||
tool_call_id=tool_id,
|
||||
name=tool_name
|
||||
)
|
||||
|
||||
return {
|
||||
"message": error_message,
|
||||
"results": [],
|
||||
"success": False
|
||||
}
|
||||
|
||||
# Execute all tool calls in parallel using asyncio.gather
|
||||
if tool_calls:
|
||||
logger.info(f"Executing {len(tool_calls)} tool calls in parallel")
|
||||
tool_execution_results = await asyncio.gather(
|
||||
*[execute_single_tool(tool_call) for tool_call in tool_calls],
|
||||
return_exceptions=True
|
||||
)
|
||||
|
||||
# Process results
|
||||
for execution_result in tool_execution_results:
|
||||
if execution_result is None:
|
||||
continue
|
||||
if isinstance(execution_result, Exception):
|
||||
logger.error(f"Tool execution exception: {execution_result}")
|
||||
continue
|
||||
if not isinstance(execution_result, dict):
|
||||
logger.error(f"Unexpected execution result type: {type(execution_result)}")
|
||||
continue
|
||||
|
||||
new_messages.append(execution_result["message"])
|
||||
if execution_result["success"] and execution_result["results"]:
|
||||
tool_results.extend(execution_result["results"])
|
||||
|
||||
logger.info(f"Parallel tool execution completed. {len(new_messages)} tools executed, {len(tool_results)} results collected")
|
||||
|
||||
return {
|
||||
"messages": new_messages,
|
||||
"tool_results": tool_results
|
||||
}
|
||||
|
||||
|
||||
# Helper functions for citation processing
|
||||
def _extract_citations_mapping(agent_response: str) -> Dict[int, Dict[str, Any]]:
|
||||
"""Extract citations mapping CSV from agent response HTML comment"""
|
||||
try:
|
||||
# Look for citations_map comment
|
||||
pattern = r'<!-- citations_map\s*(.*?)\s*-->'
|
||||
match = re.search(pattern, agent_response, re.DOTALL | re.IGNORECASE)
|
||||
|
||||
if not match:
|
||||
logger.warning("No citations_map comment found in agent response")
|
||||
return {}
|
||||
|
||||
csv_content = match.group(1).strip()
|
||||
citations_mapping = {}
|
||||
|
||||
for line in csv_content.split('\n'):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
parts = line.split(',')
|
||||
if len(parts) >= 3:
|
||||
try:
|
||||
citation_num = int(parts[0])
|
||||
tool_call_id = parts[1].strip()
|
||||
order_num = int(parts[2])
|
||||
|
||||
citations_mapping[citation_num] = {
|
||||
'tool_call_id': tool_call_id,
|
||||
'order_num': order_num
|
||||
}
|
||||
except (ValueError, IndexError) as e:
|
||||
logger.warning(f"Failed to parse citation line: {line}, error: {e}")
|
||||
continue
|
||||
|
||||
return citations_mapping
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting citations mapping: {e}")
|
||||
return {}
|
||||
|
||||
|
||||
def _build_citation_markdown(citations_mapping: Dict[int, Dict[str, Any]], tool_results: List[Dict[str, Any]]) -> str:
|
||||
"""Build citation markdown based on mapping and tool results, following build_citations.py logic"""
|
||||
if not citations_mapping:
|
||||
return ""
|
||||
|
||||
# Get configuration for citation base URL
|
||||
config = get_cached_config()
|
||||
cat_base_url = config.citation.base_url
|
||||
|
||||
# Collect citation lines first; only emit header if we have at least one valid citation
|
||||
entries: List[str] = []
|
||||
|
||||
for citation_num in sorted(citations_mapping.keys()):
|
||||
mapping = citations_mapping[citation_num]
|
||||
tool_call_id = mapping['tool_call_id']
|
||||
order_num = mapping['order_num']
|
||||
|
||||
# Find the corresponding tool result
|
||||
result = _find_tool_result(tool_results, tool_call_id, order_num)
|
||||
if not result:
|
||||
logger.warning(f"No tool result found for citation [{citation_num}]")
|
||||
continue
|
||||
|
||||
# Extract citation information following build_citations.py logic
|
||||
full_headers = result.get('full_headers', '')
|
||||
lowest_header = full_headers.split("||", 1)[0] if full_headers else ""
|
||||
header_display = f": {lowest_header}" if lowest_header else ""
|
||||
|
||||
document_code = result.get('document_code', '')
|
||||
document_category = result.get('document_category', '')
|
||||
|
||||
# Determine standard/regulation title (assuming English language)
|
||||
standard_regulation_title = ''
|
||||
if document_category == 'Standard':
|
||||
standard_regulation_title = result.get('x_Standard_Title_EN', '') or result.get('x_Standard_Title_CN', '')
|
||||
elif document_category == 'Regulation':
|
||||
standard_regulation_title = result.get('x_Regulation_Title_EN', '') or result.get('x_Regulation_Title_CN', '')
|
||||
|
||||
# Build link
|
||||
func_uuid = result.get('func_uuid', '')
|
||||
uuid = result.get('x_Standard_Regulation_Id', '')
|
||||
document_code_encoded = quote(document_code, safe='') if document_code else ''
|
||||
standard_regulation_title_encoded = quote(standard_regulation_title, safe='') if standard_regulation_title else ''
|
||||
link_name = f"{document_code_encoded}({standard_regulation_title_encoded})" if (document_code_encoded or standard_regulation_title_encoded) else ''
|
||||
link = f'{cat_base_url}?funcUuid={func_uuid}&uuid={uuid}&name={link_name}'
|
||||
|
||||
# Format citation line
|
||||
title = result.get('title', '')
|
||||
entries.append(f"[{citation_num}] {title}{header_display} | [{standard_regulation_title} | {document_code}]({link})")
|
||||
|
||||
# If no valid citations were found, do not include the header
|
||||
if not entries:
|
||||
return ""
|
||||
|
||||
# Build citations section with entries separated by a blank line (matching previous formatting)
|
||||
md = "\n\n### 📘 Citations:\n" + "\n\n".join(entries) + "\n\n"
|
||||
return md
|
||||
|
||||
|
||||
def _find_tool_result(tool_results: List[Dict[str, Any]], tool_call_id: str, order_num: int) -> Optional[Dict[str, Any]]:
|
||||
"""Find tool result by tool_call_id and order_num"""
|
||||
matching_results = []
|
||||
|
||||
for result in tool_results:
|
||||
if result.get('@tool_call_id') == tool_call_id:
|
||||
matching_results.append(result)
|
||||
|
||||
# Sort by order and return the one at the specified position
|
||||
if matching_results and 0 <= order_num < len(matching_results):
|
||||
# If results have @order_num, use it; otherwise use position in list
|
||||
if '@order_num' in matching_results[0]:
|
||||
for result in matching_results:
|
||||
if result.get('@order_num') == order_num:
|
||||
return result
|
||||
else:
|
||||
return matching_results[order_num]
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _remove_citations_comment(agent_response: str) -> str:
|
||||
"""Remove citations mapping HTML comment from agent response"""
|
||||
pattern = r'<!-- citations_map\s*.*?\s*-->'
|
||||
return re.sub(pattern, '', agent_response, flags=re.DOTALL | re.IGNORECASE).strip()
|
||||
|
||||
|
||||
# Post-processing node with citation list and link building
|
||||
async def post_process_node(state: AgentState, config: Optional[RunnableConfig] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Post-processing node that builds citation list and links based on agent's citations mapping
|
||||
and tool call results, following the logic from build_citations.py
|
||||
"""
|
||||
try:
|
||||
logger.info("🔧 POST_PROCESS_NODE: Starting citation processing")
|
||||
|
||||
# Get stream callback from context variable
|
||||
stream_callback = stream_callback_context.get()
|
||||
|
||||
# Get the last AI message (agent's response with citations mapping)
|
||||
agent_response = ""
|
||||
citations_mapping = {}
|
||||
|
||||
for message in reversed(state["messages"]):
|
||||
if isinstance(message, AIMessage) and message.content:
|
||||
# Ensure content is a string
|
||||
if isinstance(message.content, str):
|
||||
agent_response = message.content
|
||||
break
|
||||
|
||||
if not agent_response:
|
||||
logger.warning("POST_PROCESS_NODE: No agent response found")
|
||||
return {"messages": [], "final_answer": ""}
|
||||
|
||||
# Extract citations mapping from agent response
|
||||
citations_mapping = _extract_citations_mapping(agent_response)
|
||||
logger.info(f"POST_PROCESS_NODE: Extracted {len(citations_mapping)} citations")
|
||||
|
||||
# Build citation markdown
|
||||
citation_markdown = _build_citation_markdown(citations_mapping, state["tool_results"])
|
||||
|
||||
# Combine agent response (without HTML comment) with citations
|
||||
clean_response = _remove_citations_comment(agent_response)
|
||||
final_content = clean_response + citation_markdown
|
||||
|
||||
logger.info("POST_PROCESS_NODE: Built complete response with citations")
|
||||
|
||||
# Send citation markdown as a single block instead of streaming
|
||||
stream_callback = stream_callback_context.get()
|
||||
if stream_callback and citation_markdown:
|
||||
logger.info("POST_PROCESS_NODE: Sending citation markdown as single block to client")
|
||||
await stream_callback(create_token_event(citation_markdown))
|
||||
|
||||
# Create AI message with complete content
|
||||
final_ai_message = AIMessage(content=final_content)
|
||||
|
||||
return {
|
||||
"messages": [final_ai_message],
|
||||
"final_answer": final_content
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Post-processing error: {e}")
|
||||
error_message = "\n\n❌ **Error generating citations**\n\nPlease check the search results above."
|
||||
|
||||
# Send error message as single block
|
||||
stream_callback = stream_callback_context.get()
|
||||
if stream_callback:
|
||||
await stream_callback(create_token_event(error_message))
|
||||
|
||||
error_content = agent_response + error_message if agent_response else error_message
|
||||
error_ai_message = AIMessage(content=error_content)
|
||||
return {
|
||||
"messages": [error_ai_message],
|
||||
"final_answer": error_ai_message.content
|
||||
}
|
||||
|
||||
|
||||
# Main workflow class
|
||||
class AgenticWorkflow:
|
||||
"""LangGraph-based autonomous agent workflow following v0.6.0+ best practices"""
|
||||
|
||||
def __init__(self):
|
||||
# Build StateGraph with TypedDict state
|
||||
workflow = StateGraph(AgentState)
|
||||
|
||||
# Add nodes following best practices
|
||||
workflow.add_node("intent_recognition", intent_recognition_node)
|
||||
workflow.add_node("agent", call_model)
|
||||
workflow.add_node("user_manual_rag", user_manual_rag_node)
|
||||
workflow.add_node("tools", run_tools_with_streaming)
|
||||
workflow.add_node("post_process", post_process_node)
|
||||
|
||||
# Set entry point to intent recognition
|
||||
workflow.set_entry_point("intent_recognition")
|
||||
|
||||
# Intent recognition routes to either Standard_Regulation_RAG or User_Manual_RAG
|
||||
workflow.add_conditional_edges(
|
||||
"intent_recognition",
|
||||
intent_router,
|
||||
{
|
||||
"Standard_Regulation_RAG": "agent",
|
||||
"User_Manual_RAG": "user_manual_rag"
|
||||
}
|
||||
)
|
||||
|
||||
# Standard RAG workflow (existing pattern)
|
||||
workflow.add_conditional_edges(
|
||||
"agent",
|
||||
should_continue,
|
||||
{
|
||||
"tools": "tools",
|
||||
"agent": "agent", # Allow agent to continue for multi-round
|
||||
"post_process": "post_process"
|
||||
}
|
||||
)
|
||||
|
||||
# Tools route back to should_continue for multi-round decision
|
||||
workflow.add_conditional_edges(
|
||||
"tools",
|
||||
should_continue,
|
||||
{
|
||||
"agent": "agent", # Continue to agent for next round
|
||||
"post_process": "post_process" # Or finish if max rounds reached
|
||||
}
|
||||
)
|
||||
|
||||
# User Manual RAG directly goes to END (single turn)
|
||||
workflow.add_edge("user_manual_rag", END)
|
||||
|
||||
# Post-process is terminal
|
||||
workflow.add_edge("post_process", END)
|
||||
|
||||
# Compile graph with PostgreSQL checkpointer for session memory
|
||||
try:
|
||||
checkpointer = get_checkpointer()
|
||||
self.graph = workflow.compile(checkpointer=checkpointer)
|
||||
logger.info("Graph compiled with PostgreSQL checkpointer for session memory")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to initialize PostgreSQL checkpointer, using memory-only graph: {e}")
|
||||
self.graph = workflow.compile()
|
||||
|
||||
async def astream(self, state: TurnState, stream_callback: Callable | None = None):
|
||||
"""Stream agent execution using LangGraph with PostgreSQL session memory"""
|
||||
try:
|
||||
# Get configuration
|
||||
config = get_cached_config()
|
||||
|
||||
# Prepare initial messages for the graph
|
||||
messages = []
|
||||
for msg in state.messages:
|
||||
if msg.role == "user":
|
||||
messages.append(HumanMessage(content=msg.content))
|
||||
elif msg.role == "assistant":
|
||||
messages.append(AIMessage(content=msg.content))
|
||||
|
||||
# Create initial agent state (without stream_callback to avoid serialization issues)
|
||||
initial_state: AgentState = {
|
||||
"messages": messages,
|
||||
"session_id": state.session_id,
|
||||
"intent": None, # Will be determined by intent recognition node
|
||||
"tool_results": [],
|
||||
"final_answer": "",
|
||||
"tool_rounds": 0,
|
||||
"max_tool_rounds": config.app.max_tool_rounds, # Use configuration value
|
||||
"max_tool_rounds_user_manual": config.app.max_tool_rounds_user_manual # Use configuration value for user manual agent
|
||||
}
|
||||
|
||||
# Set stream callback in context variable (thread-safe)
|
||||
stream_callback_context.set(stream_callback)
|
||||
|
||||
# Create proper RunnableConfig
|
||||
runnable_config = RunnableConfig(configurable={"thread_id": state.session_id})
|
||||
|
||||
# Stream graph execution with session memory
|
||||
async for step in self.graph.astream(initial_state, config=runnable_config):
|
||||
if "post_process" in step:
|
||||
final_state = step["post_process"]
|
||||
|
||||
# Extract the tool summary message and update state
|
||||
state.final_answer = final_state.get("final_answer", "")
|
||||
|
||||
# Add the summary as a regular assistant message
|
||||
if state.final_answer:
|
||||
state.messages.append(Message(
|
||||
role="assistant",
|
||||
content=state.final_answer,
|
||||
timestamp=datetime.now()
|
||||
))
|
||||
|
||||
yield {"final": state}
|
||||
break
|
||||
elif "user_manual_rag" in step:
|
||||
# Handle user manual RAG completion
|
||||
final_state = step["user_manual_rag"]
|
||||
|
||||
# Extract the response from user manual RAG
|
||||
state.final_answer = final_state.get("final_answer", "")
|
||||
|
||||
# Add the response as a regular assistant message
|
||||
if state.final_answer:
|
||||
state.messages.append(Message(
|
||||
role="assistant",
|
||||
content=state.final_answer,
|
||||
timestamp=datetime.now()
|
||||
))
|
||||
|
||||
yield {"final": state}
|
||||
break
|
||||
else:
|
||||
# Process regular steps (intent_recognition, agent, tools)
|
||||
yield step
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"AgentWorkflow error: {e}")
|
||||
state.final_answer = "I apologize, but I encountered an error while processing your request."
|
||||
yield {"final": state}
|
||||
|
||||
|
||||
def build_graph() -> AgenticWorkflow:
|
||||
"""Build and return the autonomous agent workflow"""
|
||||
return AgenticWorkflow()
|
||||
136
vw-agentic-rag/service/graph/intent_recognition.py
Normal file
136
vw-agentic-rag/service/graph/intent_recognition.py
Normal file
@@ -0,0 +1,136 @@
|
||||
"""
|
||||
Intent recognition functionality for the Agentic RAG system.
|
||||
This module contains the intent classification logic.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Any, Optional, Literal
|
||||
from langchain_core.messages import SystemMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .state import AgentState
|
||||
from ..llm_client import LLMClient
|
||||
from ..config import get_config
|
||||
from ..utils.error_handler import StructuredLogger
|
||||
|
||||
logger = StructuredLogger(__name__)
|
||||
|
||||
|
||||
# Intent Recognition Models
|
||||
class Intent(BaseModel):
|
||||
"""Intent classification model for routing user queries"""
|
||||
label: Literal["Standard_Regulation_RAG", "User_Manual_RAG"]
|
||||
confidence: Optional[float] = None
|
||||
|
||||
|
||||
def get_last_user_message(messages) -> str:
|
||||
"""Extract the last user message from conversation history"""
|
||||
for message in reversed(messages):
|
||||
if hasattr(message, 'content'):
|
||||
content = message.content
|
||||
# Handle both string and list content
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
elif isinstance(content, list):
|
||||
# Extract string content from list
|
||||
return " ".join([str(item) for item in content if isinstance(item, str)])
|
||||
return ""
|
||||
|
||||
|
||||
def render_conversation_history(messages, max_messages: int = 10) -> str:
|
||||
"""Render conversation history for context"""
|
||||
recent_messages = messages[-max_messages:] if len(messages) > max_messages else messages
|
||||
lines = []
|
||||
for msg in recent_messages:
|
||||
if hasattr(msg, 'content'):
|
||||
content = msg.content
|
||||
if isinstance(content, str):
|
||||
# Determine message type by class name or other attributes
|
||||
if 'Human' in str(type(msg)):
|
||||
lines.append(f"<user>{content}</user>")
|
||||
elif 'AI' in str(type(msg)):
|
||||
lines.append(f"<ai>{content}</ai>")
|
||||
elif isinstance(content, list):
|
||||
content_str = " ".join([str(item) for item in content if isinstance(item, str)])
|
||||
if 'Human' in str(type(msg)):
|
||||
lines.append(f"<user>{content_str}</user>")
|
||||
elif 'AI' in str(type(msg)):
|
||||
lines.append(f"<ai>{content_str}</ai>")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
async def intent_recognition_node(state: AgentState, config: Optional[RunnableConfig] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Intent recognition node that uses LLM to classify user queries into specific domains
|
||||
"""
|
||||
try:
|
||||
logger.info("🎯 INTENT_RECOGNITION_NODE: Starting intent classification")
|
||||
|
||||
app_config = get_config()
|
||||
llm_client = LLMClient()
|
||||
|
||||
# Get current user query and conversation history
|
||||
current_query = get_last_user_message(state["messages"])
|
||||
conversation_context = render_conversation_history(state["messages"])
|
||||
|
||||
# Get intent classification prompt from configuration
|
||||
rag_prompts = app_config.get_rag_prompts()
|
||||
intent_prompt_template = rag_prompts.get("intent_recognition_prompt")
|
||||
|
||||
if not intent_prompt_template:
|
||||
logger.error("Intent recognition prompt not found in configuration")
|
||||
return {"intent": "Standard_Regulation_RAG"}
|
||||
|
||||
# Format the prompt with instruction to return only the label
|
||||
system_prompt = intent_prompt_template.format(
|
||||
current_query=current_query,
|
||||
conversation_context=conversation_context
|
||||
) + "\n\nIMPORTANT: You must respond with ONLY one of these two exact labels: 'Standard_Regulation_RAG' or 'User_Manual_RAG'. Do not include any other text or explanation."
|
||||
|
||||
# Classify intent using regular LLM call
|
||||
intent_result = await llm_client.llm.ainvoke([
|
||||
SystemMessage(content=system_prompt)
|
||||
])
|
||||
|
||||
# Parse the response to extract the intent label
|
||||
response_text = ""
|
||||
if hasattr(intent_result, 'content') and intent_result.content:
|
||||
if isinstance(intent_result.content, str):
|
||||
response_text = intent_result.content.strip()
|
||||
elif isinstance(intent_result.content, list):
|
||||
# Handle list content by joining string elements
|
||||
response_text = " ".join([str(item) for item in intent_result.content if isinstance(item, str)]).strip()
|
||||
|
||||
# Extract intent label from response
|
||||
if "User_Manual_RAG" in response_text:
|
||||
intent_label = "User_Manual_RAG"
|
||||
elif "Standard_Regulation_RAG" in response_text:
|
||||
intent_label = "Standard_Regulation_RAG"
|
||||
else:
|
||||
# Default fallback
|
||||
logger.warning(f"Could not parse intent from response: {response_text}, defaulting to Standard_Regulation_RAG")
|
||||
intent_label = "Standard_Regulation_RAG"
|
||||
|
||||
logger.info(f"🎯 INTENT_RECOGNITION_NODE: Classified intent as '{intent_label}'")
|
||||
|
||||
return {"intent": intent_label}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Intent recognition error: {e}")
|
||||
# Default to Standard_Regulation_RAG if classification fails
|
||||
logger.info("🎯 INTENT_RECOGNITION_NODE: Defaulting to Standard_Regulation_RAG due to error")
|
||||
return {"intent": "Standard_Regulation_RAG"}
|
||||
|
||||
|
||||
def intent_router(state: AgentState) -> Literal["Standard_Regulation_RAG", "User_Manual_RAG"]:
|
||||
"""
|
||||
Route based on intent classification result
|
||||
"""
|
||||
intent = state.get("intent")
|
||||
if intent is None:
|
||||
logger.warning("🎯 INTENT_ROUTER: No intent found, defaulting to Standard_Regulation_RAG")
|
||||
return "Standard_Regulation_RAG"
|
||||
|
||||
logger.info(f"🎯 INTENT_ROUTER: Routing to {intent}")
|
||||
return intent
|
||||
270
vw-agentic-rag/service/graph/message_trimmer.py
Normal file
270
vw-agentic-rag/service/graph/message_trimmer.py
Normal file
@@ -0,0 +1,270 @@
|
||||
"""
|
||||
Conversation history trimming utilities for managing context length.
|
||||
"""
|
||||
import logging
|
||||
from typing import List, Optional, Sequence, Tuple
|
||||
from langchain_core.messages import BaseMessage, SystemMessage, HumanMessage, ToolMessage, AIMessage, AnyMessage
|
||||
from langchain_core.messages.utils import trim_messages, count_tokens_approximately
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ConversationTrimmer:
|
||||
"""
|
||||
Manages conversation history to prevent exceeding LLM context limits.
|
||||
"""
|
||||
|
||||
def __init__(self, max_context_length: int = 96000, preserve_system: bool = True):
|
||||
"""
|
||||
Initialize the conversation trimmer.
|
||||
|
||||
Args:
|
||||
max_context_length: Maximum context length for conversation history (in tokens)
|
||||
preserve_system: Whether to always preserve system messages
|
||||
"""
|
||||
self.max_context_length = max_context_length
|
||||
self.preserve_system = preserve_system
|
||||
# Reserve tokens for response generation (use 85% for history, 15% for response)
|
||||
self.history_token_limit = int(max_context_length * 0.85)
|
||||
|
||||
def trim_conversation_history(self, messages: Sequence[AnyMessage]) -> List[BaseMessage]:
|
||||
"""
|
||||
Trim conversation history to fit within token limits.
|
||||
|
||||
Args:
|
||||
messages: List of conversation messages
|
||||
|
||||
Returns:
|
||||
Trimmed list of messages
|
||||
"""
|
||||
if not messages:
|
||||
return list(messages)
|
||||
|
||||
try:
|
||||
# Convert to list for processing
|
||||
message_list = list(messages)
|
||||
|
||||
# First, try multi-round tool call optimization
|
||||
optimized_messages = self._optimize_multi_round_tool_calls(message_list)
|
||||
|
||||
# Check if optimization is sufficient
|
||||
try:
|
||||
token_count = count_tokens_approximately(optimized_messages)
|
||||
if token_count <= self.history_token_limit:
|
||||
original_count = len(message_list)
|
||||
optimized_count = len(optimized_messages)
|
||||
if optimized_count < original_count:
|
||||
logger.info(f"Multi-round tool optimization: {original_count} -> {optimized_count} messages")
|
||||
return optimized_messages
|
||||
except Exception:
|
||||
# If token counting fails, continue with LangChain trimming
|
||||
pass
|
||||
|
||||
# If still too long, use LangChain's trim_messages utility
|
||||
trimmed_messages = trim_messages(
|
||||
optimized_messages,
|
||||
strategy="last", # Keep most recent messages
|
||||
token_counter=count_tokens_approximately,
|
||||
max_tokens=self.history_token_limit,
|
||||
start_on="human", # Ensure valid conversation start
|
||||
end_on=("human", "tool", "ai"), # Allow ending on human, tool, or AI messages
|
||||
include_system=self.preserve_system, # Preserve system messages
|
||||
allow_partial=False # Don't split individual messages
|
||||
)
|
||||
|
||||
original_count = len(messages)
|
||||
trimmed_count = len(trimmed_messages)
|
||||
|
||||
if trimmed_count < original_count:
|
||||
logger.info(f"Trimmed conversation history: {original_count} -> {trimmed_count} messages")
|
||||
|
||||
return trimmed_messages
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error trimming conversation history: {e}")
|
||||
# Fallback: keep last N messages
|
||||
return self._fallback_trim(list(messages))
|
||||
|
||||
def _optimize_multi_round_tool_calls(self, messages: List[AnyMessage]) -> List[BaseMessage]:
|
||||
"""
|
||||
Optimize conversation history by removing older tool call results in multi-round scenarios.
|
||||
This reduces token usage while preserving conversation context.
|
||||
|
||||
Strategy:
|
||||
1. Always preserve system messages
|
||||
2. Always preserve the original user query
|
||||
3. Keep the most recent AI-Tool message pairs (for context continuity)
|
||||
4. Remove older ToolMessage content which typically contains large JSON responses
|
||||
|
||||
Args:
|
||||
messages: List of conversation messages
|
||||
|
||||
Returns:
|
||||
Optimized list of messages
|
||||
"""
|
||||
if len(messages) <= 4: # Too short to optimize
|
||||
return [msg for msg in messages]
|
||||
|
||||
# Identify message patterns
|
||||
tool_rounds = self._identify_tool_rounds(messages)
|
||||
|
||||
if len(tool_rounds) <= 1: # Single or no tool round, no optimization needed
|
||||
return [msg for msg in messages]
|
||||
|
||||
logger.info(f"Multi-round tool optimization: Found {len(tool_rounds)} tool rounds")
|
||||
|
||||
# Build optimized message list
|
||||
optimized = []
|
||||
|
||||
# Always preserve system messages
|
||||
for msg in messages:
|
||||
if isinstance(msg, SystemMessage):
|
||||
optimized.append(msg)
|
||||
|
||||
# Preserve initial user query (first human message after system)
|
||||
first_human_added = False
|
||||
for msg in messages:
|
||||
if isinstance(msg, HumanMessage) and not first_human_added:
|
||||
optimized.append(msg)
|
||||
first_human_added = True
|
||||
break
|
||||
|
||||
# Keep only the most recent tool round (preserve context for next round)
|
||||
if tool_rounds:
|
||||
latest_round_start, latest_round_end = tool_rounds[-1]
|
||||
|
||||
# Add messages from the latest tool round
|
||||
for i in range(latest_round_start, min(latest_round_end + 1, len(messages))):
|
||||
msg = messages[i]
|
||||
if not isinstance(msg, SystemMessage) and not (isinstance(msg, HumanMessage) and not first_human_added):
|
||||
optimized.append(msg)
|
||||
|
||||
logger.info(f"Multi-round optimization: {len(messages)} -> {len(optimized)} messages (removed {len(tool_rounds)-1} older tool rounds)")
|
||||
return optimized
|
||||
|
||||
def _identify_tool_rounds(self, messages: List[AnyMessage]) -> List[Tuple[int, int]]:
|
||||
"""
|
||||
Identify tool calling rounds in the message sequence.
|
||||
|
||||
A tool round typically consists of:
|
||||
- AI message with tool_calls
|
||||
- One or more ToolMessage responses
|
||||
|
||||
Returns:
|
||||
List of (start_index, end_index) tuples for each tool round
|
||||
"""
|
||||
rounds = []
|
||||
i = 0
|
||||
|
||||
while i < len(messages):
|
||||
msg = messages[i]
|
||||
|
||||
# Look for AI message with tool calls
|
||||
if isinstance(msg, AIMessage) and hasattr(msg, 'tool_calls') and msg.tool_calls:
|
||||
round_start = i
|
||||
round_end = i
|
||||
|
||||
# Find the end of this tool round (look for consecutive ToolMessages)
|
||||
j = i + 1
|
||||
while j < len(messages) and isinstance(messages[j], ToolMessage):
|
||||
round_end = j
|
||||
j += 1
|
||||
|
||||
# Only consider it a tool round if we found at least one ToolMessage
|
||||
if round_end > round_start:
|
||||
rounds.append((round_start, round_end))
|
||||
i = round_end + 1
|
||||
else:
|
||||
i += 1
|
||||
else:
|
||||
i += 1
|
||||
|
||||
return rounds
|
||||
|
||||
def _fallback_trim(self, messages: List[AnyMessage], max_messages: int = 20) -> List[BaseMessage]:
|
||||
"""
|
||||
Fallback trimming based on message count.
|
||||
|
||||
Args:
|
||||
messages: List of conversation messages
|
||||
max_messages: Maximum number of messages to keep
|
||||
|
||||
Returns:
|
||||
Trimmed list of messages
|
||||
"""
|
||||
if len(messages) <= max_messages:
|
||||
return [msg for msg in messages] # Convert to BaseMessage
|
||||
|
||||
# Preserve system message if it exists
|
||||
system_messages = [msg for msg in messages if isinstance(msg, SystemMessage)]
|
||||
other_messages = [msg for msg in messages if not isinstance(msg, SystemMessage)]
|
||||
|
||||
# Keep the most recent messages
|
||||
recent_messages = other_messages[-(max_messages - len(system_messages)):]
|
||||
|
||||
result = system_messages + recent_messages
|
||||
logger.info(f"Fallback trimming: {len(messages)} -> {len(result)} messages")
|
||||
|
||||
return [msg for msg in result] # Ensure BaseMessage type
|
||||
|
||||
def should_trim(self, messages: Sequence[AnyMessage]) -> bool:
|
||||
"""
|
||||
Check if conversation history should be trimmed.
|
||||
|
||||
Strategy:
|
||||
1. Always trim if there are multiple tool rounds from previous conversation turns
|
||||
2. Also trim if approaching token limit
|
||||
|
||||
Args:
|
||||
messages: List of conversation messages
|
||||
|
||||
Returns:
|
||||
True if trimming is needed
|
||||
"""
|
||||
try:
|
||||
# Convert to list for processing
|
||||
message_list = list(messages)
|
||||
|
||||
# Check for multiple tool rounds - if found, always trim to remove old tool results
|
||||
tool_rounds = self._identify_tool_rounds(message_list)
|
||||
if len(tool_rounds) > 1:
|
||||
logger.info(f"Found {len(tool_rounds)} tool rounds - trimming to remove old tool results")
|
||||
return True
|
||||
|
||||
# Also check token count for traditional trimming
|
||||
token_count = count_tokens_approximately(message_list)
|
||||
return token_count > self.history_token_limit
|
||||
except Exception:
|
||||
# Fallback to message count
|
||||
return len(messages) > 30
|
||||
|
||||
|
||||
def create_conversation_trimmer(max_context_length: Optional[int] = None) -> ConversationTrimmer:
|
||||
"""
|
||||
Create a conversation trimmer with config-based settings.
|
||||
|
||||
Args:
|
||||
max_context_length: Override for maximum context length
|
||||
|
||||
Returns:
|
||||
ConversationTrimmer instance
|
||||
"""
|
||||
# If max_context_length is provided, use it directly
|
||||
if max_context_length is not None:
|
||||
return ConversationTrimmer(
|
||||
max_context_length=max_context_length,
|
||||
preserve_system=True
|
||||
)
|
||||
|
||||
# Try to get from config, fallback to default if config not available
|
||||
try:
|
||||
from ..config import get_config
|
||||
config = get_config()
|
||||
effective_max_context_length = config.get_max_context_length()
|
||||
except (RuntimeError, AttributeError):
|
||||
effective_max_context_length = 96000
|
||||
|
||||
return ConversationTrimmer(
|
||||
max_context_length=effective_max_context_length,
|
||||
preserve_system=True
|
||||
)
|
||||
66
vw-agentic-rag/service/graph/state.py
Normal file
66
vw-agentic-rag/service/graph/state.py
Normal file
@@ -0,0 +1,66 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Dict, Any, Optional, Literal
|
||||
from datetime import datetime
|
||||
from typing_extensions import Annotated
|
||||
from langgraph.graph.message import add_messages
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
"""Base message class for conversation history"""
|
||||
role: str # "user", "assistant", "tool"
|
||||
content: str
|
||||
timestamp: Optional[datetime] = None
|
||||
tool_call_id: Optional[str] = None
|
||||
tool_name: Optional[str] = None
|
||||
|
||||
|
||||
class Citation(BaseModel):
|
||||
"""Citation mapping between numbers and result IDs"""
|
||||
number: int
|
||||
result_id: str
|
||||
url: Optional[str] = None
|
||||
|
||||
|
||||
class ToolResult(BaseModel):
|
||||
"""Normalized tool result schema"""
|
||||
id: str
|
||||
title: str
|
||||
url: Optional[str] = None
|
||||
score: Optional[float] = None
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
content: Optional[str] = None # For chunk results
|
||||
# Standard/regulation specific fields
|
||||
publisher: Optional[str] = None
|
||||
publish_date: Optional[str] = None
|
||||
document_code: Optional[str] = None
|
||||
document_category: Optional[str] = None
|
||||
|
||||
|
||||
class TurnState(BaseModel):
|
||||
"""State container for LangGraph workflow"""
|
||||
session_id: str
|
||||
messages: List[Message] = Field(default_factory=list)
|
||||
tool_results: List[ToolResult] = Field(default_factory=list)
|
||||
citations: List[Citation] = Field(default_factory=list)
|
||||
meta: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
# Additional fields for tracking
|
||||
current_step: int = 0
|
||||
max_steps: int = 5
|
||||
final_answer: Optional[str] = None
|
||||
|
||||
|
||||
# TypedDict for LangGraph AgentState (LangGraph native format)
|
||||
from typing import TypedDict
|
||||
from langgraph.graph import MessagesState
|
||||
|
||||
class AgentState(MessagesState):
|
||||
"""LangGraph state with intent recognition support"""
|
||||
session_id: str
|
||||
intent: Optional[Literal["Standard_Regulation_RAG", "User_Manual_RAG"]]
|
||||
tool_results: Annotated[List[Dict[str, Any]], lambda x, y: (x or []) + (y or [])]
|
||||
final_answer: str
|
||||
tool_rounds: int
|
||||
max_tool_rounds: int
|
||||
max_tool_rounds_user_manual: int
|
||||
98
vw-agentic-rag/service/graph/tools.py
Normal file
98
vw-agentic-rag/service/graph/tools.py
Normal file
@@ -0,0 +1,98 @@
|
||||
"""
|
||||
Tool definitions and schemas for the Agentic RAG system.
|
||||
This module contains all tool implementations and their corresponding schemas.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Any, List
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from ..retrieval.retrieval import AgenticRetrieval
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Tool Definitions using @tool decorator (following LangGraph best practices)
|
||||
@tool
|
||||
async def retrieve_standard_regulation(query: str) -> Dict[str, Any]:
|
||||
"""Search for attributes/metadata of China standards and regulations in automobile/manufacturing industry"""
|
||||
async with AgenticRetrieval() as retrieval:
|
||||
try:
|
||||
result = await retrieval.retrieve_standard_regulation(
|
||||
query=query
|
||||
)
|
||||
return {
|
||||
"tool_name": "retrieve_standard_regulation",
|
||||
"results_count": len(result.results),
|
||||
"results": result.results, # Already dict objects, no need for model_dump()
|
||||
"took_ms": result.took_ms
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Retrieval error: {e}")
|
||||
return {"error": str(e), "results_count": 0, "results": []}
|
||||
|
||||
|
||||
@tool
|
||||
async def retrieve_doc_chunk_standard_regulation(query: str) -> Dict[str, Any]:
|
||||
"""Search for detailed document content chunks of China standards and regulations in automobile/manufacturing industry"""
|
||||
async with AgenticRetrieval() as retrieval:
|
||||
try:
|
||||
result = await retrieval.retrieve_doc_chunk_standard_regulation(
|
||||
query=query
|
||||
)
|
||||
return {
|
||||
"tool_name": "retrieve_doc_chunk_standard_regulation",
|
||||
"results_count": len(result.results),
|
||||
"results": result.results, # Already dict objects, no need for model_dump()
|
||||
"took_ms": result.took_ms
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Doc chunk retrieval error: {e}")
|
||||
return {"error": str(e), "results_count": 0, "results": []}
|
||||
|
||||
|
||||
# Available tools list
|
||||
tools = [retrieve_standard_regulation, retrieve_doc_chunk_standard_regulation]
|
||||
|
||||
|
||||
def get_tool_schemas() -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Generate tool schemas for LLM function calling.
|
||||
|
||||
Returns:
|
||||
List of tool schemas in OpenAI function calling format
|
||||
"""
|
||||
tools.append();
|
||||
|
||||
tool_schemas = []
|
||||
for tool in tools:
|
||||
schema = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Search query for retrieving relevant information"
|
||||
}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
}
|
||||
}
|
||||
tool_schemas.append(schema)
|
||||
|
||||
return tool_schemas
|
||||
|
||||
|
||||
def get_tools_by_name() -> Dict[str, Any]:
|
||||
"""
|
||||
Create a mapping of tool names to tool functions.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping tool names to tool functions
|
||||
"""
|
||||
return {tool.name: tool for tool in tools}
|
||||
464
vw-agentic-rag/service/graph/user_manual_rag.py
Normal file
464
vw-agentic-rag/service/graph/user_manual_rag.py
Normal file
@@ -0,0 +1,464 @@
|
||||
"""
|
||||
User Manual Agent node for the Agentic RAG system.
|
||||
This module contains the autonomous user manual agent that can use tools and generate responses.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Any, List, Optional, Callable, Literal
|
||||
from contextvars import ContextVar
|
||||
from langchain_core.messages import AIMessage, SystemMessage, BaseMessage, ToolMessage, HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from .state import AgentState
|
||||
from .user_manual_tools import get_user_manual_tool_schemas, get_user_manual_tools_by_name
|
||||
from .message_trimmer import create_conversation_trimmer
|
||||
from ..llm_client import LLMClient
|
||||
from ..config import get_config
|
||||
from ..sse import (
|
||||
create_tool_start_event,
|
||||
create_tool_result_event,
|
||||
create_tool_error_event,
|
||||
create_token_event,
|
||||
create_error_event
|
||||
)
|
||||
from ..utils.error_handler import (
|
||||
StructuredLogger, ErrorCategory, ErrorCode,
|
||||
handle_async_errors, get_user_message
|
||||
)
|
||||
|
||||
logger = StructuredLogger(__name__)
|
||||
|
||||
# Cache configuration at module level to avoid repeated get_config() calls
|
||||
_cached_config = None
|
||||
|
||||
def get_cached_config():
|
||||
"""Get cached configuration, loading it if not already cached"""
|
||||
global _cached_config
|
||||
if _cached_config is None:
|
||||
_cached_config = get_config()
|
||||
return _cached_config
|
||||
|
||||
|
||||
# User Manual Agent node (autonomous function calling agent)
|
||||
async def user_manual_agent_node(state: AgentState, config: Optional[RunnableConfig] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
User Manual Agent node that autonomously uses user manual tools and generates final answer.
|
||||
Implements "detect-first-then-stream" strategy for optimal multi-round behavior:
|
||||
1. Always start with non-streaming detection to check for tool needs
|
||||
2. If tool_calls exist → return immediately for routing to tools
|
||||
3. If no tool_calls → temporarily disable tools and perform streaming final synthesis
|
||||
"""
|
||||
app_config = get_cached_config()
|
||||
llm_client = LLMClient()
|
||||
|
||||
# Get stream callback from context variable
|
||||
from .graph import stream_callback_context
|
||||
stream_callback = stream_callback_context.get()
|
||||
|
||||
# Get user manual tool schemas and bind tools for planning phase
|
||||
tool_schemas = get_user_manual_tool_schemas()
|
||||
llm_client.bind_tools(tool_schemas, force_tool_choice=True)
|
||||
|
||||
# Create conversation trimmer for managing context length
|
||||
trimmer = create_conversation_trimmer()
|
||||
|
||||
# Prepare messages with user manual system prompt
|
||||
messages = state["messages"].copy()
|
||||
if not messages or not isinstance(messages[0], SystemMessage):
|
||||
rag_prompts = app_config.get_rag_prompts()
|
||||
user_manual_prompt = rag_prompts.get("user_manual_prompt", "")
|
||||
if not user_manual_prompt:
|
||||
raise ValueError("user_manual_prompt is null")
|
||||
|
||||
# For user manual agent, we need to format the prompt with placeholders
|
||||
# Extract current query and conversation history
|
||||
current_query = ""
|
||||
for message in reversed(messages):
|
||||
if isinstance(message, HumanMessage):
|
||||
current_query = message.content
|
||||
break
|
||||
|
||||
conversation_history = ""
|
||||
if len(messages) > 1:
|
||||
conversation_history = render_conversation_history(messages[:-1]) # Exclude current query
|
||||
|
||||
# Format system prompt (initially with empty context, tools will provide it)
|
||||
formatted_system_prompt = user_manual_prompt.format(
|
||||
conversation_history=conversation_history,
|
||||
context_content="", # Will be filled by tools
|
||||
current_query=current_query
|
||||
)
|
||||
|
||||
messages = [SystemMessage(content=formatted_system_prompt)] + messages
|
||||
|
||||
# Track tool rounds
|
||||
current_round = state.get("tool_rounds", 0)
|
||||
# Get max_tool_rounds_user_manual from state, fallback to config if not set
|
||||
max_rounds = state.get("max_tool_rounds_user_manual", None)
|
||||
if max_rounds is None:
|
||||
max_rounds = app_config.app.max_tool_rounds_user_manual
|
||||
|
||||
# Only apply trimming at the start of a new conversation turn (when tool_rounds = 0)
|
||||
# This prevents trimming current turn's tool results during multi-round tool calling
|
||||
if current_round == 0:
|
||||
# Trim conversation history to manage context length (only for previous conversation turns)
|
||||
if trimmer.should_trim(messages):
|
||||
messages = trimmer.trim_conversation_history(messages)
|
||||
logger.info("Applied conversation history trimming for context management (new conversation turn)")
|
||||
else:
|
||||
logger.info(f"Skipping trimming during tool round {current_round} to preserve current turn's context")
|
||||
|
||||
logger.info(f"User Manual Agent node: tool_rounds={current_round}, max_tool_rounds={max_rounds}")
|
||||
|
||||
# Check if this should be final synthesis (max rounds reached)
|
||||
has_tool_messages = any(isinstance(msg, ToolMessage) for msg in messages)
|
||||
is_final_synthesis = has_tool_messages and current_round >= max_rounds
|
||||
|
||||
if is_final_synthesis:
|
||||
logger.info("Starting final synthesis phase - no more tool calls allowed")
|
||||
# ✅ STEP 1: Final synthesis with tools disabled from the start
|
||||
# Disable tools to prevent any tool calling during synthesis
|
||||
try:
|
||||
original_tools = llm_client.bind_tools([], force_tool_choice=False) # Disable tools
|
||||
|
||||
if not stream_callback:
|
||||
# No streaming callback, generate final response without tools
|
||||
draft = await llm_client.ainvoke(list(messages))
|
||||
return {"messages": [draft]}
|
||||
|
||||
# ✅ STEP 2: Streaming final synthesis with improved HTML comment filtering
|
||||
response_content = ""
|
||||
accumulated_content = ""
|
||||
|
||||
async for token in llm_client.astream(list(messages)):
|
||||
accumulated_content += token
|
||||
response_content += token
|
||||
|
||||
# Check for complete HTML comments in accumulated content
|
||||
while "<!--" in accumulated_content and "-->" in accumulated_content:
|
||||
comment_start = accumulated_content.find("<!--")
|
||||
comment_end = accumulated_content.find("-->", comment_start)
|
||||
|
||||
if comment_start >= 0 and comment_end >= 0:
|
||||
# Send content before comment
|
||||
before_comment = accumulated_content[:comment_start]
|
||||
if stream_callback and before_comment:
|
||||
await stream_callback(create_token_event(before_comment))
|
||||
|
||||
# Skip the comment and continue with content after
|
||||
accumulated_content = accumulated_content[comment_end + 3:]
|
||||
else:
|
||||
break
|
||||
|
||||
# Send accumulated content if no pending comment
|
||||
if "<!--" not in accumulated_content:
|
||||
if stream_callback and accumulated_content:
|
||||
await stream_callback(create_token_event(accumulated_content))
|
||||
accumulated_content = ""
|
||||
|
||||
# Send any remaining content (if not in middle of comment)
|
||||
if accumulated_content and "<!--" not in accumulated_content:
|
||||
if stream_callback:
|
||||
await stream_callback(create_token_event(accumulated_content))
|
||||
|
||||
return {"messages": [AIMessage(content=response_content)]}
|
||||
|
||||
finally:
|
||||
# ✅ STEP 3: Restore tool binding for next interaction
|
||||
llm_client.bind_tools(tool_schemas, force_tool_choice=True)
|
||||
|
||||
else:
|
||||
logger.info(f"User Manual tool calling round {current_round + 1}/{max_rounds}")
|
||||
|
||||
# ✅ STEP 1: Non-streaming detection to check for tool needs
|
||||
draft = await llm_client.ainvoke_with_tools(list(messages))
|
||||
|
||||
# ✅ STEP 2: If draft has tool_calls, execute them within this node
|
||||
if isinstance(draft, AIMessage) and hasattr(draft, 'tool_calls') and draft.tool_calls:
|
||||
logger.info(f"Detected {len(draft.tool_calls)} tool calls, executing within user manual agent")
|
||||
|
||||
# Create a new state with the tool call message added
|
||||
tool_call_state = state.copy()
|
||||
updated_messages = state["messages"].copy()
|
||||
updated_messages.append(draft)
|
||||
tool_call_state["messages"] = updated_messages
|
||||
|
||||
# Execute the tools using the existing streaming tool execution function
|
||||
tool_results = await run_user_manual_tools_with_streaming(tool_call_state)
|
||||
tool_messages = tool_results.get("messages", [])
|
||||
|
||||
# Increment tool round counter for next iteration
|
||||
new_tool_rounds = current_round + 1
|
||||
logger.info(f"Incremented user manual tool_rounds to {new_tool_rounds}")
|
||||
|
||||
# Continue with another round if under max rounds
|
||||
if new_tool_rounds < max_rounds:
|
||||
# Recursive call for next round with all messages
|
||||
final_messages = updated_messages + tool_messages
|
||||
recursive_state = state.copy()
|
||||
recursive_state["messages"] = final_messages
|
||||
recursive_state["tool_rounds"] = new_tool_rounds
|
||||
return await user_manual_agent_node(recursive_state)
|
||||
else:
|
||||
# Max rounds reached, force final synthesis
|
||||
logger.info("Max tool rounds reached, forcing final synthesis")
|
||||
# Update messages for final synthesis
|
||||
messages = updated_messages + tool_messages
|
||||
# Continue to final synthesis below
|
||||
|
||||
# ✅ STEP 3: No tool_calls needed or max rounds reached → Enter final synthesis with streaming
|
||||
# Temporarily disable tools to prevent accidental tool calling during synthesis
|
||||
try:
|
||||
llm_client.bind_tools([], force_tool_choice=False) # Disable tools
|
||||
|
||||
if not stream_callback:
|
||||
# No streaming callback, use the draft we already have
|
||||
return {"messages": [draft]}
|
||||
|
||||
# ✅ STEP 4: Streaming final synthesis with improved HTML comment filtering
|
||||
response_content = ""
|
||||
accumulated_content = ""
|
||||
|
||||
async for token in llm_client.astream(list(messages)):
|
||||
accumulated_content += token
|
||||
response_content += token
|
||||
|
||||
# Check for complete HTML comments in accumulated content
|
||||
while "<!--" in accumulated_content and "-->" in accumulated_content:
|
||||
comment_start = accumulated_content.find("<!--")
|
||||
comment_end = accumulated_content.find("-->", comment_start)
|
||||
|
||||
if comment_start >= 0 and comment_end >= 0:
|
||||
# Send content before comment
|
||||
before_comment = accumulated_content[:comment_start]
|
||||
if stream_callback and before_comment:
|
||||
await stream_callback(create_token_event(before_comment))
|
||||
|
||||
# Skip the comment and continue with content after
|
||||
accumulated_content = accumulated_content[comment_end + 3:]
|
||||
else:
|
||||
break
|
||||
|
||||
# Send accumulated content if no pending comment
|
||||
if "<!--" not in accumulated_content:
|
||||
if stream_callback and accumulated_content:
|
||||
await stream_callback(create_token_event(accumulated_content))
|
||||
accumulated_content = ""
|
||||
|
||||
# Send any remaining content (if not in middle of comment)
|
||||
if accumulated_content and "<!--" not in accumulated_content:
|
||||
if stream_callback:
|
||||
await stream_callback(create_token_event(accumulated_content))
|
||||
|
||||
return {"messages": [AIMessage(content=response_content)]}
|
||||
|
||||
finally:
|
||||
# ✅ STEP 5: Restore tool binding for next interaction
|
||||
llm_client.bind_tools(tool_schemas, force_tool_choice=True)
|
||||
|
||||
|
||||
def render_conversation_history(messages, max_messages: int = 10) -> str:
|
||||
"""Render conversation history for context"""
|
||||
recent_messages = messages[-max_messages:] if len(messages) > max_messages else messages
|
||||
lines = []
|
||||
for msg in recent_messages:
|
||||
if hasattr(msg, 'content'):
|
||||
content = msg.content
|
||||
if isinstance(content, str):
|
||||
# Determine message type by class name or other attributes
|
||||
if 'Human' in str(type(msg)):
|
||||
lines.append(f"<user>{content}</user>")
|
||||
elif 'AI' in str(type(msg)):
|
||||
lines.append(f"<ai>{content}</ai>")
|
||||
elif isinstance(content, list):
|
||||
content_str = " ".join([str(item) for item in content if isinstance(item, str)])
|
||||
if 'Human' in str(type(msg)):
|
||||
lines.append(f"<user>{content_str}</user>")
|
||||
elif 'AI' in str(type(msg)):
|
||||
lines.append(f"<ai>{content_str}</ai>")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
# User Manual Tools routing condition
|
||||
def user_manual_should_continue(state: AgentState) -> Literal["user_manual_tools", "user_manual_agent", "post_process"]:
|
||||
"""
|
||||
Routing logic for user manual agent:
|
||||
- has tool_calls → route to user_manual_tools
|
||||
- no tool_calls → route to post_process (final synthesis already completed)
|
||||
"""
|
||||
messages = state["messages"]
|
||||
if not messages:
|
||||
logger.info("user_manual_should_continue: No messages, routing to post_process")
|
||||
return "post_process"
|
||||
|
||||
last_message = messages[-1]
|
||||
current_round = state.get("tool_rounds", 0)
|
||||
# Get max_tool_rounds_user_manual from state, fallback to config if not set
|
||||
max_rounds = state.get("max_tool_rounds_user_manual", None)
|
||||
if max_rounds is None:
|
||||
app_config = get_cached_config()
|
||||
max_rounds = app_config.app.max_tool_rounds_user_manual
|
||||
|
||||
logger.info(f"user_manual_should_continue: Last message type: {type(last_message)}, tool_rounds: {current_round}/{max_rounds}")
|
||||
|
||||
# If last message is AI message with tool calls, route to tools
|
||||
if isinstance(last_message, AIMessage):
|
||||
has_tool_calls = hasattr(last_message, 'tool_calls') and last_message.tool_calls
|
||||
logger.info(f"user_manual_should_continue: AI message has tool_calls: {has_tool_calls}")
|
||||
|
||||
if has_tool_calls:
|
||||
logger.info("user_manual_should_continue: Routing to user_manual_tools")
|
||||
return "user_manual_tools"
|
||||
else:
|
||||
# No tool calls = final synthesis already completed in user_manual_agent_node
|
||||
logger.info("user_manual_should_continue: No tool calls, routing to post_process")
|
||||
return "post_process"
|
||||
|
||||
# If last message is tool message(s), continue with agent for next round or final synthesis
|
||||
if isinstance(last_message, ToolMessage):
|
||||
logger.info("user_manual_should_continue: Tool message completed, continuing to user_manual_agent")
|
||||
return "user_manual_agent"
|
||||
|
||||
logger.info("user_manual_should_continue: Routing to post_process")
|
||||
return "post_process"
|
||||
|
||||
|
||||
# User Manual Tools node with streaming support
|
||||
async def run_user_manual_tools_with_streaming(state: AgentState, config: Optional[RunnableConfig] = None) -> Dict[str, Any]:
|
||||
"""Execute user manual tools with streaming events - supports parallel execution"""
|
||||
messages = state["messages"]
|
||||
last_message = messages[-1]
|
||||
|
||||
# Get stream callback from context variable
|
||||
from .graph import stream_callback_context
|
||||
stream_callback = stream_callback_context.get()
|
||||
|
||||
if not isinstance(last_message, AIMessage) or not hasattr(last_message, 'tool_calls'):
|
||||
return {"messages": []}
|
||||
|
||||
tool_calls = last_message.tool_calls or []
|
||||
tool_results = []
|
||||
new_messages = []
|
||||
|
||||
# User manual tools mapping
|
||||
tools_map = get_user_manual_tools_by_name()
|
||||
|
||||
async def execute_single_tool(tool_call):
|
||||
"""Execute a single user manual tool call with enhanced error handling"""
|
||||
# Get stream callback from context
|
||||
from .graph import stream_callback_context
|
||||
stream_callback = stream_callback_context.get()
|
||||
|
||||
# Apply error handling decorator
|
||||
@handle_async_errors(
|
||||
ErrorCategory.TOOL,
|
||||
ErrorCode.TOOL_ERROR,
|
||||
stream_callback,
|
||||
tool_call.get("id", "unknown") if isinstance(tool_call, dict) else "unknown"
|
||||
)
|
||||
async def _execute():
|
||||
# Validate tool_call format
|
||||
if not isinstance(tool_call, dict):
|
||||
raise ValueError(f"Tool call must be dict, got {type(tool_call)}")
|
||||
|
||||
tool_name = tool_call.get("name")
|
||||
tool_args = tool_call.get("args", {})
|
||||
tool_id = tool_call.get("id", "unknown")
|
||||
|
||||
if not tool_name:
|
||||
raise ValueError("Tool call missing 'name' field")
|
||||
|
||||
if tool_name not in tools_map:
|
||||
available_tools = list(tools_map.keys())
|
||||
raise ValueError(f"Tool '{tool_name}' not found. Available user manual tools: {available_tools}")
|
||||
|
||||
tool_func = tools_map[tool_name]
|
||||
|
||||
# Stream tool start event
|
||||
if stream_callback:
|
||||
await stream_callback(create_tool_start_event(tool_id, tool_name, tool_args))
|
||||
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# Execute the user manual tool
|
||||
result = await tool_func.ainvoke(tool_args)
|
||||
|
||||
# Calculate execution time
|
||||
took_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
# Stream tool result event
|
||||
if stream_callback:
|
||||
await stream_callback(create_tool_result_event(tool_id, tool_name, result, took_ms))
|
||||
|
||||
# Create tool message
|
||||
tool_message = ToolMessage(
|
||||
content=str(result),
|
||||
tool_call_id=tool_id,
|
||||
name=tool_name
|
||||
)
|
||||
|
||||
return tool_message, {"name": tool_name, "result": result, "took_ms": took_ms}
|
||||
|
||||
except Exception as e:
|
||||
took_ms = int((time.time() - start_time) * 1000)
|
||||
error_msg = get_user_message(ErrorCategory.TOOL)
|
||||
|
||||
# Stream tool error event
|
||||
if stream_callback:
|
||||
await stream_callback(create_tool_error_event(tool_id, tool_name, error_msg))
|
||||
|
||||
# Create error tool message
|
||||
tool_message = ToolMessage(
|
||||
content=f"Error executing {tool_name}: {error_msg}",
|
||||
tool_call_id=tool_id,
|
||||
name=tool_name
|
||||
)
|
||||
|
||||
return tool_message, {"name": tool_name, "error": error_msg, "took_ms": took_ms}
|
||||
|
||||
return await _execute()
|
||||
|
||||
# Execute user manual tools (typically just one for user manual retrieval)
|
||||
import asyncio
|
||||
tasks = [execute_single_tool(tool_call) for tool_call in tool_calls]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
for i, result in enumerate(results):
|
||||
if isinstance(result, Exception):
|
||||
# Handle execution exception
|
||||
tool_call = tool_calls[i]
|
||||
tool_id = tool_call.get("id", f"error_{i}") or f"error_{i}"
|
||||
tool_name = tool_call.get("name", "unknown")
|
||||
error_msg = get_user_message(ErrorCategory.TOOL)
|
||||
|
||||
if stream_callback:
|
||||
await stream_callback(create_tool_error_event(tool_id, tool_name, error_msg))
|
||||
|
||||
error_message = ToolMessage(
|
||||
content=f"Error executing {tool_name}: {error_msg}",
|
||||
tool_call_id=tool_id,
|
||||
name=tool_name
|
||||
)
|
||||
new_messages.append(error_message)
|
||||
elif isinstance(result, tuple) and len(result) == 2:
|
||||
# result is a tuple: (tool_message, tool_result)
|
||||
tool_message, tool_result = result
|
||||
new_messages.append(tool_message)
|
||||
tool_results.append(tool_result)
|
||||
else:
|
||||
# Unexpected result format
|
||||
logger.error(f"Unexpected tool execution result format: {type(result)}")
|
||||
continue
|
||||
|
||||
return {"messages": new_messages, "tool_results": tool_results}
|
||||
|
||||
|
||||
# Legacy function for backward compatibility
|
||||
async def user_manual_rag_node(state: AgentState, config: Optional[RunnableConfig] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Legacy user manual RAG node - redirects to new agent-based implementation
|
||||
"""
|
||||
logger.info("📚 USER_MANUAL_RAG_NODE: Redirecting to user_manual_agent_node")
|
||||
return await user_manual_agent_node(state, config)
|
||||
77
vw-agentic-rag/service/graph/user_manual_tools.py
Normal file
77
vw-agentic-rag/service/graph/user_manual_tools.py
Normal file
@@ -0,0 +1,77 @@
|
||||
"""
|
||||
User manual specific tools for the Agentic RAG system.
|
||||
This module contains tools specifically for user manual retrieval and processing.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Any, List
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from ..retrieval.retrieval import AgenticRetrieval
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# User Manual Tools
|
||||
@tool
|
||||
async def retrieve_system_usermanual(query: str) -> Dict[str, Any]:
|
||||
"""Search for document content chunks of user manual of this system(CATOnline)"""
|
||||
async with AgenticRetrieval() as retrieval:
|
||||
try:
|
||||
result = await retrieval.retrieve_doc_chunk_user_manual(
|
||||
query=query
|
||||
)
|
||||
return {
|
||||
"tool_name": "retrieve_system_usermanual",
|
||||
"results_count": len(result.results),
|
||||
"results": result.results, # Already dict objects, no need for model_dump()
|
||||
"took_ms": result.took_ms
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"User manual retrieval error: {e}")
|
||||
return {"error": str(e), "results_count": 0, "results": []}
|
||||
|
||||
|
||||
# User manual tools list
|
||||
user_manual_tools = [retrieve_system_usermanual]
|
||||
|
||||
|
||||
def get_user_manual_tool_schemas() -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Generate tool schemas for user manual tools.
|
||||
|
||||
Returns:
|
||||
List of tool schemas in OpenAI function calling format
|
||||
"""
|
||||
tool_schemas = []
|
||||
for tool in user_manual_tools:
|
||||
schema = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Search query for retrieving relevant information"
|
||||
}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
}
|
||||
}
|
||||
tool_schemas.append(schema)
|
||||
|
||||
return tool_schemas
|
||||
|
||||
|
||||
def get_user_manual_tools_by_name() -> Dict[str, Any]:
|
||||
"""
|
||||
Create a mapping of user manual tool names to tool functions.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping tool names to tool functions
|
||||
"""
|
||||
return {tool.name: tool for tool in user_manual_tools}
|
||||
Reference in New Issue
Block a user