273 lines
10 KiB
Python
273 lines
10 KiB
Python
"""Define API routes for documents."""
|
||
|
||
from __future__ import annotations
|
||
|
||
from io import BytesIO
|
||
from urllib.parse import quote
|
||
|
||
from fastapi import APIRouter, BackgroundTasks, Depends, File, Form, HTTPException, UploadFile
|
||
from fastapi.responses import StreamingResponse
|
||
from loguru import logger
|
||
|
||
from app.api.dependencies.auth import get_current_user
|
||
from app.api.models import DocumentUploadResponse
|
||
from app.application.documents import DocumentProcessResult
|
||
from app.config.settings import settings
|
||
from app.domain.auth.models import UserClaims
|
||
from app.shared.bootstrap import get_document_command_service, get_document_query_service
|
||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||
|
||
|
||
router = APIRouter(prefix="/documents", tags=["documents"])
|
||
|
||
|
||
def _document_response(result: DocumentProcessResult) -> DocumentUploadResponse:
|
||
"""Handle document response for this module."""
|
||
return DocumentUploadResponse(
|
||
doc_id=result.doc_id,
|
||
doc_name=result.doc_name,
|
||
status=result.status,
|
||
message=result.message,
|
||
num_chunks=result.num_chunks,
|
||
summary=result.summary,
|
||
summary_latency_ms=result.summary_latency_ms,
|
||
)
|
||
|
||
|
||
def _run_process_in_background(
|
||
*,
|
||
doc_id: str,
|
||
file_name: str,
|
||
final_doc_name: str,
|
||
content: bytes,
|
||
regulation_type: str,
|
||
version: str,
|
||
generate_summary: bool,
|
||
run_id: str | None,
|
||
) -> None:
|
||
"""Run document processing synchronously inside a FastAPI BackgroundTask thread.
|
||
|
||
FastAPI executes BackgroundTasks in a threadpool executor, so blocking I/O
|
||
(parser API calls, embedding, Milvus upsert) is safe here.
|
||
"""
|
||
try:
|
||
svc = get_document_command_service()
|
||
svc._process_document(
|
||
doc_id=doc_id,
|
||
file_name=file_name,
|
||
final_doc_name=final_doc_name,
|
||
content=content,
|
||
regulation_type=regulation_type,
|
||
version=version,
|
||
generate_summary=generate_summary,
|
||
run_id=run_id,
|
||
)
|
||
except Exception:
|
||
logger.exception("BackgroundTask document processing failed: doc_id={}", doc_id)
|
||
|
||
|
||
@router.post("/upload", response_model=DocumentUploadResponse)
|
||
async def upload_document(
|
||
background_tasks: BackgroundTasks,
|
||
file: UploadFile = File(..., description="上传的文档文件"),
|
||
doc_id: str | None = Form(None, description="客户端预分配的文档ID,不传则自动生成"),
|
||
doc_name: str | None = Form(None, description="文档名称"),
|
||
regulation_type: str | None = Form(None, description="法规类型"),
|
||
version: str | None = Form(None, description="文档版本"),
|
||
generate_summary: bool = Form(False, description="是否生成摘要"),
|
||
sync: bool = Form(False, description="同步处理(演示/测试用,默认异步处理)"),
|
||
current_user: UserClaims = Depends(get_current_user),
|
||
):
|
||
"""Upload a document and process it asynchronously.
|
||
|
||
Default path (sync=false):
|
||
1. Store binary to MinIO immediately — returns within seconds.
|
||
2. Schedule parse→embed→index as a FastAPI BackgroundTask (same process,
|
||
threadpool) OR enqueue to Celery workers when USE_CELERY_WORKER=true.
|
||
3. Poll GET /documents/status/{doc_id} for progress.
|
||
|
||
sync=true path: full inline processing, blocks until complete (demo / CI use).
|
||
"""
|
||
content = await file.read()
|
||
if not file.filename:
|
||
raise HTTPException(status_code=400, detail="文件名不能为空")
|
||
if not content:
|
||
raise HTTPException(status_code=400, detail="上传文件为空")
|
||
|
||
try:
|
||
svc = get_document_command_service()
|
||
|
||
if sync:
|
||
# Synchronous fallback: full inline processing.
|
||
result = svc.upload_and_process(
|
||
doc_id=doc_id,
|
||
file_name=file.filename,
|
||
content=content,
|
||
content_type=file.content_type or "application/octet-stream",
|
||
doc_name=doc_name,
|
||
regulation_type=regulation_type or "",
|
||
version=version or "",
|
||
generate_summary=generate_summary,
|
||
)
|
||
else:
|
||
# Step 1: store binary and create the document record (fast, sync).
|
||
stored_doc_id, run_id = svc.store_document(
|
||
doc_id=doc_id,
|
||
file_name=file.filename,
|
||
content=content,
|
||
content_type=file.content_type or "application/octet-stream",
|
||
doc_name=doc_name,
|
||
regulation_type=regulation_type or "",
|
||
version=version or "",
|
||
generate_summary=generate_summary,
|
||
)
|
||
final_doc_name = doc_name or file.filename
|
||
|
||
# Step 2: schedule processing via Celery worker OR FastAPI BackgroundTask.
|
||
if settings.use_celery_worker:
|
||
from app.infrastructure.tasks.document_tasks import process_document_task
|
||
process_document_task.delay(
|
||
doc_id=stored_doc_id,
|
||
file_name=file.filename,
|
||
doc_name=final_doc_name,
|
||
regulation_type=regulation_type or "",
|
||
version=version or "",
|
||
generate_summary=generate_summary,
|
||
run_id=run_id,
|
||
)
|
||
processing_note = "已入 Celery 队列,由 Worker 处理。"
|
||
else:
|
||
# Default: run in FastAPI's threadpool — no external worker needed.
|
||
background_tasks.add_task(
|
||
_run_process_in_background,
|
||
doc_id=stored_doc_id,
|
||
file_name=file.filename,
|
||
final_doc_name=final_doc_name,
|
||
content=content,
|
||
regulation_type=regulation_type or "",
|
||
version=version or "",
|
||
generate_summary=generate_summary,
|
||
run_id=run_id,
|
||
)
|
||
processing_note = "正在后台处理。"
|
||
|
||
result = DocumentProcessResult(
|
||
doc_id=stored_doc_id,
|
||
doc_name=final_doc_name,
|
||
status="stored",
|
||
message=f"文件已存储,{processing_note}请轮询 GET /documents/status/{{doc_id}} 查看进度。",
|
||
)
|
||
|
||
if result.status == "failed":
|
||
raise HTTPException(status_code=500, detail=result.message)
|
||
return _document_response(result)
|
||
|
||
except HTTPException:
|
||
raise
|
||
except Exception as exc:
|
||
logger.exception("文档上传失败")
|
||
raise HTTPException(status_code=500, detail=str(exc))
|
||
|
||
|
||
@router.get("/status/{doc_id}", response_model=DocumentUploadResponse)
|
||
async def get_document_status(doc_id: str):
|
||
"""Return document status."""
|
||
document = get_document_query_service().get(doc_id)
|
||
if not document:
|
||
raise HTTPException(status_code=404, detail="文档不存在")
|
||
return DocumentUploadResponse(
|
||
doc_id=document.doc_id,
|
||
doc_name=document.doc_name,
|
||
status=document.status.value,
|
||
message=document.error_message or "查询成功",
|
||
num_chunks=document.chunk_count,
|
||
summary=document.summary,
|
||
summary_latency_ms=document.summary_latency_ms,
|
||
regulation_type=document.regulation_type,
|
||
version=document.version,
|
||
)
|
||
|
||
|
||
@router.get("/download/{doc_id}")
|
||
async def download_document(doc_id: str):
|
||
"""Handle download document."""
|
||
try:
|
||
document, file_data = get_document_query_service().download(doc_id)
|
||
encoded_name = quote(document.file_name)
|
||
return StreamingResponse(
|
||
BytesIO(file_data),
|
||
media_type=document.content_type or "application/octet-stream",
|
||
headers={"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_name}"},
|
||
)
|
||
except FileNotFoundError as exc:
|
||
raise HTTPException(status_code=404, detail=str(exc))
|
||
except Exception as exc:
|
||
logger.exception("文档下载失败")
|
||
raise HTTPException(status_code=500, detail=str(exc))
|
||
|
||
|
||
@router.get("/list")
|
||
async def list_documents(current_user: UserClaims = Depends(get_current_user)):
|
||
"""List documents."""
|
||
documents = get_document_query_service().list_documents()
|
||
return {
|
||
"documents": [
|
||
{
|
||
"doc_id": item.doc_id,
|
||
"doc_name": item.doc_name,
|
||
"status": item.status.value,
|
||
"chunk_count": item.chunk_count,
|
||
"updated_at": item.updated_at.isoformat(),
|
||
}
|
||
for item in documents
|
||
],
|
||
"total": len(documents),
|
||
}
|
||
|
||
|
||
@router.get("/management-list")
|
||
async def get_document_management_list():
|
||
"""Return document management list."""
|
||
documents = get_document_query_service().list_documents()
|
||
return {
|
||
"documents": [
|
||
{
|
||
"doc_id": item.doc_id,
|
||
"doc_name": item.doc_name,
|
||
"status": item.status.value,
|
||
"chunk_count": item.chunk_count,
|
||
"size_bytes": item.size_bytes,
|
||
"summary": item.summary,
|
||
"updated_at": item.updated_at.isoformat(),
|
||
"regulation_type": item.regulation_type,
|
||
"version": item.version,
|
||
}
|
||
for item in documents
|
||
],
|
||
"total": len(documents),
|
||
}
|
||
|
||
|
||
@router.delete("/{doc_id}")
|
||
async def delete_document(doc_id: str, current_user: UserClaims = Depends(get_current_user)):
|
||
"""Delete a document and its associated data."""
|
||
deleted = get_document_command_service().delete(doc_id)
|
||
if not deleted:
|
||
raise HTTPException(status_code=404, detail="文档不存在")
|
||
return {"doc_id": doc_id, "deleted": True}
|
||
|
||
|
||
@router.post("/{doc_id}/retry", response_model=DocumentUploadResponse)
|
||
async def retry_document(doc_id: str):
|
||
"""Re-process a failed document."""
|
||
try:
|
||
result = get_document_command_service().retry(doc_id)
|
||
if result.status == "failed":
|
||
raise HTTPException(status_code=500, detail=result.message)
|
||
return _document_response(result)
|
||
except HTTPException:
|
||
raise
|
||
except Exception as exc:
|
||
logger.exception("文档重试失败")
|
||
raise HTTPException(status_code=500, detail=str(exc))
|