Files

299 lines
8.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""文档管理 API"""
from fastapi import APIRouter, UploadFile, File, HTTPException, BackgroundTasks
import os
import uuid
from datetime import datetime
from typing import Optional
from app.schemas.doc import (
DocumentUploadResponse,
DocumentListResponse,
DocumentInfo,
ParseResponse,
EmbedResponse,
TaskStatusResponse,
)
from app.core.config import settings
from app.services.minio import minio_service
from app.services.database import db_service, init_db, DocStatus
from app.services.tasks import generate_task_id, task_manager, get_task_status
from app.workflows.document_workflow import (
generate_doc_id,
run_parse_workflow,
run_embedding_workflow,
)
from app.utils.logger import logger
router = APIRouter(prefix="/docs", tags=["文档管理"])
# 启动时初始化数据库
init_db()
def get_content_type(filename: str) -> str:
"""根据文件扩展名获取 Content-Type"""
ext = os.path.splitext(filename)[1].lower()
content_types = {
".pdf": "application/pdf",
".docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
".doc": "application/msword",
".txt": "text/plain",
}
return content_types.get(ext, "application/octet-stream")
@router.post("/upload", response_model=DocumentUploadResponse)
async def upload_document(
file: UploadFile = File(...),
background_tasks: BackgroundTasks = None,
):
"""
上传法规文档到 MinIO并自动触发异步解析
流程:
1. 验证文件格式
2. 生成文档ID
3. 上传到 MinIO
4. 创建数据库记录
5. 触发异步解析任务(后续可替换为 RabbitMQ
"""
# 检查文件格式
allowed_ext = [".pdf", ".docx", ".doc", ".txt"]
ext = os.path.splitext(file.filename)[1].lower()
if ext not in allowed_ext:
raise HTTPException(400, f"Unsupported file format: {ext}")
# 检查文件大小
content = await file.read()
max_size = 50 * 1024 * 1024 # 50MB
if len(content) > max_size:
raise HTTPException(400, f"File size exceeds limit: {max_size // 1024 // 1024}MB")
# 生成文档ID
doc_id = generate_doc_id()
# 构建 MinIO 存储路径
storage_filename = f"{doc_id}_{file.filename}"
minio_path = f"documents/{storage_filename}"
try:
# 上传到 MinIO
content_type = get_content_type(file.filename)
minio_url = minio_service.upload_file(
minio_path,
content,
content_type,
)
# 创建数据库记录
doc = db_service.create_document(
doc_id=doc_id,
filename=storage_filename,
original_name=file.filename,
minio_path=minio_path,
size=len(content),
)
logger.info(f"Document uploaded: {doc_id} - {file.filename}")
# 触发异步解析任务
parse_task_id = generate_task_id()
db_service.create_parse_task(parse_task_id, doc_id)
# 使用 asyncio 异步执行解析(后续替换为 RabbitMQ
background_tasks.add_task(
run_parse_workflow_sync,
parse_task_id,
doc_id,
)
return DocumentUploadResponse(
doc_id=doc_id,
filename=file.filename,
size=len(content),
status="uploaded",
parse_task_id=parse_task_id,
)
except Exception as e:
logger.error(f"Upload failed: {e}")
raise HTTPException(500, f"Upload failed: {str(e)}")
def run_parse_workflow_sync(task_id: str, doc_id: str):
"""同步包装器,用于 BackgroundTasks"""
import asyncio
asyncio.run(run_parse_workflow(task_id, doc_id))
@router.get("/list", response_model=DocumentListResponse)
async def list_documents():
"""获取已索引文档列表"""
docs = db_service.list_documents()
return DocumentListResponse(
docs=[
DocumentInfo(
id=d.id,
name=d.original_name,
chunks=d.chunks,
status=d.status,
created_at=d.created_at,
)
for d in docs
]
)
@router.get("/{doc_id}", response_model=DocumentInfo)
async def get_document(doc_id: str):
"""获取单个文档信息"""
doc = db_service.get_document(doc_id)
if not doc:
raise HTTPException(404, "Document not found")
return DocumentInfo(
id=doc.id,
name=doc.original_name,
chunks=doc.chunks,
status=doc.status,
created_at=doc.created_at,
)
@router.post("/parse/{doc_id}", response_model=ParseResponse)
async def parse_document(
doc_id: str,
background_tasks: BackgroundTasks = None,
):
"""
手动触发文档解析(如果文档已上传但未解析)
"""
doc = db_service.get_document(doc_id)
if not doc:
raise HTTPException(404, "Document not found")
if doc.status not in [DocStatus.uploaded.value, DocStatus.failed.value]:
raise HTTPException(400, f"Document status is {doc.status}, cannot parse")
# 创建解析任务
task_id = generate_task_id()
db_service.create_parse_task(task_id, doc_id)
# 异步执行
background_tasks.add_task(
run_parse_workflow_sync,
task_id,
doc_id,
)
return ParseResponse(
doc_id=doc_id,
task_id=task_id,
status="parsing",
)
@router.post("/embed/{doc_id}", response_model=EmbedResponse)
async def embed_document(
doc_id: str,
background_tasks: BackgroundTasks = None,
):
"""
触发文档向量化(需要文档已解析)
"""
doc = db_service.get_document(doc_id)
if not doc:
raise HTTPException(404, "Document not found")
if doc.status != DocStatus.parsed.value:
raise HTTPException(400, f"Document must be parsed first. Current status: {doc.status}")
# 创建向量化任务
task_id = generate_task_id()
db_service.create_parse_task(task_id, doc_id)
# 异步执行
background_tasks.add_task(
run_embedding_workflow_sync,
task_id,
doc_id,
)
return EmbedResponse(
doc_id=doc_id,
task_id=task_id,
status="embedding",
)
def run_embedding_workflow_sync(task_id: str, doc_id: str):
"""同步包装器,用于 BackgroundTasks"""
import asyncio
asyncio.run(run_embedding_workflow(task_id, doc_id))
@router.get("/task/{task_id}", response_model=TaskStatusResponse)
async def get_task_status_api(task_id: str):
"""获取任务状态"""
status = get_task_status(task_id)
if not status:
# 检查数据库中的任务记录
task = db_service.get_parse_task(task_id)
if task:
return TaskStatusResponse(
task_id=task_id,
status=task.status,
progress=task.progress or 0,
message=task.message,
)
raise HTTPException(404, "Task not found")
return TaskStatusResponse(
task_id=task_id,
status=status.get("status", "unknown"),
progress=status.get("progress", 0),
message=status.get("message"),
result=status.get("result"),
)
@router.delete("/delete/{doc_id}")
async def delete_document(doc_id: str):
"""
删除文档
同时删除:
- MinIO 中的文件
- 数据库中的记录
- 解析后的文本文件
"""
doc = db_service.get_document(doc_id)
if not doc:
raise HTTPException(404, "Document not found")
try:
# 删除 MinIO 文件
minio_service.delete_file(doc.minio_path)
# 删除本地解析文件
parsed_path = f"{settings.data_parsed_dir}/{doc_id}.txt"
if os.path.exists(parsed_path):
os.remove(parsed_path)
# 删除本地临时文件
temp_path = f"{settings.data_raw_dir}/{doc.filename}"
if os.path.exists(temp_path):
os.remove(temp_path)
# 删除数据库记录
db_service.delete_document(doc_id)
logger.info(f"Document deleted: {doc_id}")
return {"success": True, "doc_id": doc_id}
except Exception as e:
logger.error(f"Delete failed: {e}")
raise HTTPException(500, f"Delete failed: {str(e)}")