first commit

This commit is contained in:
2026-04-23 09:58:47 +08:00
commit 448e078d99
49 changed files with 5188 additions and 0 deletions

View File

@@ -0,0 +1,114 @@
import uuid
import logging
from pathlib import Path
from fastapi import APIRouter, Depends, UploadFile, File, Form, HTTPException, BackgroundTasks
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from ..core.deps import get_db
from ..models.db import Workspace, File as FileRecord, Task
from ..services.rag import hybrid_search, rerank, generate_answer
from ..worker import process_file_task
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/kb", tags=["知识库"])
UPLOAD_DIR = Path("/app/data/uploads")
UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
class WorkspaceCreate(BaseModel):
name: str
description: str = ""
domain: str = "general"
class QARequest(BaseModel):
query: str
workspace_id: str | None = None
top_k: int = 5
return_sources: bool = True
@router.post("/workspaces")
async def create_workspace(req: WorkspaceCreate, db: AsyncSession = Depends(get_db)):
ws = Workspace(name=req.name, description=req.description, domain=req.domain)
db.add(ws)
await db.flush()
return {"id": str(ws.id), "name": ws.name, "domain": ws.domain}
@router.post("/files/upload")
async def upload_file(
background_tasks: BackgroundTasks,
file: UploadFile = File(...),
workspace_id: str = Form(default=""),
db: AsyncSession = Depends(get_db),
):
content = await file.read()
file_id = str(uuid.uuid4())
suffix = Path(file.filename or "doc").suffix
save_path = UPLOAD_DIR / f"{file_id}{suffix}"
save_path.write_bytes(content)
file_record = FileRecord(
id=uuid.UUID(file_id),
filename=f"{file_id}{suffix}",
original_name=file.filename or "unknown",
file_type=suffix.lstrip("."),
file_size=len(content),
storage_path=str(save_path),
workspace_id=uuid.UUID(workspace_id) if workspace_id else None,
status="uploaded",
)
db.add(file_record)
task = Task(
task_type="parse_and_vectorize",
status="pending",
file_id=uuid.UUID(file_id),
payload={"workspace_id": workspace_id},
)
db.add(task)
await db.flush()
# 异步触发 Celery 任务
celery_task = process_file_task.delay(file_id, str(task.id), workspace_id)
task.celery_task_id = celery_task.id
await db.flush()
return {"file_id": file_id, "task_id": str(task.id), "status": "processing"}
@router.get("/tasks/{task_id}")
async def get_task(task_id: str, db: AsyncSession = Depends(get_db)):
result = await db.execute(select(Task).where(Task.id == uuid.UUID(task_id)))
task = result.scalar_one_or_none()
if not task:
raise HTTPException(status_code=404, detail="任务不存在")
return {
"task_id": str(task.id),
"status": task.status,
"progress": task.progress,
"file_id": str(task.file_id) if task.file_id else None,
"error_msg": task.error_msg,
"completed_at": task.completed_at.isoformat() if task.completed_at else None,
}
@router.post("/qa")
async def qa(req: QARequest):
chunks = await hybrid_search(req.query, workspace_id=req.workspace_id, top_k=req.top_k * 2)
ranked = await rerank(req.query, chunks, top_k=req.top_k)
result = await generate_answer(req.query, ranked)
if not req.return_sources:
result.pop("sources", None)
return result
@router.post("/knowledge/retrieval")
async def retrieval(req: QARequest):
chunks = await hybrid_search(req.query, workspace_id=req.workspace_id, top_k=req.top_k)
return {"chunks": chunks, "total": len(chunks)}