init
This commit is contained in:
746
vw-agentic-rag/service/graph/graph.py
Normal file
746
vw-agentic-rag/service/graph/graph.py
Normal file
@@ -0,0 +1,746 @@
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import asyncio
|
||||
from typing import Dict, Any, List, Callable, Annotated, Literal, TypedDict, Optional, Union, cast
|
||||
from datetime import datetime
|
||||
from urllib.parse import quote
|
||||
from contextvars import ContextVar
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langgraph.graph import StateGraph, END, add_messages, MessagesState
|
||||
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, ToolMessage, BaseMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from .state import TurnState, Message, ToolResult, AgentState
|
||||
from .message_trimmer import create_conversation_trimmer
|
||||
from .tools import get_tool_schemas, get_tools_by_name
|
||||
from .user_manual_tools import get_user_manual_tools_by_name
|
||||
from .intent_recognition import intent_recognition_node, intent_router
|
||||
from .user_manual_rag import user_manual_rag_node
|
||||
from ..llm_client import LLMClient
|
||||
from ..config import get_config
|
||||
from ..utils.templates import render_prompt_template
|
||||
from ..memory.postgresql_memory import get_checkpointer
|
||||
from ..utils.error_handler import (
|
||||
StructuredLogger, ErrorCategory, ErrorCode,
|
||||
handle_async_errors, get_user_message
|
||||
)
|
||||
from ..sse import (
|
||||
create_tool_start_event,
|
||||
create_tool_result_event,
|
||||
create_tool_error_event,
|
||||
create_token_event,
|
||||
create_error_event
|
||||
)
|
||||
|
||||
logger = StructuredLogger(__name__)
|
||||
|
||||
# Cache configuration at module level to avoid repeated get_config() calls
|
||||
_cached_config = None
|
||||
|
||||
def get_cached_config():
|
||||
"""Get cached configuration, loading it if not already cached"""
|
||||
global _cached_config
|
||||
if _cached_config is None:
|
||||
_cached_config = get_config()
|
||||
return _cached_config
|
||||
|
||||
# Context variable for streaming callback (thread-safe)
|
||||
stream_callback_context: ContextVar[Optional[Callable]] = ContextVar('stream_callback', default=None)
|
||||
|
||||
|
||||
# Agent node (autonomous function calling agent)
|
||||
async def call_model(state: AgentState, config: Optional[RunnableConfig] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Agent node that autonomously uses tools and generates final answer.
|
||||
Implements "detect-first-then-stream" strategy for optimal multi-round behavior:
|
||||
1. Always start with non-streaming detection to check for tool needs
|
||||
2. If tool_calls exist → return immediately for routing to tools
|
||||
3. If no tool_calls → temporarily disable tools and perform streaming final synthesis
|
||||
"""
|
||||
app_config = get_cached_config()
|
||||
llm_client = LLMClient()
|
||||
|
||||
# Get stream callback from context variable
|
||||
stream_callback = stream_callback_context.get()
|
||||
|
||||
# Get tool schemas and bind tools for planning phase
|
||||
tool_schemas = get_tool_schemas()
|
||||
llm_client.bind_tools(tool_schemas, force_tool_choice=True)
|
||||
|
||||
# Create conversation trimmer for managing context length
|
||||
trimmer = create_conversation_trimmer()
|
||||
|
||||
# Prepare messages with system prompt
|
||||
messages = state["messages"].copy()
|
||||
if not messages or not isinstance(messages[0], SystemMessage):
|
||||
rag_prompts = app_config.get_rag_prompts()
|
||||
system_prompt = rag_prompts.get("agent_system_prompt", "")
|
||||
if not system_prompt:
|
||||
raise ValueError("system_prompt is null")
|
||||
|
||||
messages = [SystemMessage(content=system_prompt)] + messages
|
||||
|
||||
# Track tool rounds
|
||||
current_round = state.get("tool_rounds", 0)
|
||||
# Get max_tool_rounds from state, fallback to config if not set
|
||||
max_rounds = state.get("max_tool_rounds", None)
|
||||
if max_rounds is None:
|
||||
max_rounds = app_config.app.max_tool_rounds
|
||||
|
||||
# Only apply trimming at the start of a new conversation turn (when tool_rounds = 0)
|
||||
# This prevents trimming current turn's tool results during multi-round tool calling
|
||||
if current_round == 0:
|
||||
# Trim conversation history to manage context length (only for previous conversation turns)
|
||||
if trimmer.should_trim(messages):
|
||||
messages = trimmer.trim_conversation_history(messages)
|
||||
logger.info("Applied conversation history trimming for context management (new conversation turn)")
|
||||
else:
|
||||
logger.info(f"Skipping trimming during tool round {current_round} to preserve current turn's context")
|
||||
|
||||
logger.info(f"Agent node: tool_rounds={current_round}, max_tool_rounds={max_rounds}")
|
||||
|
||||
# Check if this should be final synthesis (max rounds reached)
|
||||
has_tool_messages = any(isinstance(msg, ToolMessage) for msg in messages)
|
||||
is_final_synthesis = has_tool_messages and current_round >= max_rounds
|
||||
|
||||
if is_final_synthesis:
|
||||
logger.info("Starting final synthesis phase - no more tool calls allowed")
|
||||
# ✅ STEP 1: Final synthesis with tools disabled from the start
|
||||
# Disable tools to prevent any tool calling during synthesis
|
||||
try:
|
||||
original_tools = llm_client.bind_tools([], force_tool_choice=False) # Disable tools
|
||||
|
||||
if not stream_callback:
|
||||
# No streaming callback, generate final response without tools
|
||||
draft = await llm_client.ainvoke(list(messages))
|
||||
return {"messages": [draft]}
|
||||
|
||||
# ✅ STEP 2: Streaming final synthesis with improved HTML comment filtering
|
||||
response_content = ""
|
||||
accumulated_content = ""
|
||||
|
||||
async for token in llm_client.astream(list(messages)):
|
||||
accumulated_content += token
|
||||
response_content += token
|
||||
|
||||
# Check for complete HTML comments in accumulated content
|
||||
while "<!--" in accumulated_content and "-->" in accumulated_content:
|
||||
comment_start = accumulated_content.find("<!--")
|
||||
comment_end = accumulated_content.find("-->", comment_start)
|
||||
|
||||
if comment_start >= 0 and comment_end >= 0:
|
||||
# Send content before comment
|
||||
before_comment = accumulated_content[:comment_start]
|
||||
if stream_callback and before_comment:
|
||||
await stream_callback(create_token_event(before_comment))
|
||||
|
||||
# Skip the comment and continue with content after
|
||||
accumulated_content = accumulated_content[comment_end + 3:]
|
||||
else:
|
||||
break
|
||||
|
||||
# Send accumulated content if no pending comment
|
||||
if "<!--" not in accumulated_content:
|
||||
if stream_callback and accumulated_content:
|
||||
await stream_callback(create_token_event(accumulated_content))
|
||||
accumulated_content = ""
|
||||
|
||||
# Send any remaining content (if not in middle of comment)
|
||||
if accumulated_content and "<!--" not in accumulated_content:
|
||||
if stream_callback:
|
||||
await stream_callback(create_token_event(accumulated_content))
|
||||
|
||||
return {"messages": [AIMessage(content=response_content)]}
|
||||
|
||||
finally:
|
||||
# ✅ STEP 3: Restore tool binding for next interaction
|
||||
llm_client.bind_tools(tool_schemas, force_tool_choice=True)
|
||||
|
||||
else:
|
||||
logger.info(f"Tool calling round {current_round + 1}/{max_rounds}")
|
||||
|
||||
# ✅ STEP 1: Non-streaming detection to check for tool needs
|
||||
draft = await llm_client.ainvoke_with_tools(list(messages))
|
||||
|
||||
# ✅ STEP 2: If draft has tool_calls, return immediately (let routing handle it)
|
||||
if isinstance(draft, AIMessage) and hasattr(draft, 'tool_calls') and draft.tool_calls:
|
||||
# Increment tool round counter for next iteration
|
||||
new_tool_rounds = current_round + 1
|
||||
logger.info(f"Incremented tool_rounds to {new_tool_rounds}")
|
||||
return {"messages": [draft], "tool_rounds": new_tool_rounds}
|
||||
|
||||
# ✅ STEP 3: No tool_calls needed → Enter final synthesis with streaming
|
||||
# Temporarily disable tools to prevent accidental tool calling during synthesis
|
||||
try:
|
||||
llm_client.bind_tools([], force_tool_choice=False) # Disable tools
|
||||
|
||||
if not stream_callback:
|
||||
# No streaming callback, use the draft we already have
|
||||
return {"messages": [draft]}
|
||||
|
||||
# ✅ STEP 4: Streaming final synthesis with improved HTML comment filtering
|
||||
response_content = ""
|
||||
accumulated_content = ""
|
||||
|
||||
async for token in llm_client.astream(list(messages)):
|
||||
accumulated_content += token
|
||||
response_content += token
|
||||
|
||||
# Check for complete HTML comments in accumulated content
|
||||
while "<!--" in accumulated_content and "-->" in accumulated_content:
|
||||
comment_start = accumulated_content.find("<!--")
|
||||
comment_end = accumulated_content.find("-->", comment_start)
|
||||
|
||||
if comment_start >= 0 and comment_end >= 0:
|
||||
# Send content before comment
|
||||
before_comment = accumulated_content[:comment_start]
|
||||
if stream_callback and before_comment:
|
||||
await stream_callback(create_token_event(before_comment))
|
||||
|
||||
# Skip the comment and continue with content after
|
||||
accumulated_content = accumulated_content[comment_end + 3:]
|
||||
else:
|
||||
break
|
||||
|
||||
# Send accumulated content if no pending comment
|
||||
if "<!--" not in accumulated_content:
|
||||
if stream_callback and accumulated_content:
|
||||
await stream_callback(create_token_event(accumulated_content))
|
||||
accumulated_content = ""
|
||||
|
||||
# Send any remaining content (if not in middle of comment)
|
||||
if accumulated_content and "<!--" not in accumulated_content:
|
||||
if stream_callback:
|
||||
await stream_callback(create_token_event(accumulated_content))
|
||||
|
||||
return {"messages": [AIMessage(content=response_content)]}
|
||||
|
||||
finally:
|
||||
# ✅ STEP 5: Restore tool binding for next interaction
|
||||
llm_client.bind_tools(tool_schemas, force_tool_choice=True)
|
||||
|
||||
|
||||
# Tools routing condition (simplified for "detect-first-then-stream" strategy)
|
||||
def should_continue(state: AgentState) -> Literal["tools", "agent", "post_process"]:
|
||||
"""
|
||||
Simplified routing logic for "detect-first-then-stream" strategy:
|
||||
- has tool_calls → route to tools
|
||||
- no tool_calls → route to post_process (final synthesis already completed)
|
||||
"""
|
||||
messages = state["messages"]
|
||||
if not messages:
|
||||
logger.info("should_continue: No messages, routing to post_process")
|
||||
return "post_process"
|
||||
|
||||
last_message = messages[-1]
|
||||
current_round = state.get("tool_rounds", 0)
|
||||
# Get max_tool_rounds from state, fallback to config if not set
|
||||
max_rounds = state.get("max_tool_rounds", None)
|
||||
if max_rounds is None:
|
||||
app_config = get_cached_config()
|
||||
max_rounds = app_config.app.max_tool_rounds
|
||||
|
||||
logger.info(f"should_continue: Last message type: {type(last_message)}, tool_rounds: {current_round}/{max_rounds}")
|
||||
|
||||
# If last message is AI message with tool calls, route to tools
|
||||
if isinstance(last_message, AIMessage):
|
||||
has_tool_calls = hasattr(last_message, 'tool_calls') and last_message.tool_calls
|
||||
logger.info(f"should_continue: AI message has tool_calls: {has_tool_calls}")
|
||||
|
||||
if has_tool_calls:
|
||||
logger.info("should_continue: Routing to tools")
|
||||
return "tools"
|
||||
else:
|
||||
# No tool calls = final synthesis already completed in call_model
|
||||
logger.info("should_continue: No tool calls, routing to post_process")
|
||||
return "post_process"
|
||||
|
||||
# If last message is tool message(s), continue with agent for next round or final synthesis
|
||||
if isinstance(last_message, ToolMessage):
|
||||
logger.info("should_continue: Tool message completed, continuing to agent")
|
||||
return "agent"
|
||||
|
||||
logger.info("should_continue: Routing to post_process")
|
||||
return "post_process"
|
||||
|
||||
|
||||
# Custom tool node with streaming support and parallel execution
|
||||
async def run_tools_with_streaming(state: AgentState, config: Optional[RunnableConfig] = None) -> Dict[str, Any]:
|
||||
"""Execute tools with streaming events - supports parallel execution"""
|
||||
messages = state["messages"]
|
||||
last_message = messages[-1]
|
||||
|
||||
# Get stream callback from context variable
|
||||
stream_callback = stream_callback_context.get()
|
||||
|
||||
if not isinstance(last_message, AIMessage) or not hasattr(last_message, 'tool_calls'):
|
||||
return {"messages": []}
|
||||
|
||||
tool_calls = last_message.tool_calls or []
|
||||
tool_results = []
|
||||
new_messages = []
|
||||
|
||||
# Tools mapping
|
||||
tools_map = get_tools_by_name()
|
||||
|
||||
async def execute_single_tool(tool_call):
|
||||
"""Execute a single tool call with enhanced error handling"""
|
||||
# Get stream callback from context
|
||||
stream_callback = stream_callback_context.get()
|
||||
|
||||
# Apply error handling decorator
|
||||
@handle_async_errors(
|
||||
ErrorCategory.TOOL,
|
||||
ErrorCode.TOOL_ERROR,
|
||||
stream_callback,
|
||||
tool_call.get("id", "unknown") if isinstance(tool_call, dict) else "unknown"
|
||||
)
|
||||
async def _execute():
|
||||
# Validate tool_call format
|
||||
if not isinstance(tool_call, dict):
|
||||
raise ValueError(f"Tool call must be dict, got {type(tool_call)}")
|
||||
|
||||
tool_name = tool_call.get("name")
|
||||
tool_args = tool_call.get("args", {})
|
||||
tool_id = tool_call.get("id", "unknown")
|
||||
|
||||
if not tool_name or tool_name not in tools_map:
|
||||
raise ValueError(f"Tool '{tool_name}' not found")
|
||||
|
||||
logger.info(f"Executing tool: {tool_name}", extra={
|
||||
"tool_id": tool_id, "tool_name": tool_name
|
||||
})
|
||||
|
||||
# Send start event
|
||||
if stream_callback:
|
||||
await stream_callback(create_tool_start_event(tool_id, tool_name, tool_args))
|
||||
|
||||
# Execute tool
|
||||
import time
|
||||
start_time = time.time()
|
||||
result = await tools_map[tool_name].ainvoke(tool_args)
|
||||
execution_time = int((time.time() - start_time) * 1000)
|
||||
|
||||
# Process result
|
||||
if isinstance(result, dict):
|
||||
result["tool_call_id"] = tool_id
|
||||
if "results" in result and isinstance(result["results"], list):
|
||||
for i, search_result in enumerate(result["results"]):
|
||||
if isinstance(search_result, dict):
|
||||
search_result["@tool_call_id"] = tool_id
|
||||
search_result["@order_num"] = i
|
||||
|
||||
# Send result event
|
||||
if stream_callback:
|
||||
await stream_callback(create_tool_result_event(
|
||||
tool_id, tool_name, result.get("results", []), execution_time
|
||||
))
|
||||
|
||||
# Create tool message
|
||||
tool_message = ToolMessage(
|
||||
content=json.dumps(result, ensure_ascii=False),
|
||||
tool_call_id=tool_id,
|
||||
name=tool_name
|
||||
)
|
||||
|
||||
return {
|
||||
"message": tool_message,
|
||||
"results": result.get("results", []) if isinstance(result, dict) else [],
|
||||
"success": True
|
||||
}
|
||||
|
||||
try:
|
||||
return await _execute()
|
||||
except Exception as e:
|
||||
# Handle any errors not caught by decorator
|
||||
tool_id = tool_call.get("id", "unknown") if isinstance(tool_call, dict) else "unknown"
|
||||
tool_name = tool_call.get("name", "unknown") if isinstance(tool_call, dict) else "unknown"
|
||||
|
||||
error_message = ToolMessage(
|
||||
content=f"Error: {get_user_message(ErrorCategory.TOOL)}",
|
||||
tool_call_id=tool_id,
|
||||
name=tool_name
|
||||
)
|
||||
|
||||
return {
|
||||
"message": error_message,
|
||||
"results": [],
|
||||
"success": False
|
||||
}
|
||||
|
||||
# Execute all tool calls in parallel using asyncio.gather
|
||||
if tool_calls:
|
||||
logger.info(f"Executing {len(tool_calls)} tool calls in parallel")
|
||||
tool_execution_results = await asyncio.gather(
|
||||
*[execute_single_tool(tool_call) for tool_call in tool_calls],
|
||||
return_exceptions=True
|
||||
)
|
||||
|
||||
# Process results
|
||||
for execution_result in tool_execution_results:
|
||||
if execution_result is None:
|
||||
continue
|
||||
if isinstance(execution_result, Exception):
|
||||
logger.error(f"Tool execution exception: {execution_result}")
|
||||
continue
|
||||
if not isinstance(execution_result, dict):
|
||||
logger.error(f"Unexpected execution result type: {type(execution_result)}")
|
||||
continue
|
||||
|
||||
new_messages.append(execution_result["message"])
|
||||
if execution_result["success"] and execution_result["results"]:
|
||||
tool_results.extend(execution_result["results"])
|
||||
|
||||
logger.info(f"Parallel tool execution completed. {len(new_messages)} tools executed, {len(tool_results)} results collected")
|
||||
|
||||
return {
|
||||
"messages": new_messages,
|
||||
"tool_results": tool_results
|
||||
}
|
||||
|
||||
|
||||
# Helper functions for citation processing
|
||||
def _extract_citations_mapping(agent_response: str) -> Dict[int, Dict[str, Any]]:
|
||||
"""Extract citations mapping CSV from agent response HTML comment"""
|
||||
try:
|
||||
# Look for citations_map comment
|
||||
pattern = r'<!-- citations_map\s*(.*?)\s*-->'
|
||||
match = re.search(pattern, agent_response, re.DOTALL | re.IGNORECASE)
|
||||
|
||||
if not match:
|
||||
logger.warning("No citations_map comment found in agent response")
|
||||
return {}
|
||||
|
||||
csv_content = match.group(1).strip()
|
||||
citations_mapping = {}
|
||||
|
||||
for line in csv_content.split('\n'):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
parts = line.split(',')
|
||||
if len(parts) >= 3:
|
||||
try:
|
||||
citation_num = int(parts[0])
|
||||
tool_call_id = parts[1].strip()
|
||||
order_num = int(parts[2])
|
||||
|
||||
citations_mapping[citation_num] = {
|
||||
'tool_call_id': tool_call_id,
|
||||
'order_num': order_num
|
||||
}
|
||||
except (ValueError, IndexError) as e:
|
||||
logger.warning(f"Failed to parse citation line: {line}, error: {e}")
|
||||
continue
|
||||
|
||||
return citations_mapping
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting citations mapping: {e}")
|
||||
return {}
|
||||
|
||||
|
||||
def _build_citation_markdown(citations_mapping: Dict[int, Dict[str, Any]], tool_results: List[Dict[str, Any]]) -> str:
|
||||
"""Build citation markdown based on mapping and tool results, following build_citations.py logic"""
|
||||
if not citations_mapping:
|
||||
return ""
|
||||
|
||||
# Get configuration for citation base URL
|
||||
config = get_cached_config()
|
||||
cat_base_url = config.citation.base_url
|
||||
|
||||
# Collect citation lines first; only emit header if we have at least one valid citation
|
||||
entries: List[str] = []
|
||||
|
||||
for citation_num in sorted(citations_mapping.keys()):
|
||||
mapping = citations_mapping[citation_num]
|
||||
tool_call_id = mapping['tool_call_id']
|
||||
order_num = mapping['order_num']
|
||||
|
||||
# Find the corresponding tool result
|
||||
result = _find_tool_result(tool_results, tool_call_id, order_num)
|
||||
if not result:
|
||||
logger.warning(f"No tool result found for citation [{citation_num}]")
|
||||
continue
|
||||
|
||||
# Extract citation information following build_citations.py logic
|
||||
full_headers = result.get('full_headers', '')
|
||||
lowest_header = full_headers.split("||", 1)[0] if full_headers else ""
|
||||
header_display = f": {lowest_header}" if lowest_header else ""
|
||||
|
||||
document_code = result.get('document_code', '')
|
||||
document_category = result.get('document_category', '')
|
||||
|
||||
# Determine standard/regulation title (assuming English language)
|
||||
standard_regulation_title = ''
|
||||
if document_category == 'Standard':
|
||||
standard_regulation_title = result.get('x_Standard_Title_EN', '') or result.get('x_Standard_Title_CN', '')
|
||||
elif document_category == 'Regulation':
|
||||
standard_regulation_title = result.get('x_Regulation_Title_EN', '') or result.get('x_Regulation_Title_CN', '')
|
||||
|
||||
# Build link
|
||||
func_uuid = result.get('func_uuid', '')
|
||||
uuid = result.get('x_Standard_Regulation_Id', '')
|
||||
document_code_encoded = quote(document_code, safe='') if document_code else ''
|
||||
standard_regulation_title_encoded = quote(standard_regulation_title, safe='') if standard_regulation_title else ''
|
||||
link_name = f"{document_code_encoded}({standard_regulation_title_encoded})" if (document_code_encoded or standard_regulation_title_encoded) else ''
|
||||
link = f'{cat_base_url}?funcUuid={func_uuid}&uuid={uuid}&name={link_name}'
|
||||
|
||||
# Format citation line
|
||||
title = result.get('title', '')
|
||||
entries.append(f"[{citation_num}] {title}{header_display} | [{standard_regulation_title} | {document_code}]({link})")
|
||||
|
||||
# If no valid citations were found, do not include the header
|
||||
if not entries:
|
||||
return ""
|
||||
|
||||
# Build citations section with entries separated by a blank line (matching previous formatting)
|
||||
md = "\n\n### 📘 Citations:\n" + "\n\n".join(entries) + "\n\n"
|
||||
return md
|
||||
|
||||
|
||||
def _find_tool_result(tool_results: List[Dict[str, Any]], tool_call_id: str, order_num: int) -> Optional[Dict[str, Any]]:
|
||||
"""Find tool result by tool_call_id and order_num"""
|
||||
matching_results = []
|
||||
|
||||
for result in tool_results:
|
||||
if result.get('@tool_call_id') == tool_call_id:
|
||||
matching_results.append(result)
|
||||
|
||||
# Sort by order and return the one at the specified position
|
||||
if matching_results and 0 <= order_num < len(matching_results):
|
||||
# If results have @order_num, use it; otherwise use position in list
|
||||
if '@order_num' in matching_results[0]:
|
||||
for result in matching_results:
|
||||
if result.get('@order_num') == order_num:
|
||||
return result
|
||||
else:
|
||||
return matching_results[order_num]
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _remove_citations_comment(agent_response: str) -> str:
|
||||
"""Remove citations mapping HTML comment from agent response"""
|
||||
pattern = r'<!-- citations_map\s*.*?\s*-->'
|
||||
return re.sub(pattern, '', agent_response, flags=re.DOTALL | re.IGNORECASE).strip()
|
||||
|
||||
|
||||
# Post-processing node with citation list and link building
|
||||
async def post_process_node(state: AgentState, config: Optional[RunnableConfig] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Post-processing node that builds citation list and links based on agent's citations mapping
|
||||
and tool call results, following the logic from build_citations.py
|
||||
"""
|
||||
try:
|
||||
logger.info("🔧 POST_PROCESS_NODE: Starting citation processing")
|
||||
|
||||
# Get stream callback from context variable
|
||||
stream_callback = stream_callback_context.get()
|
||||
|
||||
# Get the last AI message (agent's response with citations mapping)
|
||||
agent_response = ""
|
||||
citations_mapping = {}
|
||||
|
||||
for message in reversed(state["messages"]):
|
||||
if isinstance(message, AIMessage) and message.content:
|
||||
# Ensure content is a string
|
||||
if isinstance(message.content, str):
|
||||
agent_response = message.content
|
||||
break
|
||||
|
||||
if not agent_response:
|
||||
logger.warning("POST_PROCESS_NODE: No agent response found")
|
||||
return {"messages": [], "final_answer": ""}
|
||||
|
||||
# Extract citations mapping from agent response
|
||||
citations_mapping = _extract_citations_mapping(agent_response)
|
||||
logger.info(f"POST_PROCESS_NODE: Extracted {len(citations_mapping)} citations")
|
||||
|
||||
# Build citation markdown
|
||||
citation_markdown = _build_citation_markdown(citations_mapping, state["tool_results"])
|
||||
|
||||
# Combine agent response (without HTML comment) with citations
|
||||
clean_response = _remove_citations_comment(agent_response)
|
||||
final_content = clean_response + citation_markdown
|
||||
|
||||
logger.info("POST_PROCESS_NODE: Built complete response with citations")
|
||||
|
||||
# Send citation markdown as a single block instead of streaming
|
||||
stream_callback = stream_callback_context.get()
|
||||
if stream_callback and citation_markdown:
|
||||
logger.info("POST_PROCESS_NODE: Sending citation markdown as single block to client")
|
||||
await stream_callback(create_token_event(citation_markdown))
|
||||
|
||||
# Create AI message with complete content
|
||||
final_ai_message = AIMessage(content=final_content)
|
||||
|
||||
return {
|
||||
"messages": [final_ai_message],
|
||||
"final_answer": final_content
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Post-processing error: {e}")
|
||||
error_message = "\n\n❌ **Error generating citations**\n\nPlease check the search results above."
|
||||
|
||||
# Send error message as single block
|
||||
stream_callback = stream_callback_context.get()
|
||||
if stream_callback:
|
||||
await stream_callback(create_token_event(error_message))
|
||||
|
||||
error_content = agent_response + error_message if agent_response else error_message
|
||||
error_ai_message = AIMessage(content=error_content)
|
||||
return {
|
||||
"messages": [error_ai_message],
|
||||
"final_answer": error_ai_message.content
|
||||
}
|
||||
|
||||
|
||||
# Main workflow class
|
||||
class AgenticWorkflow:
|
||||
"""LangGraph-based autonomous agent workflow following v0.6.0+ best practices"""
|
||||
|
||||
def __init__(self):
|
||||
# Build StateGraph with TypedDict state
|
||||
workflow = StateGraph(AgentState)
|
||||
|
||||
# Add nodes following best practices
|
||||
workflow.add_node("intent_recognition", intent_recognition_node)
|
||||
workflow.add_node("agent", call_model)
|
||||
workflow.add_node("user_manual_rag", user_manual_rag_node)
|
||||
workflow.add_node("tools", run_tools_with_streaming)
|
||||
workflow.add_node("post_process", post_process_node)
|
||||
|
||||
# Set entry point to intent recognition
|
||||
workflow.set_entry_point("intent_recognition")
|
||||
|
||||
# Intent recognition routes to either Standard_Regulation_RAG or User_Manual_RAG
|
||||
workflow.add_conditional_edges(
|
||||
"intent_recognition",
|
||||
intent_router,
|
||||
{
|
||||
"Standard_Regulation_RAG": "agent",
|
||||
"User_Manual_RAG": "user_manual_rag"
|
||||
}
|
||||
)
|
||||
|
||||
# Standard RAG workflow (existing pattern)
|
||||
workflow.add_conditional_edges(
|
||||
"agent",
|
||||
should_continue,
|
||||
{
|
||||
"tools": "tools",
|
||||
"agent": "agent", # Allow agent to continue for multi-round
|
||||
"post_process": "post_process"
|
||||
}
|
||||
)
|
||||
|
||||
# Tools route back to should_continue for multi-round decision
|
||||
workflow.add_conditional_edges(
|
||||
"tools",
|
||||
should_continue,
|
||||
{
|
||||
"agent": "agent", # Continue to agent for next round
|
||||
"post_process": "post_process" # Or finish if max rounds reached
|
||||
}
|
||||
)
|
||||
|
||||
# User Manual RAG directly goes to END (single turn)
|
||||
workflow.add_edge("user_manual_rag", END)
|
||||
|
||||
# Post-process is terminal
|
||||
workflow.add_edge("post_process", END)
|
||||
|
||||
# Compile graph with PostgreSQL checkpointer for session memory
|
||||
try:
|
||||
checkpointer = get_checkpointer()
|
||||
self.graph = workflow.compile(checkpointer=checkpointer)
|
||||
logger.info("Graph compiled with PostgreSQL checkpointer for session memory")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to initialize PostgreSQL checkpointer, using memory-only graph: {e}")
|
||||
self.graph = workflow.compile()
|
||||
|
||||
async def astream(self, state: TurnState, stream_callback: Callable | None = None):
|
||||
"""Stream agent execution using LangGraph with PostgreSQL session memory"""
|
||||
try:
|
||||
# Get configuration
|
||||
config = get_cached_config()
|
||||
|
||||
# Prepare initial messages for the graph
|
||||
messages = []
|
||||
for msg in state.messages:
|
||||
if msg.role == "user":
|
||||
messages.append(HumanMessage(content=msg.content))
|
||||
elif msg.role == "assistant":
|
||||
messages.append(AIMessage(content=msg.content))
|
||||
|
||||
# Create initial agent state (without stream_callback to avoid serialization issues)
|
||||
initial_state: AgentState = {
|
||||
"messages": messages,
|
||||
"session_id": state.session_id,
|
||||
"intent": None, # Will be determined by intent recognition node
|
||||
"tool_results": [],
|
||||
"final_answer": "",
|
||||
"tool_rounds": 0,
|
||||
"max_tool_rounds": config.app.max_tool_rounds, # Use configuration value
|
||||
"max_tool_rounds_user_manual": config.app.max_tool_rounds_user_manual # Use configuration value for user manual agent
|
||||
}
|
||||
|
||||
# Set stream callback in context variable (thread-safe)
|
||||
stream_callback_context.set(stream_callback)
|
||||
|
||||
# Create proper RunnableConfig
|
||||
runnable_config = RunnableConfig(configurable={"thread_id": state.session_id})
|
||||
|
||||
# Stream graph execution with session memory
|
||||
async for step in self.graph.astream(initial_state, config=runnable_config):
|
||||
if "post_process" in step:
|
||||
final_state = step["post_process"]
|
||||
|
||||
# Extract the tool summary message and update state
|
||||
state.final_answer = final_state.get("final_answer", "")
|
||||
|
||||
# Add the summary as a regular assistant message
|
||||
if state.final_answer:
|
||||
state.messages.append(Message(
|
||||
role="assistant",
|
||||
content=state.final_answer,
|
||||
timestamp=datetime.now()
|
||||
))
|
||||
|
||||
yield {"final": state}
|
||||
break
|
||||
elif "user_manual_rag" in step:
|
||||
# Handle user manual RAG completion
|
||||
final_state = step["user_manual_rag"]
|
||||
|
||||
# Extract the response from user manual RAG
|
||||
state.final_answer = final_state.get("final_answer", "")
|
||||
|
||||
# Add the response as a regular assistant message
|
||||
if state.final_answer:
|
||||
state.messages.append(Message(
|
||||
role="assistant",
|
||||
content=state.final_answer,
|
||||
timestamp=datetime.now()
|
||||
))
|
||||
|
||||
yield {"final": state}
|
||||
break
|
||||
else:
|
||||
# Process regular steps (intent_recognition, agent, tools)
|
||||
yield step
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"AgentWorkflow error: {e}")
|
||||
state.final_answer = "I apologize, but I encountered an error while processing your request."
|
||||
yield {"final": state}
|
||||
|
||||
|
||||
def build_graph() -> AgenticWorkflow:
|
||||
"""Build and return the autonomous agent workflow"""
|
||||
return AgenticWorkflow()
|
||||
Reference in New Issue
Block a user