299 lines
8.2 KiB
Python
299 lines
8.2 KiB
Python
"""文档管理 API"""
|
||
|
||
from fastapi import APIRouter, UploadFile, File, HTTPException, BackgroundTasks
|
||
import os
|
||
import uuid
|
||
from datetime import datetime
|
||
from typing import Optional
|
||
|
||
from app.schemas.doc import (
|
||
DocumentUploadResponse,
|
||
DocumentListResponse,
|
||
DocumentInfo,
|
||
ParseResponse,
|
||
EmbedResponse,
|
||
TaskStatusResponse,
|
||
)
|
||
from app.core.config import settings
|
||
from app.services.minio import minio_service
|
||
from app.services.database import db_service, init_db, DocStatus
|
||
from app.services.tasks import generate_task_id, task_manager, get_task_status
|
||
from app.workflows.document_workflow import (
|
||
generate_doc_id,
|
||
run_parse_workflow,
|
||
run_embedding_workflow,
|
||
)
|
||
from app.utils.logger import logger
|
||
|
||
|
||
router = APIRouter(prefix="/docs", tags=["文档管理"])
|
||
|
||
# 启动时初始化数据库
|
||
init_db()
|
||
|
||
|
||
def get_content_type(filename: str) -> str:
|
||
"""根据文件扩展名获取 Content-Type"""
|
||
ext = os.path.splitext(filename)[1].lower()
|
||
content_types = {
|
||
".pdf": "application/pdf",
|
||
".docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||
".doc": "application/msword",
|
||
".txt": "text/plain",
|
||
}
|
||
return content_types.get(ext, "application/octet-stream")
|
||
|
||
|
||
@router.post("/upload", response_model=DocumentUploadResponse)
|
||
async def upload_document(
|
||
file: UploadFile = File(...),
|
||
background_tasks: BackgroundTasks = None,
|
||
):
|
||
"""
|
||
上传法规文档到 MinIO,并自动触发异步解析
|
||
|
||
流程:
|
||
1. 验证文件格式
|
||
2. 生成文档ID
|
||
3. 上传到 MinIO
|
||
4. 创建数据库记录
|
||
5. 触发异步解析任务(后续可替换为 RabbitMQ)
|
||
"""
|
||
# 检查文件格式
|
||
allowed_ext = [".pdf", ".docx", ".doc", ".txt"]
|
||
ext = os.path.splitext(file.filename)[1].lower()
|
||
if ext not in allowed_ext:
|
||
raise HTTPException(400, f"Unsupported file format: {ext}")
|
||
|
||
# 检查文件大小
|
||
content = await file.read()
|
||
max_size = 50 * 1024 * 1024 # 50MB
|
||
if len(content) > max_size:
|
||
raise HTTPException(400, f"File size exceeds limit: {max_size // 1024 // 1024}MB")
|
||
|
||
# 生成文档ID
|
||
doc_id = generate_doc_id()
|
||
|
||
# 构建 MinIO 存储路径
|
||
storage_filename = f"{doc_id}_{file.filename}"
|
||
minio_path = f"documents/{storage_filename}"
|
||
|
||
try:
|
||
# 上传到 MinIO
|
||
content_type = get_content_type(file.filename)
|
||
minio_url = minio_service.upload_file(
|
||
minio_path,
|
||
content,
|
||
content_type,
|
||
)
|
||
|
||
# 创建数据库记录
|
||
doc = db_service.create_document(
|
||
doc_id=doc_id,
|
||
filename=storage_filename,
|
||
original_name=file.filename,
|
||
minio_path=minio_path,
|
||
size=len(content),
|
||
)
|
||
|
||
logger.info(f"Document uploaded: {doc_id} - {file.filename}")
|
||
|
||
# 触发异步解析任务
|
||
parse_task_id = generate_task_id()
|
||
db_service.create_parse_task(parse_task_id, doc_id)
|
||
|
||
# 使用 asyncio 异步执行解析(后续替换为 RabbitMQ)
|
||
background_tasks.add_task(
|
||
run_parse_workflow_sync,
|
||
parse_task_id,
|
||
doc_id,
|
||
)
|
||
|
||
return DocumentUploadResponse(
|
||
doc_id=doc_id,
|
||
filename=file.filename,
|
||
size=len(content),
|
||
status="uploaded",
|
||
parse_task_id=parse_task_id,
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"Upload failed: {e}")
|
||
raise HTTPException(500, f"Upload failed: {str(e)}")
|
||
|
||
|
||
def run_parse_workflow_sync(task_id: str, doc_id: str):
|
||
"""同步包装器,用于 BackgroundTasks"""
|
||
import asyncio
|
||
asyncio.run(run_parse_workflow(task_id, doc_id))
|
||
|
||
|
||
@router.get("/list", response_model=DocumentListResponse)
|
||
async def list_documents():
|
||
"""获取已索引文档列表"""
|
||
docs = db_service.list_documents()
|
||
return DocumentListResponse(
|
||
docs=[
|
||
DocumentInfo(
|
||
id=d.id,
|
||
name=d.original_name,
|
||
chunks=d.chunks,
|
||
status=d.status,
|
||
created_at=d.created_at,
|
||
)
|
||
for d in docs
|
||
]
|
||
)
|
||
|
||
|
||
@router.get("/{doc_id}", response_model=DocumentInfo)
|
||
async def get_document(doc_id: str):
|
||
"""获取单个文档信息"""
|
||
doc = db_service.get_document(doc_id)
|
||
if not doc:
|
||
raise HTTPException(404, "Document not found")
|
||
|
||
return DocumentInfo(
|
||
id=doc.id,
|
||
name=doc.original_name,
|
||
chunks=doc.chunks,
|
||
status=doc.status,
|
||
created_at=doc.created_at,
|
||
)
|
||
|
||
|
||
@router.post("/parse/{doc_id}", response_model=ParseResponse)
|
||
async def parse_document(
|
||
doc_id: str,
|
||
background_tasks: BackgroundTasks = None,
|
||
):
|
||
"""
|
||
手动触发文档解析(如果文档已上传但未解析)
|
||
"""
|
||
doc = db_service.get_document(doc_id)
|
||
if not doc:
|
||
raise HTTPException(404, "Document not found")
|
||
|
||
if doc.status not in [DocStatus.uploaded.value, DocStatus.failed.value]:
|
||
raise HTTPException(400, f"Document status is {doc.status}, cannot parse")
|
||
|
||
# 创建解析任务
|
||
task_id = generate_task_id()
|
||
db_service.create_parse_task(task_id, doc_id)
|
||
|
||
# 异步执行
|
||
background_tasks.add_task(
|
||
run_parse_workflow_sync,
|
||
task_id,
|
||
doc_id,
|
||
)
|
||
|
||
return ParseResponse(
|
||
doc_id=doc_id,
|
||
task_id=task_id,
|
||
status="parsing",
|
||
)
|
||
|
||
|
||
@router.post("/embed/{doc_id}", response_model=EmbedResponse)
|
||
async def embed_document(
|
||
doc_id: str,
|
||
background_tasks: BackgroundTasks = None,
|
||
):
|
||
"""
|
||
触发文档向量化(需要文档已解析)
|
||
"""
|
||
doc = db_service.get_document(doc_id)
|
||
if not doc:
|
||
raise HTTPException(404, "Document not found")
|
||
|
||
if doc.status != DocStatus.parsed.value:
|
||
raise HTTPException(400, f"Document must be parsed first. Current status: {doc.status}")
|
||
|
||
# 创建向量化任务
|
||
task_id = generate_task_id()
|
||
db_service.create_parse_task(task_id, doc_id)
|
||
|
||
# 异步执行
|
||
background_tasks.add_task(
|
||
run_embedding_workflow_sync,
|
||
task_id,
|
||
doc_id,
|
||
)
|
||
|
||
return EmbedResponse(
|
||
doc_id=doc_id,
|
||
task_id=task_id,
|
||
status="embedding",
|
||
)
|
||
|
||
|
||
def run_embedding_workflow_sync(task_id: str, doc_id: str):
|
||
"""同步包装器,用于 BackgroundTasks"""
|
||
import asyncio
|
||
asyncio.run(run_embedding_workflow(task_id, doc_id))
|
||
|
||
|
||
@router.get("/task/{task_id}", response_model=TaskStatusResponse)
|
||
async def get_task_status_api(task_id: str):
|
||
"""获取任务状态"""
|
||
status = get_task_status(task_id)
|
||
if not status:
|
||
# 检查数据库中的任务记录
|
||
task = db_service.get_parse_task(task_id)
|
||
if task:
|
||
return TaskStatusResponse(
|
||
task_id=task_id,
|
||
status=task.status,
|
||
progress=task.progress or 0,
|
||
message=task.message,
|
||
)
|
||
raise HTTPException(404, "Task not found")
|
||
|
||
return TaskStatusResponse(
|
||
task_id=task_id,
|
||
status=status.get("status", "unknown"),
|
||
progress=status.get("progress", 0),
|
||
message=status.get("message"),
|
||
result=status.get("result"),
|
||
)
|
||
|
||
|
||
@router.delete("/delete/{doc_id}")
|
||
async def delete_document(doc_id: str):
|
||
"""
|
||
删除文档
|
||
|
||
同时删除:
|
||
- MinIO 中的文件
|
||
- 数据库中的记录
|
||
- 解析后的文本文件
|
||
"""
|
||
doc = db_service.get_document(doc_id)
|
||
if not doc:
|
||
raise HTTPException(404, "Document not found")
|
||
|
||
try:
|
||
# 删除 MinIO 文件
|
||
minio_service.delete_file(doc.minio_path)
|
||
|
||
# 删除本地解析文件
|
||
parsed_path = f"{settings.data_parsed_dir}/{doc_id}.txt"
|
||
if os.path.exists(parsed_path):
|
||
os.remove(parsed_path)
|
||
|
||
# 删除本地临时文件
|
||
temp_path = f"{settings.data_raw_dir}/{doc.filename}"
|
||
if os.path.exists(temp_path):
|
||
os.remove(temp_path)
|
||
|
||
# 删除数据库记录
|
||
db_service.delete_document(doc_id)
|
||
|
||
logger.info(f"Document deleted: {doc_id}")
|
||
|
||
return {"success": True, "doc_id": doc_id}
|
||
|
||
except Exception as e:
|
||
logger.error(f"Delete failed: {e}")
|
||
raise HTTPException(500, f"Delete failed: {str(e)}") |