init
This commit is contained in:
66
vw-agentic-rag/service/graph/state.py
Normal file
66
vw-agentic-rag/service/graph/state.py
Normal file
@@ -0,0 +1,66 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Dict, Any, Optional, Literal
|
||||
from datetime import datetime
|
||||
from typing_extensions import Annotated
|
||||
from langgraph.graph.message import add_messages
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
"""Base message class for conversation history"""
|
||||
role: str # "user", "assistant", "tool"
|
||||
content: str
|
||||
timestamp: Optional[datetime] = None
|
||||
tool_call_id: Optional[str] = None
|
||||
tool_name: Optional[str] = None
|
||||
|
||||
|
||||
class Citation(BaseModel):
|
||||
"""Citation mapping between numbers and result IDs"""
|
||||
number: int
|
||||
result_id: str
|
||||
url: Optional[str] = None
|
||||
|
||||
|
||||
class ToolResult(BaseModel):
|
||||
"""Normalized tool result schema"""
|
||||
id: str
|
||||
title: str
|
||||
url: Optional[str] = None
|
||||
score: Optional[float] = None
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
content: Optional[str] = None # For chunk results
|
||||
# Standard/regulation specific fields
|
||||
publisher: Optional[str] = None
|
||||
publish_date: Optional[str] = None
|
||||
document_code: Optional[str] = None
|
||||
document_category: Optional[str] = None
|
||||
|
||||
|
||||
class TurnState(BaseModel):
|
||||
"""State container for LangGraph workflow"""
|
||||
session_id: str
|
||||
messages: List[Message] = Field(default_factory=list)
|
||||
tool_results: List[ToolResult] = Field(default_factory=list)
|
||||
citations: List[Citation] = Field(default_factory=list)
|
||||
meta: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
# Additional fields for tracking
|
||||
current_step: int = 0
|
||||
max_steps: int = 5
|
||||
final_answer: Optional[str] = None
|
||||
|
||||
|
||||
# TypedDict for LangGraph AgentState (LangGraph native format)
|
||||
from typing import TypedDict
|
||||
from langgraph.graph import MessagesState
|
||||
|
||||
class AgentState(MessagesState):
|
||||
"""LangGraph state with intent recognition support"""
|
||||
session_id: str
|
||||
intent: Optional[Literal["Standard_Regulation_RAG", "User_Manual_RAG"]]
|
||||
tool_results: Annotated[List[Dict[str, Any]], lambda x, y: (x or []) + (y or [])]
|
||||
final_answer: str
|
||||
tool_rounds: int
|
||||
max_tool_rounds: int
|
||||
max_tool_rounds_user_manual: int
|
||||
Reference in New Issue
Block a user