188 lines
5.9 KiB
Python
188 lines
5.9 KiB
Python
import asyncio
|
|
import logging
|
|
from typing import AsyncGenerator
|
|
from contextlib import asynccontextmanager
|
|
|
|
from fastapi import FastAPI, HTTPException
|
|
from fastapi.responses import StreamingResponse
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
import uvicorn
|
|
|
|
from .config import load_config, get_config
|
|
from .schemas.messages import ChatRequest
|
|
from .memory.postgresql_memory import get_memory_manager
|
|
from .graph.state import TurnState, Message
|
|
from .graph.graph import build_graph
|
|
from .sse import create_error_event
|
|
from .utils.error_handler import StructuredLogger, ErrorCategory, ErrorCode, handle_async_errors
|
|
from .utils.middleware import ErrorMiddleware
|
|
|
|
# Setup logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|
)
|
|
logger = StructuredLogger(__name__)
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
"""Application lifespan manager"""
|
|
# Startup
|
|
try:
|
|
logger.info("Starting application initialization...")
|
|
|
|
# Initialize PostgreSQL memory manager
|
|
memory_manager = get_memory_manager()
|
|
connection_ok = memory_manager.test_connection()
|
|
logger.info(f"PostgreSQL memory manager initialized (connected: {connection_ok})")
|
|
|
|
# Initialize global components
|
|
app.state.memory_manager = memory_manager
|
|
app.state.graph = build_graph()
|
|
|
|
logger.info("Application startup complete")
|
|
yield
|
|
except Exception as e:
|
|
logger.error(f"Failed to start application: {e}")
|
|
raise
|
|
finally:
|
|
# Shutdown
|
|
logger.info("Application shutdown")
|
|
|
|
|
|
def create_app() -> FastAPI:
|
|
"""Application factory"""
|
|
# Load configuration first
|
|
config = load_config()
|
|
logger.info(f"Loaded configuration for provider: {config.provider}")
|
|
|
|
app = FastAPI(
|
|
title="Agentic RAG API",
|
|
description="Agentic RAG application for manufacturing standards and regulations",
|
|
version="0.1.0",
|
|
lifespan=lifespan
|
|
)
|
|
|
|
# Add error handling middleware
|
|
app.add_middleware(ErrorMiddleware)
|
|
|
|
# Add CORS middleware
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=config.app.cors_origins,
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
# Define routes
|
|
@app.post("/api/chat")
|
|
async def chat_endpoint(request: ChatRequest):
|
|
"""Main chat endpoint with SSE streaming"""
|
|
try:
|
|
return StreamingResponse(
|
|
stream_chat_response(request),
|
|
media_type="text/event-stream",
|
|
headers={
|
|
"Cache-Control": "no-cache",
|
|
"Connection": "keep-alive",
|
|
"Access-Control-Allow-Origin": "*",
|
|
"Access-Control-Allow-Headers": "*",
|
|
}
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Chat endpoint error: {e}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
@app.post("/api/ai-sdk/chat")
|
|
async def ai_sdk_chat_endpoint(request: ChatRequest):
|
|
"""AI SDK compatible chat endpoint"""
|
|
try:
|
|
# Import here to avoid circular imports
|
|
from .ai_sdk_chat import handle_ai_sdk_chat
|
|
return await handle_ai_sdk_chat(request, app.state)
|
|
except Exception as e:
|
|
logger.error(f"AI SDK chat endpoint error: {e}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
@app.get("/health")
|
|
async def health_check():
|
|
"""Health check endpoint"""
|
|
return {"status": "healthy", "service": "agentic-rag"}
|
|
|
|
@app.get("/")
|
|
async def root():
|
|
"""Root endpoint"""
|
|
return {"message": "Agentic RAG API for Manufacturing Standards & Regulations"}
|
|
|
|
return app
|
|
|
|
|
|
# Create the global app instance for uvicorn
|
|
app = create_app()
|
|
|
|
|
|
@handle_async_errors(ErrorCategory.LLM, ErrorCode.LLM_ERROR)
|
|
async def stream_chat_response(request: ChatRequest) -> AsyncGenerator[str, None]:
|
|
"""Stream chat response with enhanced error handling"""
|
|
config = get_config()
|
|
memory_manager = app.state.memory_manager
|
|
graph = app.state.graph
|
|
|
|
# Create conversation state
|
|
state = TurnState(session_id=request.session_id)
|
|
|
|
# Add user message
|
|
if request.messages:
|
|
last_message = request.messages[-1]
|
|
if last_message.get("role") == "user":
|
|
user_message = Message(
|
|
role="user",
|
|
content=last_message.get("content", "")
|
|
)
|
|
state.messages.append(user_message)
|
|
|
|
# Create event queue for streaming
|
|
event_queue = asyncio.Queue()
|
|
|
|
async def stream_callback(event_str: str):
|
|
await event_queue.put(event_str)
|
|
|
|
# Execute workflow in background task
|
|
async def run_workflow():
|
|
try:
|
|
async for _ in graph.astream(state, stream_callback):
|
|
pass
|
|
await event_queue.put(None) # Signal completion
|
|
except Exception as e:
|
|
logger.error("Workflow execution failed", error=e,
|
|
category=ErrorCategory.LLM, error_code=ErrorCode.LLM_ERROR)
|
|
await event_queue.put(create_error_event("Processing error: AI service is temporarily unavailable"))
|
|
await event_queue.put(None)
|
|
|
|
# Start workflow task
|
|
workflow_task = asyncio.create_task(run_workflow())
|
|
|
|
# Stream events as they come
|
|
try:
|
|
while True:
|
|
event = await event_queue.get()
|
|
if event is None: # Completion signal
|
|
break
|
|
yield event
|
|
finally:
|
|
if not workflow_task.done():
|
|
workflow_task.cancel()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
config = load_config() # Load configuration first
|
|
uvicorn.run(
|
|
"service.main:app",
|
|
host=config.app.host,
|
|
port=config.app.port,
|
|
reload=True,
|
|
log_level="info"
|
|
)
|