import json import logging import re import asyncio from typing import Dict, Any, List, Callable, Annotated, Literal, TypedDict, Optional, Union, cast from datetime import datetime from urllib.parse import quote from contextvars import ContextVar from pydantic import BaseModel from langgraph.graph import StateGraph, END, add_messages, MessagesState from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, ToolMessage, BaseMessage from langchain_core.runnables import RunnableConfig from .state import TurnState, Message, ToolResult, AgentState from .message_trimmer import create_conversation_trimmer from .tools import get_tool_schemas, get_tools_by_name from .user_manual_tools import get_user_manual_tools_by_name from .intent_recognition import intent_recognition_node, intent_router from .user_manual_rag import user_manual_rag_node from ..llm_client import LLMClient from ..config import get_config from ..utils.templates import render_prompt_template from ..memory.postgresql_memory import get_checkpointer from ..utils.error_handler import ( StructuredLogger, ErrorCategory, ErrorCode, handle_async_errors, get_user_message ) from ..sse import ( create_tool_start_event, create_tool_result_event, create_tool_error_event, create_token_event, create_error_event ) logger = StructuredLogger(__name__) # Cache configuration at module level to avoid repeated get_config() calls _cached_config = None def get_cached_config(): """Get cached configuration, loading it if not already cached""" global _cached_config if _cached_config is None: _cached_config = get_config() return _cached_config # Context variable for streaming callback (thread-safe) stream_callback_context: ContextVar[Optional[Callable]] = ContextVar('stream_callback', default=None) # Agent node (autonomous function calling agent) async def call_model(state: AgentState, config: Optional[RunnableConfig] = None) -> Dict[str, Any]: """ Agent node that autonomously uses tools and generates final answer. Implements "detect-first-then-stream" strategy for optimal multi-round behavior: 1. Always start with non-streaming detection to check for tool needs 2. If tool_calls exist → return immediately for routing to tools 3. If no tool_calls → temporarily disable tools and perform streaming final synthesis """ app_config = get_cached_config() llm_client = LLMClient() # Get stream callback from context variable stream_callback = stream_callback_context.get() # Get tool schemas and bind tools for planning phase tool_schemas = get_tool_schemas() llm_client.bind_tools(tool_schemas, force_tool_choice=True) # Create conversation trimmer for managing context length trimmer = create_conversation_trimmer() # Prepare messages with system prompt messages = state["messages"].copy() if not messages or not isinstance(messages[0], SystemMessage): rag_prompts = app_config.get_rag_prompts() system_prompt = rag_prompts.get("agent_system_prompt", "") if not system_prompt: raise ValueError("system_prompt is null") messages = [SystemMessage(content=system_prompt)] + messages # Track tool rounds current_round = state.get("tool_rounds", 0) # Get max_tool_rounds from state, fallback to config if not set max_rounds = state.get("max_tool_rounds", None) if max_rounds is None: max_rounds = app_config.app.max_tool_rounds # Only apply trimming at the start of a new conversation turn (when tool_rounds = 0) # This prevents trimming current turn's tool results during multi-round tool calling if current_round == 0: # Trim conversation history to manage context length (only for previous conversation turns) if trimmer.should_trim(messages): messages = trimmer.trim_conversation_history(messages) logger.info("Applied conversation history trimming for context management (new conversation turn)") else: logger.info(f"Skipping trimming during tool round {current_round} to preserve current turn's context") logger.info(f"Agent node: tool_rounds={current_round}, max_tool_rounds={max_rounds}") # Check if this should be final synthesis (max rounds reached) has_tool_messages = any(isinstance(msg, ToolMessage) for msg in messages) is_final_synthesis = has_tool_messages and current_round >= max_rounds if is_final_synthesis: logger.info("Starting final synthesis phase - no more tool calls allowed") # ✅ STEP 1: Final synthesis with tools disabled from the start # Disable tools to prevent any tool calling during synthesis try: original_tools = llm_client.bind_tools([], force_tool_choice=False) # Disable tools if not stream_callback: # No streaming callback, generate final response without tools draft = await llm_client.ainvoke(list(messages)) return {"messages": [draft]} # ✅ STEP 2: Streaming final synthesis with improved HTML comment filtering response_content = "" accumulated_content = "" async for token in llm_client.astream(list(messages)): accumulated_content += token response_content += token # Check for complete HTML comments in accumulated content while "" in accumulated_content: comment_start = accumulated_content.find("", comment_start) if comment_start >= 0 and comment_end >= 0: # Send content before comment before_comment = accumulated_content[:comment_start] if stream_callback and before_comment: await stream_callback(create_token_event(before_comment)) # Skip the comment and continue with content after accumulated_content = accumulated_content[comment_end + 3:] else: break # Send accumulated content if no pending comment if "" in accumulated_content: comment_start = accumulated_content.find("", comment_start) if comment_start >= 0 and comment_end >= 0: # Send content before comment before_comment = accumulated_content[:comment_start] if stream_callback and before_comment: await stream_callback(create_token_event(before_comment)) # Skip the comment and continue with content after accumulated_content = accumulated_content[comment_end + 3:] else: break # Send accumulated content if no pending comment if "' match = re.search(pattern, agent_response, re.DOTALL | re.IGNORECASE) if not match: logger.warning("No citations_map comment found in agent response") return {} csv_content = match.group(1).strip() citations_mapping = {} for line in csv_content.split('\n'): line = line.strip() if not line: continue parts = line.split(',') if len(parts) >= 3: try: citation_num = int(parts[0]) tool_call_id = parts[1].strip() order_num = int(parts[2]) citations_mapping[citation_num] = { 'tool_call_id': tool_call_id, 'order_num': order_num } except (ValueError, IndexError) as e: logger.warning(f"Failed to parse citation line: {line}, error: {e}") continue return citations_mapping except Exception as e: logger.error(f"Error extracting citations mapping: {e}") return {} def _build_citation_markdown(citations_mapping: Dict[int, Dict[str, Any]], tool_results: List[Dict[str, Any]]) -> str: """Build citation markdown based on mapping and tool results, following build_citations.py logic""" if not citations_mapping: return "" # Get configuration for citation base URL config = get_cached_config() cat_base_url = config.citation.base_url # Collect citation lines first; only emit header if we have at least one valid citation entries: List[str] = [] for citation_num in sorted(citations_mapping.keys()): mapping = citations_mapping[citation_num] tool_call_id = mapping['tool_call_id'] order_num = mapping['order_num'] # Find the corresponding tool result result = _find_tool_result(tool_results, tool_call_id, order_num) if not result: logger.warning(f"No tool result found for citation [{citation_num}]") continue # Extract citation information following build_citations.py logic full_headers = result.get('full_headers', '') lowest_header = full_headers.split("||", 1)[0] if full_headers else "" header_display = f": {lowest_header}" if lowest_header else "" document_code = result.get('document_code', '') document_category = result.get('document_category', '') # Determine standard/regulation title (assuming English language) standard_regulation_title = '' if document_category == 'Standard': standard_regulation_title = result.get('x_Standard_Title_EN', '') or result.get('x_Standard_Title_CN', '') elif document_category == 'Regulation': standard_regulation_title = result.get('x_Regulation_Title_EN', '') or result.get('x_Regulation_Title_CN', '') # Build link func_uuid = result.get('func_uuid', '') uuid = result.get('x_Standard_Regulation_Id', '') document_code_encoded = quote(document_code, safe='') if document_code else '' standard_regulation_title_encoded = quote(standard_regulation_title, safe='') if standard_regulation_title else '' link_name = f"{document_code_encoded}({standard_regulation_title_encoded})" if (document_code_encoded or standard_regulation_title_encoded) else '' link = f'{cat_base_url}?funcUuid={func_uuid}&uuid={uuid}&name={link_name}' # Format citation line title = result.get('title', '') entries.append(f"[{citation_num}] {title}{header_display} | [{standard_regulation_title} | {document_code}]({link})") # If no valid citations were found, do not include the header if not entries: return "" # Build citations section with entries separated by a blank line (matching previous formatting) md = "\n\n### 📘 Citations:\n" + "\n\n".join(entries) + "\n\n" return md def _find_tool_result(tool_results: List[Dict[str, Any]], tool_call_id: str, order_num: int) -> Optional[Dict[str, Any]]: """Find tool result by tool_call_id and order_num""" matching_results = [] for result in tool_results: if result.get('@tool_call_id') == tool_call_id: matching_results.append(result) # Sort by order and return the one at the specified position if matching_results and 0 <= order_num < len(matching_results): # If results have @order_num, use it; otherwise use position in list if '@order_num' in matching_results[0]: for result in matching_results: if result.get('@order_num') == order_num: return result else: return matching_results[order_num] return None def _remove_citations_comment(agent_response: str) -> str: """Remove citations mapping HTML comment from agent response""" pattern = r'' return re.sub(pattern, '', agent_response, flags=re.DOTALL | re.IGNORECASE).strip() # Post-processing node with citation list and link building async def post_process_node(state: AgentState, config: Optional[RunnableConfig] = None) -> Dict[str, Any]: """ Post-processing node that builds citation list and links based on agent's citations mapping and tool call results, following the logic from build_citations.py """ try: logger.info("🔧 POST_PROCESS_NODE: Starting citation processing") # Get stream callback from context variable stream_callback = stream_callback_context.get() # Get the last AI message (agent's response with citations mapping) agent_response = "" citations_mapping = {} for message in reversed(state["messages"]): if isinstance(message, AIMessage) and message.content: # Ensure content is a string if isinstance(message.content, str): agent_response = message.content break if not agent_response: logger.warning("POST_PROCESS_NODE: No agent response found") return {"messages": [], "final_answer": ""} # Extract citations mapping from agent response citations_mapping = _extract_citations_mapping(agent_response) logger.info(f"POST_PROCESS_NODE: Extracted {len(citations_mapping)} citations") # Build citation markdown citation_markdown = _build_citation_markdown(citations_mapping, state["tool_results"]) # Combine agent response (without HTML comment) with citations clean_response = _remove_citations_comment(agent_response) final_content = clean_response + citation_markdown logger.info("POST_PROCESS_NODE: Built complete response with citations") # Send citation markdown as a single block instead of streaming stream_callback = stream_callback_context.get() if stream_callback and citation_markdown: logger.info("POST_PROCESS_NODE: Sending citation markdown as single block to client") await stream_callback(create_token_event(citation_markdown)) # Create AI message with complete content final_ai_message = AIMessage(content=final_content) return { "messages": [final_ai_message], "final_answer": final_content } except Exception as e: logger.error(f"Post-processing error: {e}") error_message = "\n\n❌ **Error generating citations**\n\nPlease check the search results above." # Send error message as single block stream_callback = stream_callback_context.get() if stream_callback: await stream_callback(create_token_event(error_message)) error_content = agent_response + error_message if agent_response else error_message error_ai_message = AIMessage(content=error_content) return { "messages": [error_ai_message], "final_answer": error_ai_message.content } # Main workflow class class AgenticWorkflow: """LangGraph-based autonomous agent workflow following v0.6.0+ best practices""" def __init__(self): # Build StateGraph with TypedDict state workflow = StateGraph(AgentState) # Add nodes following best practices workflow.add_node("intent_recognition", intent_recognition_node) workflow.add_node("agent", call_model) workflow.add_node("user_manual_rag", user_manual_rag_node) workflow.add_node("tools", run_tools_with_streaming) workflow.add_node("post_process", post_process_node) # Set entry point to intent recognition workflow.set_entry_point("intent_recognition") # Intent recognition routes to either Standard_Regulation_RAG or User_Manual_RAG workflow.add_conditional_edges( "intent_recognition", intent_router, { "Standard_Regulation_RAG": "agent", "User_Manual_RAG": "user_manual_rag" } ) # Standard RAG workflow (existing pattern) workflow.add_conditional_edges( "agent", should_continue, { "tools": "tools", "agent": "agent", # Allow agent to continue for multi-round "post_process": "post_process" } ) # Tools route back to should_continue for multi-round decision workflow.add_conditional_edges( "tools", should_continue, { "agent": "agent", # Continue to agent for next round "post_process": "post_process" # Or finish if max rounds reached } ) # User Manual RAG directly goes to END (single turn) workflow.add_edge("user_manual_rag", END) # Post-process is terminal workflow.add_edge("post_process", END) # Compile graph with PostgreSQL checkpointer for session memory try: checkpointer = get_checkpointer() self.graph = workflow.compile(checkpointer=checkpointer) logger.info("Graph compiled with PostgreSQL checkpointer for session memory") except Exception as e: logger.warning(f"Failed to initialize PostgreSQL checkpointer, using memory-only graph: {e}") self.graph = workflow.compile() async def astream(self, state: TurnState, stream_callback: Callable | None = None): """Stream agent execution using LangGraph with PostgreSQL session memory""" try: # Get configuration config = get_cached_config() # Prepare initial messages for the graph messages = [] for msg in state.messages: if msg.role == "user": messages.append(HumanMessage(content=msg.content)) elif msg.role == "assistant": messages.append(AIMessage(content=msg.content)) # Create initial agent state (without stream_callback to avoid serialization issues) initial_state: AgentState = { "messages": messages, "session_id": state.session_id, "intent": None, # Will be determined by intent recognition node "tool_results": [], "final_answer": "", "tool_rounds": 0, "max_tool_rounds": config.app.max_tool_rounds, # Use configuration value "max_tool_rounds_user_manual": config.app.max_tool_rounds_user_manual # Use configuration value for user manual agent } # Set stream callback in context variable (thread-safe) stream_callback_context.set(stream_callback) # Create proper RunnableConfig runnable_config = RunnableConfig(configurable={"thread_id": state.session_id}) # Stream graph execution with session memory async for step in self.graph.astream(initial_state, config=runnable_config): if "post_process" in step: final_state = step["post_process"] # Extract the tool summary message and update state state.final_answer = final_state.get("final_answer", "") # Add the summary as a regular assistant message if state.final_answer: state.messages.append(Message( role="assistant", content=state.final_answer, timestamp=datetime.now() )) yield {"final": state} break elif "user_manual_rag" in step: # Handle user manual RAG completion final_state = step["user_manual_rag"] # Extract the response from user manual RAG state.final_answer = final_state.get("final_answer", "") # Add the response as a regular assistant message if state.final_answer: state.messages.append(Message( role="assistant", content=state.final_answer, timestamp=datetime.now() )) yield {"final": state} break else: # Process regular steps (intent_recognition, agent, tools) yield step except Exception as e: logger.error(f"AgentWorkflow error: {e}") state.final_answer = "I apologize, but I encountered an error while processing your request." yield {"final": state} def build_graph() -> AgenticWorkflow: """Build and return the autonomous agent workflow""" return AgenticWorkflow()