Files
catonline_ai/vw-agentic-rag/service/graph/intent_recognition.py
2025-09-26 17:15:54 +08:00

137 lines
5.7 KiB
Python

"""
Intent recognition functionality for the Agentic RAG system.
This module contains the intent classification logic.
"""
import logging
from typing import Dict, Any, Optional, Literal
from langchain_core.messages import SystemMessage
from langchain_core.runnables import RunnableConfig
from pydantic import BaseModel
from .state import AgentState
from ..llm_client import LLMClient
from ..config import get_config
from ..utils.error_handler import StructuredLogger
logger = StructuredLogger(__name__)
# Intent Recognition Models
class Intent(BaseModel):
"""Intent classification model for routing user queries"""
label: Literal["Standard_Regulation_RAG", "User_Manual_RAG"]
confidence: Optional[float] = None
def get_last_user_message(messages) -> str:
"""Extract the last user message from conversation history"""
for message in reversed(messages):
if hasattr(message, 'content'):
content = message.content
# Handle both string and list content
if isinstance(content, str):
return content
elif isinstance(content, list):
# Extract string content from list
return " ".join([str(item) for item in content if isinstance(item, str)])
return ""
def render_conversation_history(messages, max_messages: int = 10) -> str:
"""Render conversation history for context"""
recent_messages = messages[-max_messages:] if len(messages) > max_messages else messages
lines = []
for msg in recent_messages:
if hasattr(msg, 'content'):
content = msg.content
if isinstance(content, str):
# Determine message type by class name or other attributes
if 'Human' in str(type(msg)):
lines.append(f"<user>{content}</user>")
elif 'AI' in str(type(msg)):
lines.append(f"<ai>{content}</ai>")
elif isinstance(content, list):
content_str = " ".join([str(item) for item in content if isinstance(item, str)])
if 'Human' in str(type(msg)):
lines.append(f"<user>{content_str}</user>")
elif 'AI' in str(type(msg)):
lines.append(f"<ai>{content_str}</ai>")
return "\n".join(lines)
async def intent_recognition_node(state: AgentState, config: Optional[RunnableConfig] = None) -> Dict[str, Any]:
"""
Intent recognition node that uses LLM to classify user queries into specific domains
"""
try:
logger.info("🎯 INTENT_RECOGNITION_NODE: Starting intent classification")
app_config = get_config()
llm_client = LLMClient()
# Get current user query and conversation history
current_query = get_last_user_message(state["messages"])
conversation_context = render_conversation_history(state["messages"])
# Get intent classification prompt from configuration
rag_prompts = app_config.get_rag_prompts()
intent_prompt_template = rag_prompts.get("intent_recognition_prompt")
if not intent_prompt_template:
logger.error("Intent recognition prompt not found in configuration")
return {"intent": "Standard_Regulation_RAG"}
# Format the prompt with instruction to return only the label
system_prompt = intent_prompt_template.format(
current_query=current_query,
conversation_context=conversation_context
) + "\n\nIMPORTANT: You must respond with ONLY one of these two exact labels: 'Standard_Regulation_RAG' or 'User_Manual_RAG'. Do not include any other text or explanation."
# Classify intent using regular LLM call
intent_result = await llm_client.llm.ainvoke([
SystemMessage(content=system_prompt)
])
# Parse the response to extract the intent label
response_text = ""
if hasattr(intent_result, 'content') and intent_result.content:
if isinstance(intent_result.content, str):
response_text = intent_result.content.strip()
elif isinstance(intent_result.content, list):
# Handle list content by joining string elements
response_text = " ".join([str(item) for item in intent_result.content if isinstance(item, str)]).strip()
# Extract intent label from response
if "User_Manual_RAG" in response_text:
intent_label = "User_Manual_RAG"
elif "Standard_Regulation_RAG" in response_text:
intent_label = "Standard_Regulation_RAG"
else:
# Default fallback
logger.warning(f"Could not parse intent from response: {response_text}, defaulting to Standard_Regulation_RAG")
intent_label = "Standard_Regulation_RAG"
logger.info(f"🎯 INTENT_RECOGNITION_NODE: Classified intent as '{intent_label}'")
return {"intent": intent_label}
except Exception as e:
logger.error(f"Intent recognition error: {e}")
# Default to Standard_Regulation_RAG if classification fails
logger.info("🎯 INTENT_RECOGNITION_NODE: Defaulting to Standard_Regulation_RAG due to error")
return {"intent": "Standard_Regulation_RAG"}
def intent_router(state: AgentState) -> Literal["Standard_Regulation_RAG", "User_Manual_RAG"]:
"""
Route based on intent classification result
"""
intent = state.get("intent")
if intent is None:
logger.warning("🎯 INTENT_ROUTER: No intent found, defaulting to Standard_Regulation_RAG")
return "Standard_Regulation_RAG"
logger.info(f"🎯 INTENT_ROUTER: Routing to {intent}")
return intent