137 lines
5.7 KiB
Python
137 lines
5.7 KiB
Python
"""
|
|
Intent recognition functionality for the Agentic RAG system.
|
|
This module contains the intent classification logic.
|
|
"""
|
|
|
|
import logging
|
|
from typing import Dict, Any, Optional, Literal
|
|
from langchain_core.messages import SystemMessage
|
|
from langchain_core.runnables import RunnableConfig
|
|
from pydantic import BaseModel
|
|
|
|
from .state import AgentState
|
|
from ..llm_client import LLMClient
|
|
from ..config import get_config
|
|
from ..utils.error_handler import StructuredLogger
|
|
|
|
logger = StructuredLogger(__name__)
|
|
|
|
|
|
# Intent Recognition Models
|
|
class Intent(BaseModel):
|
|
"""Intent classification model for routing user queries"""
|
|
label: Literal["Standard_Regulation_RAG", "User_Manual_RAG"]
|
|
confidence: Optional[float] = None
|
|
|
|
|
|
def get_last_user_message(messages) -> str:
|
|
"""Extract the last user message from conversation history"""
|
|
for message in reversed(messages):
|
|
if hasattr(message, 'content'):
|
|
content = message.content
|
|
# Handle both string and list content
|
|
if isinstance(content, str):
|
|
return content
|
|
elif isinstance(content, list):
|
|
# Extract string content from list
|
|
return " ".join([str(item) for item in content if isinstance(item, str)])
|
|
return ""
|
|
|
|
|
|
def render_conversation_history(messages, max_messages: int = 10) -> str:
|
|
"""Render conversation history for context"""
|
|
recent_messages = messages[-max_messages:] if len(messages) > max_messages else messages
|
|
lines = []
|
|
for msg in recent_messages:
|
|
if hasattr(msg, 'content'):
|
|
content = msg.content
|
|
if isinstance(content, str):
|
|
# Determine message type by class name or other attributes
|
|
if 'Human' in str(type(msg)):
|
|
lines.append(f"<user>{content}</user>")
|
|
elif 'AI' in str(type(msg)):
|
|
lines.append(f"<ai>{content}</ai>")
|
|
elif isinstance(content, list):
|
|
content_str = " ".join([str(item) for item in content if isinstance(item, str)])
|
|
if 'Human' in str(type(msg)):
|
|
lines.append(f"<user>{content_str}</user>")
|
|
elif 'AI' in str(type(msg)):
|
|
lines.append(f"<ai>{content_str}</ai>")
|
|
return "\n".join(lines)
|
|
|
|
|
|
async def intent_recognition_node(state: AgentState, config: Optional[RunnableConfig] = None) -> Dict[str, Any]:
|
|
"""
|
|
Intent recognition node that uses LLM to classify user queries into specific domains
|
|
"""
|
|
try:
|
|
logger.info("🎯 INTENT_RECOGNITION_NODE: Starting intent classification")
|
|
|
|
app_config = get_config()
|
|
llm_client = LLMClient()
|
|
|
|
# Get current user query and conversation history
|
|
current_query = get_last_user_message(state["messages"])
|
|
conversation_context = render_conversation_history(state["messages"])
|
|
|
|
# Get intent classification prompt from configuration
|
|
rag_prompts = app_config.get_rag_prompts()
|
|
intent_prompt_template = rag_prompts.get("intent_recognition_prompt")
|
|
|
|
if not intent_prompt_template:
|
|
logger.error("Intent recognition prompt not found in configuration")
|
|
return {"intent": "Standard_Regulation_RAG"}
|
|
|
|
# Format the prompt with instruction to return only the label
|
|
system_prompt = intent_prompt_template.format(
|
|
current_query=current_query,
|
|
conversation_context=conversation_context
|
|
) + "\n\nIMPORTANT: You must respond with ONLY one of these two exact labels: 'Standard_Regulation_RAG' or 'User_Manual_RAG'. Do not include any other text or explanation."
|
|
|
|
# Classify intent using regular LLM call
|
|
intent_result = await llm_client.llm.ainvoke([
|
|
SystemMessage(content=system_prompt)
|
|
])
|
|
|
|
# Parse the response to extract the intent label
|
|
response_text = ""
|
|
if hasattr(intent_result, 'content') and intent_result.content:
|
|
if isinstance(intent_result.content, str):
|
|
response_text = intent_result.content.strip()
|
|
elif isinstance(intent_result.content, list):
|
|
# Handle list content by joining string elements
|
|
response_text = " ".join([str(item) for item in intent_result.content if isinstance(item, str)]).strip()
|
|
|
|
# Extract intent label from response
|
|
if "User_Manual_RAG" in response_text:
|
|
intent_label = "User_Manual_RAG"
|
|
elif "Standard_Regulation_RAG" in response_text:
|
|
intent_label = "Standard_Regulation_RAG"
|
|
else:
|
|
# Default fallback
|
|
logger.warning(f"Could not parse intent from response: {response_text}, defaulting to Standard_Regulation_RAG")
|
|
intent_label = "Standard_Regulation_RAG"
|
|
|
|
logger.info(f"🎯 INTENT_RECOGNITION_NODE: Classified intent as '{intent_label}'")
|
|
|
|
return {"intent": intent_label}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Intent recognition error: {e}")
|
|
# Default to Standard_Regulation_RAG if classification fails
|
|
logger.info("🎯 INTENT_RECOGNITION_NODE: Defaulting to Standard_Regulation_RAG due to error")
|
|
return {"intent": "Standard_Regulation_RAG"}
|
|
|
|
|
|
def intent_router(state: AgentState) -> Literal["Standard_Regulation_RAG", "User_Manual_RAG"]:
|
|
"""
|
|
Route based on intent classification result
|
|
"""
|
|
intent = state.get("intent")
|
|
if intent is None:
|
|
logger.warning("🎯 INTENT_ROUTER: No intent found, defaulting to Standard_Regulation_RAG")
|
|
return "Standard_Regulation_RAG"
|
|
|
|
logger.info(f"🎯 INTENT_ROUTER: Routing to {intent}")
|
|
return intent
|