159 lines
4.7 KiB
Markdown
159 lines
4.7 KiB
Markdown
# LangGraph Implementation Analysis and Improvements
|
|
|
|
## Official Example vs Current Implementation
|
|
|
|
### Key Differences Found
|
|
|
|
#### 1. **Graph Structure**
|
|
**Official Example:**
|
|
```python
|
|
workflow = StateGraph(AgentState)
|
|
workflow.add_node("agent", call_model)
|
|
workflow.add_node("tools", run_tools)
|
|
workflow.set_entry_point("agent")
|
|
workflow.add_conditional_edges("agent", should_continue, ["tools", END])
|
|
workflow.add_edge("tools", "agent")
|
|
graph = workflow.compile()
|
|
```
|
|
|
|
**Current Implementation:**
|
|
```python
|
|
class AgentWorkflow:
|
|
def __init__(self):
|
|
self.agent_node = AgentNode()
|
|
self.post_process_node = PostProcessNode()
|
|
|
|
async def astream(self, state, stream_callback):
|
|
state = await self.agent_node(state, stream_callback)
|
|
state = await self.post_process_node(state, stream_callback)
|
|
```
|
|
|
|
#### 2. **State Management**
|
|
**Official Example:**
|
|
```python
|
|
class AgentState(TypedDict):
|
|
messages: Annotated[list, add_messages]
|
|
```
|
|
|
|
**Current Implementation:**
|
|
```python
|
|
class TurnState(BaseModel):
|
|
session_id: str
|
|
messages: List[Message] = Field(default_factory=list)
|
|
tool_results: List[ToolResult] = Field(default_factory=list)
|
|
citations: List[Citation] = Field(default_factory=list)
|
|
# ... many more fields
|
|
```
|
|
|
|
#### 3. **Tool Handling**
|
|
**Official Example:**
|
|
```python
|
|
@tool
|
|
def get_stock_price(stock_symbol: str):
|
|
return mock_stock_data[stock_symbol]
|
|
|
|
tools = [get_stock_price]
|
|
tool_node = ToolNode(tools)
|
|
```
|
|
|
|
**Current Implementation:**
|
|
```python
|
|
async def _execute_tool_call(self, tool_call, state, stream_callback):
|
|
async with RetrievalTools() as retrieval:
|
|
if tool_name == "retrieve_standard_regulation":
|
|
result = await retrieval.retrieve_standard_regulation(**tool_args)
|
|
# Manual tool execution logic
|
|
```
|
|
|
|
## Recommendations for Improvement
|
|
|
|
### 1. **Use Standard LangGraph Patterns**
|
|
- Adopt `StateGraph` with `add_node()` and `add_edge()`
|
|
- Use `@tool` decorators for cleaner tool definitions
|
|
- Leverage `ToolNode` for automatic tool execution
|
|
|
|
### 2. **Simplify State Management**
|
|
- Reduce state complexity where possible
|
|
- Use LangGraph's `add_messages` helper for message handling
|
|
- Keep only essential fields in the main state
|
|
|
|
### 3. **Improve Code Organization**
|
|
- Separate concerns: graph definition, tool definitions, state
|
|
- Use factory functions for graph creation
|
|
- Follow LangGraph's recommended patterns
|
|
|
|
### 4. **Better Tool Integration**
|
|
- Use `@tool` decorators for automatic schema generation
|
|
- Leverage LangGraph's built-in tool execution
|
|
- Reduce manual tool call handling
|
|
|
|
## Implementation Plan
|
|
|
|
### Phase 1: Create Simplified Graph (✅ Done)
|
|
- `service/graph/simplified_graph.py` - follows LangGraph patterns
|
|
- Uses `@tool` decorators
|
|
- Cleaner state management
|
|
- Reduced complexity
|
|
|
|
### Phase 2: Update Main Implementation
|
|
- Refactor existing `graph.py` to use LangGraph patterns
|
|
- Keep existing functionality but improve structure
|
|
- Maintain backward compatibility
|
|
|
|
### Phase 3: Testing and Migration
|
|
- Test simplified implementation
|
|
- Gradual migration of features
|
|
- Performance comparison
|
|
|
|
## Code Comparison
|
|
|
|
### Tool Definition
|
|
**Before:**
|
|
```python
|
|
async def _execute_tool_call(self, tool_call, state, stream_callback):
|
|
tool_name = tool_call["name"]
|
|
tool_args = tool_call["args"]
|
|
async with RetrievalTools() as retrieval:
|
|
if tool_name == "retrieve_standard_regulation":
|
|
result = await retrieval.retrieve_standard_regulation(**tool_args)
|
|
# 20+ lines of manual handling
|
|
```
|
|
|
|
**After:**
|
|
```python
|
|
@tool
|
|
async def retrieve_standard_regulation(query: str, conversation_history: str = "") -> str:
|
|
async with RetrievalTools() as retrieval:
|
|
result = await retrieval.retrieve_standard_regulation(query=query, conversation_history=conversation_history)
|
|
return f"Found {len(result.results)} results"
|
|
```
|
|
|
|
### Graph Creation
|
|
**Before:**
|
|
```python
|
|
class AgentWorkflow:
|
|
def __init__(self):
|
|
self.agent_node = AgentNode()
|
|
self.post_process_node = PostProcessNode()
|
|
```
|
|
|
|
**After:**
|
|
```python
|
|
def create_agent_graph():
|
|
workflow = StateGraph(AgentState)
|
|
workflow.add_node("agent", call_model)
|
|
workflow.add_node("tools", run_tools)
|
|
workflow.set_entry_point("agent")
|
|
workflow.add_conditional_edges("agent", should_continue, ["tools", END])
|
|
return workflow.compile()
|
|
```
|
|
|
|
## Benefits of LangGraph Patterns
|
|
|
|
1. **Declarative**: Graph structure is explicit and easy to understand
|
|
2. **Modular**: Nodes and edges can be easily modified
|
|
3. **Testable**: Individual nodes can be tested in isolation
|
|
4. **Standard**: Follows LangGraph community conventions
|
|
5. **Maintainable**: Less custom logic, more framework features
|
|
6. **Debuggable**: LangGraph provides built-in debugging tools
|