Refactor code structure for improved readability and maintainability
This commit is contained in:
@@ -1,115 +1,299 @@
|
||||
from fastapi import APIRouter, UploadFile, File, HTTPException
|
||||
"""文档管理 API"""
|
||||
|
||||
from fastapi import APIRouter, UploadFile, File, HTTPException, BackgroundTasks
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from app.schemas.doc import (
|
||||
DocumentUploadResponse,
|
||||
DocumentListResponse,
|
||||
DocumentInfo,
|
||||
ParseResponse,
|
||||
EmbedResponse,
|
||||
TaskStatusResponse,
|
||||
)
|
||||
from app.services.mock_data import get_mock_documents, generate_doc_id
|
||||
from app.core.config import settings
|
||||
from app.services.minio import minio_service
|
||||
from app.services.database import db_service, init_db, DocStatus
|
||||
from app.services.tasks import generate_task_id, task_manager, get_task_status
|
||||
from app.workflows.document_workflow import (
|
||||
generate_doc_id,
|
||||
run_parse_workflow,
|
||||
run_embedding_workflow,
|
||||
)
|
||||
from app.utils.logger import logger
|
||||
|
||||
|
||||
router = APIRouter(prefix="/docs", tags=["文档管理"])
|
||||
|
||||
# 临时存储文档信息(包含预设的mock文档)
|
||||
documents_store: dict[str, dict] = {}
|
||||
# 启动时初始化数据库
|
||||
init_db()
|
||||
|
||||
# 初始化时加载mock文档
|
||||
for doc in get_mock_documents():
|
||||
documents_store[doc["id"]] = doc
|
||||
|
||||
def get_content_type(filename: str) -> str:
|
||||
"""根据文件扩展名获取 Content-Type"""
|
||||
ext = os.path.splitext(filename)[1].lower()
|
||||
content_types = {
|
||||
".pdf": "application/pdf",
|
||||
".docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
".doc": "application/msword",
|
||||
".txt": "text/plain",
|
||||
}
|
||||
return content_types.get(ext, "application/octet-stream")
|
||||
|
||||
|
||||
@router.post("/upload", response_model=DocumentUploadResponse)
|
||||
async def upload_document(file: UploadFile = File(...)):
|
||||
"""上传法规文档"""
|
||||
async def upload_document(
|
||||
file: UploadFile = File(...),
|
||||
background_tasks: BackgroundTasks = None,
|
||||
):
|
||||
"""
|
||||
上传法规文档到 MinIO,并自动触发异步解析
|
||||
|
||||
流程:
|
||||
1. 验证文件格式
|
||||
2. 生成文档ID
|
||||
3. 上传到 MinIO
|
||||
4. 创建数据库记录
|
||||
5. 触发异步解析任务(后续可替换为 RabbitMQ)
|
||||
"""
|
||||
# 检查文件格式
|
||||
allowed_ext = [".pdf", ".docx", ".doc", ".txt"]
|
||||
ext = os.path.splitext(file.filename)[1].lower()
|
||||
if ext not in allowed_ext:
|
||||
raise HTTPException(400, f"Unsupported file format: {ext}")
|
||||
|
||||
# 检查文件大小
|
||||
content = await file.read()
|
||||
max_size = 50 * 1024 * 1024 # 50MB
|
||||
if len(content) > max_size:
|
||||
raise HTTPException(400, f"File size exceeds limit: {max_size // 1024 // 1024}MB")
|
||||
|
||||
# 生成文档ID
|
||||
doc_id = generate_doc_id()
|
||||
|
||||
# 保存文件
|
||||
raw_dir = "/airegulation/demo-mao/backend/data/raw"
|
||||
os.makedirs(raw_dir, exist_ok=True)
|
||||
file_path = os.path.join(raw_dir, f"{doc_id}_{file.filename}")
|
||||
# 构建 MinIO 存储路径
|
||||
storage_filename = f"{doc_id}_{file.filename}"
|
||||
minio_path = f"documents/{storage_filename}"
|
||||
|
||||
content = await file.read()
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(content)
|
||||
try:
|
||||
# 上传到 MinIO
|
||||
content_type = get_content_type(file.filename)
|
||||
minio_url = minio_service.upload_file(
|
||||
minio_path,
|
||||
content,
|
||||
content_type,
|
||||
)
|
||||
|
||||
# 记录文档信息
|
||||
documents_store[doc_id] = {
|
||||
"id": doc_id,
|
||||
"name": file.filename,
|
||||
"path": file_path,
|
||||
"size": len(content),
|
||||
"status": "uploaded",
|
||||
"chunks": 0,
|
||||
"created_at": datetime.now(),
|
||||
}
|
||||
# 创建数据库记录
|
||||
doc = db_service.create_document(
|
||||
doc_id=doc_id,
|
||||
filename=storage_filename,
|
||||
original_name=file.filename,
|
||||
minio_path=minio_path,
|
||||
size=len(content),
|
||||
)
|
||||
|
||||
return DocumentUploadResponse(
|
||||
doc_id=doc_id,
|
||||
filename=file.filename,
|
||||
size=len(content),
|
||||
)
|
||||
logger.info(f"Document uploaded: {doc_id} - {file.filename}")
|
||||
|
||||
# 触发异步解析任务
|
||||
parse_task_id = generate_task_id()
|
||||
db_service.create_parse_task(parse_task_id, doc_id)
|
||||
|
||||
# 使用 asyncio 异步执行解析(后续替换为 RabbitMQ)
|
||||
background_tasks.add_task(
|
||||
run_parse_workflow_sync,
|
||||
parse_task_id,
|
||||
doc_id,
|
||||
)
|
||||
|
||||
return DocumentUploadResponse(
|
||||
doc_id=doc_id,
|
||||
filename=file.filename,
|
||||
size=len(content),
|
||||
status="uploaded",
|
||||
parse_task_id=parse_task_id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Upload failed: {e}")
|
||||
raise HTTPException(500, f"Upload failed: {str(e)}")
|
||||
|
||||
|
||||
def run_parse_workflow_sync(task_id: str, doc_id: str):
|
||||
"""同步包装器,用于 BackgroundTasks"""
|
||||
import asyncio
|
||||
asyncio.run(run_parse_workflow(task_id, doc_id))
|
||||
|
||||
|
||||
@router.get("/list", response_model=DocumentListResponse)
|
||||
async def list_documents():
|
||||
"""获取已索引文档列表"""
|
||||
docs = [
|
||||
DocumentInfo(
|
||||
id=d["id"],
|
||||
name=d["name"],
|
||||
chunks=d["chunks"],
|
||||
status=d["status"],
|
||||
created_at=d.get("created_at"),
|
||||
)
|
||||
for d in documents_store.values()
|
||||
]
|
||||
return DocumentListResponse(docs=docs)
|
||||
docs = db_service.list_documents()
|
||||
return DocumentListResponse(
|
||||
docs=[
|
||||
DocumentInfo(
|
||||
id=d.id,
|
||||
name=d.original_name,
|
||||
chunks=d.chunks,
|
||||
status=d.status,
|
||||
created_at=d.created_at,
|
||||
)
|
||||
for d in docs
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{doc_id}", response_model=DocumentInfo)
|
||||
async def get_document(doc_id: str):
|
||||
"""获取单个文档信息"""
|
||||
doc = db_service.get_document(doc_id)
|
||||
if not doc:
|
||||
raise HTTPException(404, "Document not found")
|
||||
|
||||
return DocumentInfo(
|
||||
id=doc.id,
|
||||
name=doc.original_name,
|
||||
chunks=doc.chunks,
|
||||
status=doc.status,
|
||||
created_at=doc.created_at,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/parse/{doc_id}", response_model=ParseResponse)
|
||||
async def parse_document(doc_id: str):
|
||||
"""解析文档并分块"""
|
||||
if doc_id not in documents_store:
|
||||
async def parse_document(
|
||||
doc_id: str,
|
||||
background_tasks: BackgroundTasks = None,
|
||||
):
|
||||
"""
|
||||
手动触发文档解析(如果文档已上传但未解析)
|
||||
"""
|
||||
doc = db_service.get_document(doc_id)
|
||||
if not doc:
|
||||
raise HTTPException(404, "Document not found")
|
||||
|
||||
doc = documents_store[doc_id]
|
||||
# 模拟解析逻辑
|
||||
doc["status"] = "parsed"
|
||||
# 根据文件大小计算chunks数量
|
||||
file_size = doc.get("size", 100000)
|
||||
doc["chunks"] = max(20, file_size // 8000)
|
||||
if doc.status not in [DocStatus.uploaded.value, DocStatus.failed.value]:
|
||||
raise HTTPException(400, f"Document status is {doc.status}, cannot parse")
|
||||
|
||||
return ParseResponse(doc_id=doc_id, chunks=doc["chunks"])
|
||||
# 创建解析任务
|
||||
task_id = generate_task_id()
|
||||
db_service.create_parse_task(task_id, doc_id)
|
||||
|
||||
# 异步执行
|
||||
background_tasks.add_task(
|
||||
run_parse_workflow_sync,
|
||||
task_id,
|
||||
doc_id,
|
||||
)
|
||||
|
||||
return ParseResponse(
|
||||
doc_id=doc_id,
|
||||
task_id=task_id,
|
||||
status="parsing",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/embed/{doc_id}", response_model=EmbedResponse)
|
||||
async def embed_document(doc_id: str):
|
||||
"""嵌入并存入向量库"""
|
||||
if doc_id not in documents_store:
|
||||
async def embed_document(
|
||||
doc_id: str,
|
||||
background_tasks: BackgroundTasks = None,
|
||||
):
|
||||
"""
|
||||
触发文档向量化(需要文档已解析)
|
||||
"""
|
||||
doc = db_service.get_document(doc_id)
|
||||
if not doc:
|
||||
raise HTTPException(404, "Document not found")
|
||||
|
||||
doc = documents_store[doc_id]
|
||||
# 模拟嵌入逻辑
|
||||
doc["status"] = "indexed"
|
||||
if doc.status != DocStatus.parsed.value:
|
||||
raise HTTPException(400, f"Document must be parsed first. Current status: {doc.status}")
|
||||
|
||||
return EmbedResponse(doc_id=doc_id, vectors=doc["chunks"])
|
||||
# 创建向量化任务
|
||||
task_id = generate_task_id()
|
||||
db_service.create_parse_task(task_id, doc_id)
|
||||
|
||||
# 异步执行
|
||||
background_tasks.add_task(
|
||||
run_embedding_workflow_sync,
|
||||
task_id,
|
||||
doc_id,
|
||||
)
|
||||
|
||||
return EmbedResponse(
|
||||
doc_id=doc_id,
|
||||
task_id=task_id,
|
||||
status="embedding",
|
||||
)
|
||||
|
||||
|
||||
def run_embedding_workflow_sync(task_id: str, doc_id: str):
|
||||
"""同步包装器,用于 BackgroundTasks"""
|
||||
import asyncio
|
||||
asyncio.run(run_embedding_workflow(task_id, doc_id))
|
||||
|
||||
|
||||
@router.get("/task/{task_id}", response_model=TaskStatusResponse)
|
||||
async def get_task_status_api(task_id: str):
|
||||
"""获取任务状态"""
|
||||
status = get_task_status(task_id)
|
||||
if not status:
|
||||
# 检查数据库中的任务记录
|
||||
task = db_service.get_parse_task(task_id)
|
||||
if task:
|
||||
return TaskStatusResponse(
|
||||
task_id=task_id,
|
||||
status=task.status,
|
||||
progress=task.progress or 0,
|
||||
message=task.message,
|
||||
)
|
||||
raise HTTPException(404, "Task not found")
|
||||
|
||||
return TaskStatusResponse(
|
||||
task_id=task_id,
|
||||
status=status.get("status", "unknown"),
|
||||
progress=status.get("progress", 0),
|
||||
message=status.get("message"),
|
||||
result=status.get("result"),
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/delete/{doc_id}")
|
||||
async def delete_document(doc_id: str):
|
||||
"""删除文档"""
|
||||
if doc_id not in documents_store:
|
||||
"""
|
||||
删除文档
|
||||
|
||||
同时删除:
|
||||
- MinIO 中的文件
|
||||
- 数据库中的记录
|
||||
- 解析后的文本文件
|
||||
"""
|
||||
doc = db_service.get_document(doc_id)
|
||||
if not doc:
|
||||
raise HTTPException(404, "Document not found")
|
||||
|
||||
del documents_store[doc_id]
|
||||
return {"success": True}
|
||||
try:
|
||||
# 删除 MinIO 文件
|
||||
minio_service.delete_file(doc.minio_path)
|
||||
|
||||
# 删除本地解析文件
|
||||
parsed_path = f"{settings.data_parsed_dir}/{doc_id}.txt"
|
||||
if os.path.exists(parsed_path):
|
||||
os.remove(parsed_path)
|
||||
|
||||
# 删除本地临时文件
|
||||
temp_path = f"{settings.data_raw_dir}/{doc.filename}"
|
||||
if os.path.exists(temp_path):
|
||||
os.remove(temp_path)
|
||||
|
||||
# 删除数据库记录
|
||||
db_service.delete_document(doc_id)
|
||||
|
||||
logger.info(f"Document deleted: {doc_id}")
|
||||
|
||||
return {"success": True, "doc_id": doc_id}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Delete failed: {e}")
|
||||
raise HTTPException(500, f"Delete failed: {str(e)}")
|
||||
Reference in New Issue
Block a user