init
This commit is contained in:
158
vw-agentic-rag/docs/topics/LANGGRAPH_IMPROVEMENTS.md
Normal file
158
vw-agentic-rag/docs/topics/LANGGRAPH_IMPROVEMENTS.md
Normal file
@@ -0,0 +1,158 @@
|
||||
# LangGraph Implementation Analysis and Improvements
|
||||
|
||||
## Official Example vs Current Implementation
|
||||
|
||||
### Key Differences Found
|
||||
|
||||
#### 1. **Graph Structure**
|
||||
**Official Example:**
|
||||
```python
|
||||
workflow = StateGraph(AgentState)
|
||||
workflow.add_node("agent", call_model)
|
||||
workflow.add_node("tools", run_tools)
|
||||
workflow.set_entry_point("agent")
|
||||
workflow.add_conditional_edges("agent", should_continue, ["tools", END])
|
||||
workflow.add_edge("tools", "agent")
|
||||
graph = workflow.compile()
|
||||
```
|
||||
|
||||
**Current Implementation:**
|
||||
```python
|
||||
class AgentWorkflow:
|
||||
def __init__(self):
|
||||
self.agent_node = AgentNode()
|
||||
self.post_process_node = PostProcessNode()
|
||||
|
||||
async def astream(self, state, stream_callback):
|
||||
state = await self.agent_node(state, stream_callback)
|
||||
state = await self.post_process_node(state, stream_callback)
|
||||
```
|
||||
|
||||
#### 2. **State Management**
|
||||
**Official Example:**
|
||||
```python
|
||||
class AgentState(TypedDict):
|
||||
messages: Annotated[list, add_messages]
|
||||
```
|
||||
|
||||
**Current Implementation:**
|
||||
```python
|
||||
class TurnState(BaseModel):
|
||||
session_id: str
|
||||
messages: List[Message] = Field(default_factory=list)
|
||||
tool_results: List[ToolResult] = Field(default_factory=list)
|
||||
citations: List[Citation] = Field(default_factory=list)
|
||||
# ... many more fields
|
||||
```
|
||||
|
||||
#### 3. **Tool Handling**
|
||||
**Official Example:**
|
||||
```python
|
||||
@tool
|
||||
def get_stock_price(stock_symbol: str):
|
||||
return mock_stock_data[stock_symbol]
|
||||
|
||||
tools = [get_stock_price]
|
||||
tool_node = ToolNode(tools)
|
||||
```
|
||||
|
||||
**Current Implementation:**
|
||||
```python
|
||||
async def _execute_tool_call(self, tool_call, state, stream_callback):
|
||||
async with RetrievalTools() as retrieval:
|
||||
if tool_name == "retrieve_standard_regulation":
|
||||
result = await retrieval.retrieve_standard_regulation(**tool_args)
|
||||
# Manual tool execution logic
|
||||
```
|
||||
|
||||
## Recommendations for Improvement
|
||||
|
||||
### 1. **Use Standard LangGraph Patterns**
|
||||
- Adopt `StateGraph` with `add_node()` and `add_edge()`
|
||||
- Use `@tool` decorators for cleaner tool definitions
|
||||
- Leverage `ToolNode` for automatic tool execution
|
||||
|
||||
### 2. **Simplify State Management**
|
||||
- Reduce state complexity where possible
|
||||
- Use LangGraph's `add_messages` helper for message handling
|
||||
- Keep only essential fields in the main state
|
||||
|
||||
### 3. **Improve Code Organization**
|
||||
- Separate concerns: graph definition, tool definitions, state
|
||||
- Use factory functions for graph creation
|
||||
- Follow LangGraph's recommended patterns
|
||||
|
||||
### 4. **Better Tool Integration**
|
||||
- Use `@tool` decorators for automatic schema generation
|
||||
- Leverage LangGraph's built-in tool execution
|
||||
- Reduce manual tool call handling
|
||||
|
||||
## Implementation Plan
|
||||
|
||||
### Phase 1: Create Simplified Graph (✅ Done)
|
||||
- `service/graph/simplified_graph.py` - follows LangGraph patterns
|
||||
- Uses `@tool` decorators
|
||||
- Cleaner state management
|
||||
- Reduced complexity
|
||||
|
||||
### Phase 2: Update Main Implementation
|
||||
- Refactor existing `graph.py` to use LangGraph patterns
|
||||
- Keep existing functionality but improve structure
|
||||
- Maintain backward compatibility
|
||||
|
||||
### Phase 3: Testing and Migration
|
||||
- Test simplified implementation
|
||||
- Gradual migration of features
|
||||
- Performance comparison
|
||||
|
||||
## Code Comparison
|
||||
|
||||
### Tool Definition
|
||||
**Before:**
|
||||
```python
|
||||
async def _execute_tool_call(self, tool_call, state, stream_callback):
|
||||
tool_name = tool_call["name"]
|
||||
tool_args = tool_call["args"]
|
||||
async with RetrievalTools() as retrieval:
|
||||
if tool_name == "retrieve_standard_regulation":
|
||||
result = await retrieval.retrieve_standard_regulation(**tool_args)
|
||||
# 20+ lines of manual handling
|
||||
```
|
||||
|
||||
**After:**
|
||||
```python
|
||||
@tool
|
||||
async def retrieve_standard_regulation(query: str, conversation_history: str = "") -> str:
|
||||
async with RetrievalTools() as retrieval:
|
||||
result = await retrieval.retrieve_standard_regulation(query=query, conversation_history=conversation_history)
|
||||
return f"Found {len(result.results)} results"
|
||||
```
|
||||
|
||||
### Graph Creation
|
||||
**Before:**
|
||||
```python
|
||||
class AgentWorkflow:
|
||||
def __init__(self):
|
||||
self.agent_node = AgentNode()
|
||||
self.post_process_node = PostProcessNode()
|
||||
```
|
||||
|
||||
**After:**
|
||||
```python
|
||||
def create_agent_graph():
|
||||
workflow = StateGraph(AgentState)
|
||||
workflow.add_node("agent", call_model)
|
||||
workflow.add_node("tools", run_tools)
|
||||
workflow.set_entry_point("agent")
|
||||
workflow.add_conditional_edges("agent", should_continue, ["tools", END])
|
||||
return workflow.compile()
|
||||
```
|
||||
|
||||
## Benefits of LangGraph Patterns
|
||||
|
||||
1. **Declarative**: Graph structure is explicit and easy to understand
|
||||
2. **Modular**: Nodes and edges can be easily modified
|
||||
3. **Testable**: Individual nodes can be tested in isolation
|
||||
4. **Standard**: Follows LangGraph community conventions
|
||||
5. **Maintainable**: Less custom logic, more framework features
|
||||
6. **Debuggable**: LangGraph provides built-in debugging tools
|
||||
Reference in New Issue
Block a user