67 lines
2.1 KiB
Python
67 lines
2.1 KiB
Python
from pydantic import BaseModel, Field
|
|
from typing import List, Dict, Any, Optional, Literal
|
|
from datetime import datetime
|
|
from typing_extensions import Annotated
|
|
from langgraph.graph.message import add_messages
|
|
from langchain_core.messages import BaseMessage
|
|
|
|
|
|
class Message(BaseModel):
|
|
"""Base message class for conversation history"""
|
|
role: str # "user", "assistant", "tool"
|
|
content: str
|
|
timestamp: Optional[datetime] = None
|
|
tool_call_id: Optional[str] = None
|
|
tool_name: Optional[str] = None
|
|
|
|
|
|
class Citation(BaseModel):
|
|
"""Citation mapping between numbers and result IDs"""
|
|
number: int
|
|
result_id: str
|
|
url: Optional[str] = None
|
|
|
|
|
|
class ToolResult(BaseModel):
|
|
"""Normalized tool result schema"""
|
|
id: str
|
|
title: str
|
|
url: Optional[str] = None
|
|
score: Optional[float] = None
|
|
metadata: Dict[str, Any] = Field(default_factory=dict)
|
|
content: Optional[str] = None # For chunk results
|
|
# Standard/regulation specific fields
|
|
publisher: Optional[str] = None
|
|
publish_date: Optional[str] = None
|
|
document_code: Optional[str] = None
|
|
document_category: Optional[str] = None
|
|
|
|
|
|
class TurnState(BaseModel):
|
|
"""State container for LangGraph workflow"""
|
|
session_id: str
|
|
messages: List[Message] = Field(default_factory=list)
|
|
tool_results: List[ToolResult] = Field(default_factory=list)
|
|
citations: List[Citation] = Field(default_factory=list)
|
|
meta: Dict[str, Any] = Field(default_factory=dict)
|
|
|
|
# Additional fields for tracking
|
|
current_step: int = 0
|
|
max_steps: int = 5
|
|
final_answer: Optional[str] = None
|
|
|
|
|
|
# TypedDict for LangGraph AgentState (LangGraph native format)
|
|
from typing import TypedDict
|
|
from langgraph.graph import MessagesState
|
|
|
|
class AgentState(MessagesState):
|
|
"""LangGraph state with intent recognition support"""
|
|
session_id: str
|
|
intent: Optional[Literal["Standard_Regulation_RAG", "User_Manual_RAG"]]
|
|
tool_results: Annotated[List[Dict[str, Any]], lambda x, y: (x or []) + (y or [])]
|
|
final_answer: str
|
|
tool_rounds: int
|
|
max_tool_rounds: int
|
|
max_tool_rounds_user_manual: int
|