import asyncio import logging from typing import AsyncGenerator from contextlib import asynccontextmanager from fastapi import FastAPI, HTTPException from fastapi.responses import StreamingResponse from fastapi.middleware.cors import CORSMiddleware import uvicorn from .config import load_config, get_config from .schemas.messages import ChatRequest from .memory.postgresql_memory import get_memory_manager from .graph.state import TurnState, Message from .graph.graph import build_graph from .sse import create_error_event from .utils.error_handler import StructuredLogger, ErrorCategory, ErrorCode, handle_async_errors from .utils.middleware import ErrorMiddleware # Setup logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = StructuredLogger(__name__) @asynccontextmanager async def lifespan(app: FastAPI): """Application lifespan manager""" # Startup try: logger.info("Starting application initialization...") # Initialize PostgreSQL memory manager memory_manager = get_memory_manager() connection_ok = memory_manager.test_connection() logger.info(f"PostgreSQL memory manager initialized (connected: {connection_ok})") # Initialize global components app.state.memory_manager = memory_manager app.state.graph = build_graph() logger.info("Application startup complete") yield except Exception as e: logger.error(f"Failed to start application: {e}") raise finally: # Shutdown logger.info("Application shutdown") def create_app() -> FastAPI: """Application factory""" # Load configuration first config = load_config() logger.info(f"Loaded configuration for provider: {config.provider}") app = FastAPI( title="Agentic RAG API", description="Agentic RAG application for manufacturing standards and regulations", version="0.1.0", lifespan=lifespan ) # Add error handling middleware app.add_middleware(ErrorMiddleware) # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=config.app.cors_origins, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Define routes @app.post("/api/chat") async def chat_endpoint(request: ChatRequest): """Main chat endpoint with SSE streaming""" try: return StreamingResponse( stream_chat_response(request), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "Access-Control-Allow-Origin": "*", "Access-Control-Allow-Headers": "*", } ) except Exception as e: logger.error(f"Chat endpoint error: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.post("/api/ai-sdk/chat") async def ai_sdk_chat_endpoint(request: ChatRequest): """AI SDK compatible chat endpoint""" try: # Import here to avoid circular imports from .ai_sdk_chat import handle_ai_sdk_chat return await handle_ai_sdk_chat(request, app.state) except Exception as e: logger.error(f"AI SDK chat endpoint error: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.get("/health") async def health_check(): """Health check endpoint""" return {"status": "healthy", "service": "agentic-rag"} @app.get("/") async def root(): """Root endpoint""" return {"message": "Agentic RAG API for Manufacturing Standards & Regulations"} return app # Create the global app instance for uvicorn app = create_app() @handle_async_errors(ErrorCategory.LLM, ErrorCode.LLM_ERROR) async def stream_chat_response(request: ChatRequest) -> AsyncGenerator[str, None]: """Stream chat response with enhanced error handling""" config = get_config() memory_manager = app.state.memory_manager graph = app.state.graph # Create conversation state state = TurnState(session_id=request.session_id) # Add user message if request.messages: last_message = request.messages[-1] if last_message.get("role") == "user": user_message = Message( role="user", content=last_message.get("content", "") ) state.messages.append(user_message) # Create event queue for streaming event_queue = asyncio.Queue() async def stream_callback(event_str: str): await event_queue.put(event_str) # Execute workflow in background task async def run_workflow(): try: async for _ in graph.astream(state, stream_callback): pass await event_queue.put(None) # Signal completion except Exception as e: logger.error("Workflow execution failed", error=e, category=ErrorCategory.LLM, error_code=ErrorCode.LLM_ERROR) await event_queue.put(create_error_event("Processing error: AI service is temporarily unavailable")) await event_queue.put(None) # Start workflow task workflow_task = asyncio.create_task(run_workflow()) # Stream events as they come try: while True: event = await event_queue.get() if event is None: # Completion signal break yield event finally: if not workflow_task.done(): workflow_task.cancel() if __name__ == "__main__": config = load_config() # Load configuration first uvicorn.run( "service.main:app", host=config.app.host, port=config.app.port, reload=True, log_level="info" )