115 lines
3.5 KiB
Python
115 lines
3.5 KiB
Python
|
|
import uuid
|
||
|
|
import logging
|
||
|
|
from pathlib import Path
|
||
|
|
|
||
|
|
from fastapi import APIRouter, Depends, UploadFile, File, Form, HTTPException, BackgroundTasks
|
||
|
|
from pydantic import BaseModel
|
||
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
|
from sqlalchemy import select
|
||
|
|
|
||
|
|
from ..core.deps import get_db
|
||
|
|
from ..models.db import Workspace, File as FileRecord, Task
|
||
|
|
from ..services.rag import hybrid_search, rerank, generate_answer
|
||
|
|
from ..worker import process_file_task
|
||
|
|
|
||
|
|
logger = logging.getLogger(__name__)
|
||
|
|
router = APIRouter(prefix="/api/kb", tags=["知识库"])
|
||
|
|
|
||
|
|
UPLOAD_DIR = Path("/app/data/uploads")
|
||
|
|
UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
|
||
|
|
|
||
|
|
|
||
|
|
class WorkspaceCreate(BaseModel):
|
||
|
|
name: str
|
||
|
|
description: str = ""
|
||
|
|
domain: str = "general"
|
||
|
|
|
||
|
|
|
||
|
|
class QARequest(BaseModel):
|
||
|
|
query: str
|
||
|
|
workspace_id: str | None = None
|
||
|
|
top_k: int = 5
|
||
|
|
return_sources: bool = True
|
||
|
|
|
||
|
|
|
||
|
|
@router.post("/workspaces")
|
||
|
|
async def create_workspace(req: WorkspaceCreate, db: AsyncSession = Depends(get_db)):
|
||
|
|
ws = Workspace(name=req.name, description=req.description, domain=req.domain)
|
||
|
|
db.add(ws)
|
||
|
|
await db.flush()
|
||
|
|
return {"id": str(ws.id), "name": ws.name, "domain": ws.domain}
|
||
|
|
|
||
|
|
|
||
|
|
@router.post("/files/upload")
|
||
|
|
async def upload_file(
|
||
|
|
background_tasks: BackgroundTasks,
|
||
|
|
file: UploadFile = File(...),
|
||
|
|
workspace_id: str = Form(default=""),
|
||
|
|
db: AsyncSession = Depends(get_db),
|
||
|
|
):
|
||
|
|
content = await file.read()
|
||
|
|
file_id = str(uuid.uuid4())
|
||
|
|
suffix = Path(file.filename or "doc").suffix
|
||
|
|
save_path = UPLOAD_DIR / f"{file_id}{suffix}"
|
||
|
|
save_path.write_bytes(content)
|
||
|
|
|
||
|
|
file_record = FileRecord(
|
||
|
|
id=uuid.UUID(file_id),
|
||
|
|
filename=f"{file_id}{suffix}",
|
||
|
|
original_name=file.filename or "unknown",
|
||
|
|
file_type=suffix.lstrip("."),
|
||
|
|
file_size=len(content),
|
||
|
|
storage_path=str(save_path),
|
||
|
|
workspace_id=uuid.UUID(workspace_id) if workspace_id else None,
|
||
|
|
status="uploaded",
|
||
|
|
)
|
||
|
|
db.add(file_record)
|
||
|
|
|
||
|
|
task = Task(
|
||
|
|
task_type="parse_and_vectorize",
|
||
|
|
status="pending",
|
||
|
|
file_id=uuid.UUID(file_id),
|
||
|
|
payload={"workspace_id": workspace_id},
|
||
|
|
)
|
||
|
|
db.add(task)
|
||
|
|
await db.flush()
|
||
|
|
|
||
|
|
# 异步触发 Celery 任务
|
||
|
|
celery_task = process_file_task.delay(file_id, str(task.id), workspace_id)
|
||
|
|
task.celery_task_id = celery_task.id
|
||
|
|
await db.flush()
|
||
|
|
|
||
|
|
return {"file_id": file_id, "task_id": str(task.id), "status": "processing"}
|
||
|
|
|
||
|
|
|
||
|
|
@router.get("/tasks/{task_id}")
|
||
|
|
async def get_task(task_id: str, db: AsyncSession = Depends(get_db)):
|
||
|
|
result = await db.execute(select(Task).where(Task.id == uuid.UUID(task_id)))
|
||
|
|
task = result.scalar_one_or_none()
|
||
|
|
if not task:
|
||
|
|
raise HTTPException(status_code=404, detail="任务不存在")
|
||
|
|
return {
|
||
|
|
"task_id": str(task.id),
|
||
|
|
"status": task.status,
|
||
|
|
"progress": task.progress,
|
||
|
|
"file_id": str(task.file_id) if task.file_id else None,
|
||
|
|
"error_msg": task.error_msg,
|
||
|
|
"completed_at": task.completed_at.isoformat() if task.completed_at else None,
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
@router.post("/qa")
|
||
|
|
async def qa(req: QARequest):
|
||
|
|
chunks = await hybrid_search(req.query, workspace_id=req.workspace_id, top_k=req.top_k * 2)
|
||
|
|
ranked = await rerank(req.query, chunks, top_k=req.top_k)
|
||
|
|
result = await generate_answer(req.query, ranked)
|
||
|
|
if not req.return_sources:
|
||
|
|
result.pop("sources", None)
|
||
|
|
return result
|
||
|
|
|
||
|
|
|
||
|
|
@router.post("/knowledge/retrieval")
|
||
|
|
async def retrieval(req: QARequest):
|
||
|
|
chunks = await hybrid_search(req.query, workspace_id=req.workspace_id, top_k=req.top_k)
|
||
|
|
return {"chunks": chunks, "total": len(chunks)}
|