import uuid import logging from pathlib import Path from fastapi import APIRouter, Depends, UploadFile, File, Form, HTTPException, BackgroundTasks from pydantic import BaseModel from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select from ..core.deps import get_db from ..models.db import Workspace, File as FileRecord, Task from ..services.rag import hybrid_search, rerank, generate_answer from ..worker import process_file_task logger = logging.getLogger(__name__) router = APIRouter(prefix="/api/kb", tags=["知识库"]) UPLOAD_DIR = Path("/app/data/uploads") UPLOAD_DIR.mkdir(parents=True, exist_ok=True) class WorkspaceCreate(BaseModel): name: str description: str = "" domain: str = "general" class QARequest(BaseModel): query: str workspace_id: str | None = None top_k: int = 5 return_sources: bool = True @router.post("/workspaces") async def create_workspace(req: WorkspaceCreate, db: AsyncSession = Depends(get_db)): ws = Workspace(name=req.name, description=req.description, domain=req.domain) db.add(ws) await db.flush() return {"id": str(ws.id), "name": ws.name, "domain": ws.domain} @router.post("/files/upload") async def upload_file( background_tasks: BackgroundTasks, file: UploadFile = File(...), workspace_id: str = Form(default=""), db: AsyncSession = Depends(get_db), ): content = await file.read() file_id = str(uuid.uuid4()) suffix = Path(file.filename or "doc").suffix save_path = UPLOAD_DIR / f"{file_id}{suffix}" save_path.write_bytes(content) file_record = FileRecord( id=uuid.UUID(file_id), filename=f"{file_id}{suffix}", original_name=file.filename or "unknown", file_type=suffix.lstrip("."), file_size=len(content), storage_path=str(save_path), workspace_id=uuid.UUID(workspace_id) if workspace_id else None, status="uploaded", ) db.add(file_record) task = Task( task_type="parse_and_vectorize", status="pending", file_id=uuid.UUID(file_id), payload={"workspace_id": workspace_id}, ) db.add(task) await db.flush() # 异步触发 Celery 任务 celery_task = process_file_task.delay(file_id, str(task.id), workspace_id) task.celery_task_id = celery_task.id await db.flush() return {"file_id": file_id, "task_id": str(task.id), "status": "processing"} @router.get("/tasks/{task_id}") async def get_task(task_id: str, db: AsyncSession = Depends(get_db)): result = await db.execute(select(Task).where(Task.id == uuid.UUID(task_id))) task = result.scalar_one_or_none() if not task: raise HTTPException(status_code=404, detail="任务不存在") return { "task_id": str(task.id), "status": task.status, "progress": task.progress, "file_id": str(task.file_id) if task.file_id else None, "error_msg": task.error_msg, "completed_at": task.completed_at.isoformat() if task.completed_at else None, } @router.post("/qa") async def qa(req: QARequest): chunks = await hybrid_search(req.query, workspace_id=req.workspace_id, top_k=req.top_k * 2) ranked = await rerank(req.query, chunks, top_k=req.top_k) result = await generate_answer(req.query, ranked) if not req.return_sources: result.pop("sources", None) return result @router.post("/knowledge/retrieval") async def retrieval(req: QARequest): chunks = await hybrid_search(req.query, workspace_id=req.workspace_id, top_k=req.top_k) return {"chunks": chunks, "total": len(chunks)}