first commit
This commit is contained in:
114
services/compliance-backend/app/api/kb.py
Normal file
114
services/compliance-backend/app/api/kb.py
Normal file
@@ -0,0 +1,114 @@
|
||||
import uuid
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import APIRouter, Depends, UploadFile, File, Form, HTTPException, BackgroundTasks
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
|
||||
from ..core.deps import get_db
|
||||
from ..models.db import Workspace, File as FileRecord, Task
|
||||
from ..services.rag import hybrid_search, rerank, generate_answer
|
||||
from ..worker import process_file_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/kb", tags=["知识库"])
|
||||
|
||||
UPLOAD_DIR = Path("/app/data/uploads")
|
||||
UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
class WorkspaceCreate(BaseModel):
|
||||
name: str
|
||||
description: str = ""
|
||||
domain: str = "general"
|
||||
|
||||
|
||||
class QARequest(BaseModel):
|
||||
query: str
|
||||
workspace_id: str | None = None
|
||||
top_k: int = 5
|
||||
return_sources: bool = True
|
||||
|
||||
|
||||
@router.post("/workspaces")
|
||||
async def create_workspace(req: WorkspaceCreate, db: AsyncSession = Depends(get_db)):
|
||||
ws = Workspace(name=req.name, description=req.description, domain=req.domain)
|
||||
db.add(ws)
|
||||
await db.flush()
|
||||
return {"id": str(ws.id), "name": ws.name, "domain": ws.domain}
|
||||
|
||||
|
||||
@router.post("/files/upload")
|
||||
async def upload_file(
|
||||
background_tasks: BackgroundTasks,
|
||||
file: UploadFile = File(...),
|
||||
workspace_id: str = Form(default=""),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
content = await file.read()
|
||||
file_id = str(uuid.uuid4())
|
||||
suffix = Path(file.filename or "doc").suffix
|
||||
save_path = UPLOAD_DIR / f"{file_id}{suffix}"
|
||||
save_path.write_bytes(content)
|
||||
|
||||
file_record = FileRecord(
|
||||
id=uuid.UUID(file_id),
|
||||
filename=f"{file_id}{suffix}",
|
||||
original_name=file.filename or "unknown",
|
||||
file_type=suffix.lstrip("."),
|
||||
file_size=len(content),
|
||||
storage_path=str(save_path),
|
||||
workspace_id=uuid.UUID(workspace_id) if workspace_id else None,
|
||||
status="uploaded",
|
||||
)
|
||||
db.add(file_record)
|
||||
|
||||
task = Task(
|
||||
task_type="parse_and_vectorize",
|
||||
status="pending",
|
||||
file_id=uuid.UUID(file_id),
|
||||
payload={"workspace_id": workspace_id},
|
||||
)
|
||||
db.add(task)
|
||||
await db.flush()
|
||||
|
||||
# 异步触发 Celery 任务
|
||||
celery_task = process_file_task.delay(file_id, str(task.id), workspace_id)
|
||||
task.celery_task_id = celery_task.id
|
||||
await db.flush()
|
||||
|
||||
return {"file_id": file_id, "task_id": str(task.id), "status": "processing"}
|
||||
|
||||
|
||||
@router.get("/tasks/{task_id}")
|
||||
async def get_task(task_id: str, db: AsyncSession = Depends(get_db)):
|
||||
result = await db.execute(select(Task).where(Task.id == uuid.UUID(task_id)))
|
||||
task = result.scalar_one_or_none()
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail="任务不存在")
|
||||
return {
|
||||
"task_id": str(task.id),
|
||||
"status": task.status,
|
||||
"progress": task.progress,
|
||||
"file_id": str(task.file_id) if task.file_id else None,
|
||||
"error_msg": task.error_msg,
|
||||
"completed_at": task.completed_at.isoformat() if task.completed_at else None,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/qa")
|
||||
async def qa(req: QARequest):
|
||||
chunks = await hybrid_search(req.query, workspace_id=req.workspace_id, top_k=req.top_k * 2)
|
||||
ranked = await rerank(req.query, chunks, top_k=req.top_k)
|
||||
result = await generate_answer(req.query, ranked)
|
||||
if not req.return_sources:
|
||||
result.pop("sources", None)
|
||||
return result
|
||||
|
||||
|
||||
@router.post("/knowledge/retrieval")
|
||||
async def retrieval(req: QARequest):
|
||||
chunks = await hybrid_search(req.query, workspace_id=req.workspace_id, top_k=req.top_k)
|
||||
return {"chunks": chunks, "total": len(chunks)}
|
||||
Reference in New Issue
Block a user