"""文档管理 API""" from fastapi import APIRouter, UploadFile, File, HTTPException, BackgroundTasks import os import uuid from datetime import datetime from typing import Optional from app.schemas.doc import ( DocumentUploadResponse, DocumentListResponse, DocumentInfo, ParseResponse, EmbedResponse, TaskStatusResponse, ) from app.core.config import settings from app.services.minio import minio_service from app.services.database import db_service, init_db, DocStatus from app.services.tasks import generate_task_id, task_manager, get_task_status from app.workflows.document_workflow import ( generate_doc_id, run_parse_workflow, run_embedding_workflow, ) from app.utils.logger import logger router = APIRouter(prefix="/docs", tags=["文档管理"]) # 启动时初始化数据库 init_db() def get_content_type(filename: str) -> str: """根据文件扩展名获取 Content-Type""" ext = os.path.splitext(filename)[1].lower() content_types = { ".pdf": "application/pdf", ".docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document", ".doc": "application/msword", ".txt": "text/plain", } return content_types.get(ext, "application/octet-stream") @router.post("/upload", response_model=DocumentUploadResponse) async def upload_document( file: UploadFile = File(...), background_tasks: BackgroundTasks = None, ): """ 上传法规文档到 MinIO,并自动触发异步解析 流程: 1. 验证文件格式 2. 生成文档ID 3. 上传到 MinIO 4. 创建数据库记录 5. 触发异步解析任务(后续可替换为 RabbitMQ) """ # 检查文件格式 allowed_ext = [".pdf", ".docx", ".doc", ".txt"] ext = os.path.splitext(file.filename)[1].lower() if ext not in allowed_ext: raise HTTPException(400, f"Unsupported file format: {ext}") # 检查文件大小 content = await file.read() max_size = 50 * 1024 * 1024 # 50MB if len(content) > max_size: raise HTTPException(400, f"File size exceeds limit: {max_size // 1024 // 1024}MB") # 生成文档ID doc_id = generate_doc_id() # 构建 MinIO 存储路径 storage_filename = f"{doc_id}_{file.filename}" minio_path = f"documents/{storage_filename}" try: # 上传到 MinIO content_type = get_content_type(file.filename) minio_url = minio_service.upload_file( minio_path, content, content_type, ) # 创建数据库记录 doc = db_service.create_document( doc_id=doc_id, filename=storage_filename, original_name=file.filename, minio_path=minio_path, size=len(content), ) logger.info(f"Document uploaded: {doc_id} - {file.filename}") # 触发异步解析任务 parse_task_id = generate_task_id() db_service.create_parse_task(parse_task_id, doc_id) # 使用 asyncio 异步执行解析(后续替换为 RabbitMQ) background_tasks.add_task( run_parse_workflow_sync, parse_task_id, doc_id, ) return DocumentUploadResponse( doc_id=doc_id, filename=file.filename, size=len(content), status="uploaded", parse_task_id=parse_task_id, ) except Exception as e: logger.error(f"Upload failed: {e}") raise HTTPException(500, f"Upload failed: {str(e)}") def run_parse_workflow_sync(task_id: str, doc_id: str): """同步包装器,用于 BackgroundTasks""" import asyncio asyncio.run(run_parse_workflow(task_id, doc_id)) @router.get("/list", response_model=DocumentListResponse) async def list_documents(): """获取已索引文档列表""" docs = db_service.list_documents() return DocumentListResponse( docs=[ DocumentInfo( id=d.id, name=d.original_name, chunks=d.chunks, status=d.status, created_at=d.created_at, ) for d in docs ] ) @router.get("/{doc_id}", response_model=DocumentInfo) async def get_document(doc_id: str): """获取单个文档信息""" doc = db_service.get_document(doc_id) if not doc: raise HTTPException(404, "Document not found") return DocumentInfo( id=doc.id, name=doc.original_name, chunks=doc.chunks, status=doc.status, created_at=doc.created_at, ) @router.post("/parse/{doc_id}", response_model=ParseResponse) async def parse_document( doc_id: str, background_tasks: BackgroundTasks = None, ): """ 手动触发文档解析(如果文档已上传但未解析) """ doc = db_service.get_document(doc_id) if not doc: raise HTTPException(404, "Document not found") if doc.status not in [DocStatus.uploaded.value, DocStatus.failed.value]: raise HTTPException(400, f"Document status is {doc.status}, cannot parse") # 创建解析任务 task_id = generate_task_id() db_service.create_parse_task(task_id, doc_id) # 异步执行 background_tasks.add_task( run_parse_workflow_sync, task_id, doc_id, ) return ParseResponse( doc_id=doc_id, task_id=task_id, status="parsing", ) @router.post("/embed/{doc_id}", response_model=EmbedResponse) async def embed_document( doc_id: str, background_tasks: BackgroundTasks = None, ): """ 触发文档向量化(需要文档已解析) """ doc = db_service.get_document(doc_id) if not doc: raise HTTPException(404, "Document not found") if doc.status != DocStatus.parsed.value: raise HTTPException(400, f"Document must be parsed first. Current status: {doc.status}") # 创建向量化任务 task_id = generate_task_id() db_service.create_parse_task(task_id, doc_id) # 异步执行 background_tasks.add_task( run_embedding_workflow_sync, task_id, doc_id, ) return EmbedResponse( doc_id=doc_id, task_id=task_id, status="embedding", ) def run_embedding_workflow_sync(task_id: str, doc_id: str): """同步包装器,用于 BackgroundTasks""" import asyncio asyncio.run(run_embedding_workflow(task_id, doc_id)) @router.get("/task/{task_id}", response_model=TaskStatusResponse) async def get_task_status_api(task_id: str): """获取任务状态""" status = get_task_status(task_id) if not status: # 检查数据库中的任务记录 task = db_service.get_parse_task(task_id) if task: return TaskStatusResponse( task_id=task_id, status=task.status, progress=task.progress or 0, message=task.message, ) raise HTTPException(404, "Task not found") return TaskStatusResponse( task_id=task_id, status=status.get("status", "unknown"), progress=status.get("progress", 0), message=status.get("message"), result=status.get("result"), ) @router.delete("/delete/{doc_id}") async def delete_document(doc_id: str): """ 删除文档 同时删除: - MinIO 中的文件 - 数据库中的记录 - 解析后的文本文件 """ doc = db_service.get_document(doc_id) if not doc: raise HTTPException(404, "Document not found") try: # 删除 MinIO 文件 minio_service.delete_file(doc.minio_path) # 删除本地解析文件 parsed_path = f"{settings.data_parsed_dir}/{doc_id}.txt" if os.path.exists(parsed_path): os.remove(parsed_path) # 删除本地临时文件 temp_path = f"{settings.data_raw_dir}/{doc.filename}" if os.path.exists(temp_path): os.remove(temp_path) # 删除数据库记录 db_service.delete_document(doc_id) logger.info(f"Document deleted: {doc_id}") return {"success": True, "doc_id": doc_id} except Exception as e: logger.error(f"Delete failed: {e}") raise HTTPException(500, f"Delete failed: {str(e)}")