init
This commit is contained in:
136
vw-agentic-rag/service/graph/intent_recognition.py
Normal file
136
vw-agentic-rag/service/graph/intent_recognition.py
Normal file
@@ -0,0 +1,136 @@
|
||||
"""
|
||||
Intent recognition functionality for the Agentic RAG system.
|
||||
This module contains the intent classification logic.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Any, Optional, Literal
|
||||
from langchain_core.messages import SystemMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .state import AgentState
|
||||
from ..llm_client import LLMClient
|
||||
from ..config import get_config
|
||||
from ..utils.error_handler import StructuredLogger
|
||||
|
||||
logger = StructuredLogger(__name__)
|
||||
|
||||
|
||||
# Intent Recognition Models
|
||||
class Intent(BaseModel):
|
||||
"""Intent classification model for routing user queries"""
|
||||
label: Literal["Standard_Regulation_RAG", "User_Manual_RAG"]
|
||||
confidence: Optional[float] = None
|
||||
|
||||
|
||||
def get_last_user_message(messages) -> str:
|
||||
"""Extract the last user message from conversation history"""
|
||||
for message in reversed(messages):
|
||||
if hasattr(message, 'content'):
|
||||
content = message.content
|
||||
# Handle both string and list content
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
elif isinstance(content, list):
|
||||
# Extract string content from list
|
||||
return " ".join([str(item) for item in content if isinstance(item, str)])
|
||||
return ""
|
||||
|
||||
|
||||
def render_conversation_history(messages, max_messages: int = 10) -> str:
|
||||
"""Render conversation history for context"""
|
||||
recent_messages = messages[-max_messages:] if len(messages) > max_messages else messages
|
||||
lines = []
|
||||
for msg in recent_messages:
|
||||
if hasattr(msg, 'content'):
|
||||
content = msg.content
|
||||
if isinstance(content, str):
|
||||
# Determine message type by class name or other attributes
|
||||
if 'Human' in str(type(msg)):
|
||||
lines.append(f"<user>{content}</user>")
|
||||
elif 'AI' in str(type(msg)):
|
||||
lines.append(f"<ai>{content}</ai>")
|
||||
elif isinstance(content, list):
|
||||
content_str = " ".join([str(item) for item in content if isinstance(item, str)])
|
||||
if 'Human' in str(type(msg)):
|
||||
lines.append(f"<user>{content_str}</user>")
|
||||
elif 'AI' in str(type(msg)):
|
||||
lines.append(f"<ai>{content_str}</ai>")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
async def intent_recognition_node(state: AgentState, config: Optional[RunnableConfig] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Intent recognition node that uses LLM to classify user queries into specific domains
|
||||
"""
|
||||
try:
|
||||
logger.info("🎯 INTENT_RECOGNITION_NODE: Starting intent classification")
|
||||
|
||||
app_config = get_config()
|
||||
llm_client = LLMClient()
|
||||
|
||||
# Get current user query and conversation history
|
||||
current_query = get_last_user_message(state["messages"])
|
||||
conversation_context = render_conversation_history(state["messages"])
|
||||
|
||||
# Get intent classification prompt from configuration
|
||||
rag_prompts = app_config.get_rag_prompts()
|
||||
intent_prompt_template = rag_prompts.get("intent_recognition_prompt")
|
||||
|
||||
if not intent_prompt_template:
|
||||
logger.error("Intent recognition prompt not found in configuration")
|
||||
return {"intent": "Standard_Regulation_RAG"}
|
||||
|
||||
# Format the prompt with instruction to return only the label
|
||||
system_prompt = intent_prompt_template.format(
|
||||
current_query=current_query,
|
||||
conversation_context=conversation_context
|
||||
) + "\n\nIMPORTANT: You must respond with ONLY one of these two exact labels: 'Standard_Regulation_RAG' or 'User_Manual_RAG'. Do not include any other text or explanation."
|
||||
|
||||
# Classify intent using regular LLM call
|
||||
intent_result = await llm_client.llm.ainvoke([
|
||||
SystemMessage(content=system_prompt)
|
||||
])
|
||||
|
||||
# Parse the response to extract the intent label
|
||||
response_text = ""
|
||||
if hasattr(intent_result, 'content') and intent_result.content:
|
||||
if isinstance(intent_result.content, str):
|
||||
response_text = intent_result.content.strip()
|
||||
elif isinstance(intent_result.content, list):
|
||||
# Handle list content by joining string elements
|
||||
response_text = " ".join([str(item) for item in intent_result.content if isinstance(item, str)]).strip()
|
||||
|
||||
# Extract intent label from response
|
||||
if "User_Manual_RAG" in response_text:
|
||||
intent_label = "User_Manual_RAG"
|
||||
elif "Standard_Regulation_RAG" in response_text:
|
||||
intent_label = "Standard_Regulation_RAG"
|
||||
else:
|
||||
# Default fallback
|
||||
logger.warning(f"Could not parse intent from response: {response_text}, defaulting to Standard_Regulation_RAG")
|
||||
intent_label = "Standard_Regulation_RAG"
|
||||
|
||||
logger.info(f"🎯 INTENT_RECOGNITION_NODE: Classified intent as '{intent_label}'")
|
||||
|
||||
return {"intent": intent_label}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Intent recognition error: {e}")
|
||||
# Default to Standard_Regulation_RAG if classification fails
|
||||
logger.info("🎯 INTENT_RECOGNITION_NODE: Defaulting to Standard_Regulation_RAG due to error")
|
||||
return {"intent": "Standard_Regulation_RAG"}
|
||||
|
||||
|
||||
def intent_router(state: AgentState) -> Literal["Standard_Regulation_RAG", "User_Manual_RAG"]:
|
||||
"""
|
||||
Route based on intent classification result
|
||||
"""
|
||||
intent = state.get("intent")
|
||||
if intent is None:
|
||||
logger.warning("🎯 INTENT_ROUTER: No intent found, defaulting to Standard_Regulation_RAG")
|
||||
return "Standard_Regulation_RAG"
|
||||
|
||||
logger.info(f"🎯 INTENT_ROUTER: Routing to {intent}")
|
||||
return intent
|
||||
Reference in New Issue
Block a user