init
This commit is contained in:
187
vw-agentic-rag/service/main.py
Normal file
187
vw-agentic-rag/service/main.py
Normal file
@@ -0,0 +1,187 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
import uvicorn
|
||||
|
||||
from .config import load_config, get_config
|
||||
from .schemas.messages import ChatRequest
|
||||
from .memory.postgresql_memory import get_memory_manager
|
||||
from .graph.state import TurnState, Message
|
||||
from .graph.graph import build_graph
|
||||
from .sse import create_error_event
|
||||
from .utils.error_handler import StructuredLogger, ErrorCategory, ErrorCode, handle_async_errors
|
||||
from .utils.middleware import ErrorMiddleware
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = StructuredLogger(__name__)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Application lifespan manager"""
|
||||
# Startup
|
||||
try:
|
||||
logger.info("Starting application initialization...")
|
||||
|
||||
# Initialize PostgreSQL memory manager
|
||||
memory_manager = get_memory_manager()
|
||||
connection_ok = memory_manager.test_connection()
|
||||
logger.info(f"PostgreSQL memory manager initialized (connected: {connection_ok})")
|
||||
|
||||
# Initialize global components
|
||||
app.state.memory_manager = memory_manager
|
||||
app.state.graph = build_graph()
|
||||
|
||||
logger.info("Application startup complete")
|
||||
yield
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start application: {e}")
|
||||
raise
|
||||
finally:
|
||||
# Shutdown
|
||||
logger.info("Application shutdown")
|
||||
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
"""Application factory"""
|
||||
# Load configuration first
|
||||
config = load_config()
|
||||
logger.info(f"Loaded configuration for provider: {config.provider}")
|
||||
|
||||
app = FastAPI(
|
||||
title="Agentic RAG API",
|
||||
description="Agentic RAG application for manufacturing standards and regulations",
|
||||
version="0.1.0",
|
||||
lifespan=lifespan
|
||||
)
|
||||
|
||||
# Add error handling middleware
|
||||
app.add_middleware(ErrorMiddleware)
|
||||
|
||||
# Add CORS middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=config.app.cors_origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Define routes
|
||||
@app.post("/api/chat")
|
||||
async def chat_endpoint(request: ChatRequest):
|
||||
"""Main chat endpoint with SSE streaming"""
|
||||
try:
|
||||
return StreamingResponse(
|
||||
stream_chat_response(request),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"Access-Control-Allow-Origin": "*",
|
||||
"Access-Control-Allow-Headers": "*",
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Chat endpoint error: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.post("/api/ai-sdk/chat")
|
||||
async def ai_sdk_chat_endpoint(request: ChatRequest):
|
||||
"""AI SDK compatible chat endpoint"""
|
||||
try:
|
||||
# Import here to avoid circular imports
|
||||
from .ai_sdk_chat import handle_ai_sdk_chat
|
||||
return await handle_ai_sdk_chat(request, app.state)
|
||||
except Exception as e:
|
||||
logger.error(f"AI SDK chat endpoint error: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""Health check endpoint"""
|
||||
return {"status": "healthy", "service": "agentic-rag"}
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
"""Root endpoint"""
|
||||
return {"message": "Agentic RAG API for Manufacturing Standards & Regulations"}
|
||||
|
||||
return app
|
||||
|
||||
|
||||
# Create the global app instance for uvicorn
|
||||
app = create_app()
|
||||
|
||||
|
||||
@handle_async_errors(ErrorCategory.LLM, ErrorCode.LLM_ERROR)
|
||||
async def stream_chat_response(request: ChatRequest) -> AsyncGenerator[str, None]:
|
||||
"""Stream chat response with enhanced error handling"""
|
||||
config = get_config()
|
||||
memory_manager = app.state.memory_manager
|
||||
graph = app.state.graph
|
||||
|
||||
# Create conversation state
|
||||
state = TurnState(session_id=request.session_id)
|
||||
|
||||
# Add user message
|
||||
if request.messages:
|
||||
last_message = request.messages[-1]
|
||||
if last_message.get("role") == "user":
|
||||
user_message = Message(
|
||||
role="user",
|
||||
content=last_message.get("content", "")
|
||||
)
|
||||
state.messages.append(user_message)
|
||||
|
||||
# Create event queue for streaming
|
||||
event_queue = asyncio.Queue()
|
||||
|
||||
async def stream_callback(event_str: str):
|
||||
await event_queue.put(event_str)
|
||||
|
||||
# Execute workflow in background task
|
||||
async def run_workflow():
|
||||
try:
|
||||
async for _ in graph.astream(state, stream_callback):
|
||||
pass
|
||||
await event_queue.put(None) # Signal completion
|
||||
except Exception as e:
|
||||
logger.error("Workflow execution failed", error=e,
|
||||
category=ErrorCategory.LLM, error_code=ErrorCode.LLM_ERROR)
|
||||
await event_queue.put(create_error_event("Processing error: AI service is temporarily unavailable"))
|
||||
await event_queue.put(None)
|
||||
|
||||
# Start workflow task
|
||||
workflow_task = asyncio.create_task(run_workflow())
|
||||
|
||||
# Stream events as they come
|
||||
try:
|
||||
while True:
|
||||
event = await event_queue.get()
|
||||
if event is None: # Completion signal
|
||||
break
|
||||
yield event
|
||||
finally:
|
||||
if not workflow_task.done():
|
||||
workflow_task.cancel()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
config = load_config() # Load configuration first
|
||||
uvicorn.run(
|
||||
"service.main:app",
|
||||
host=config.app.host,
|
||||
port=config.app.port,
|
||||
reload=True,
|
||||
log_level="info"
|
||||
)
|
||||
Reference in New Issue
Block a user