Refactor code structure for improved readability and maintainability
This commit is contained in:
@@ -1,115 +1,299 @@
|
||||
from fastapi import APIRouter, UploadFile, File, HTTPException
|
||||
"""文档管理 API"""
|
||||
|
||||
from fastapi import APIRouter, UploadFile, File, HTTPException, BackgroundTasks
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from app.schemas.doc import (
|
||||
DocumentUploadResponse,
|
||||
DocumentListResponse,
|
||||
DocumentInfo,
|
||||
ParseResponse,
|
||||
EmbedResponse,
|
||||
TaskStatusResponse,
|
||||
)
|
||||
from app.services.mock_data import get_mock_documents, generate_doc_id
|
||||
from app.core.config import settings
|
||||
from app.services.minio import minio_service
|
||||
from app.services.database import db_service, init_db, DocStatus
|
||||
from app.services.tasks import generate_task_id, task_manager, get_task_status
|
||||
from app.workflows.document_workflow import (
|
||||
generate_doc_id,
|
||||
run_parse_workflow,
|
||||
run_embedding_workflow,
|
||||
)
|
||||
from app.utils.logger import logger
|
||||
|
||||
|
||||
router = APIRouter(prefix="/docs", tags=["文档管理"])
|
||||
|
||||
# 临时存储文档信息(包含预设的mock文档)
|
||||
documents_store: dict[str, dict] = {}
|
||||
# 启动时初始化数据库
|
||||
init_db()
|
||||
|
||||
# 初始化时加载mock文档
|
||||
for doc in get_mock_documents():
|
||||
documents_store[doc["id"]] = doc
|
||||
|
||||
def get_content_type(filename: str) -> str:
|
||||
"""根据文件扩展名获取 Content-Type"""
|
||||
ext = os.path.splitext(filename)[1].lower()
|
||||
content_types = {
|
||||
".pdf": "application/pdf",
|
||||
".docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
".doc": "application/msword",
|
||||
".txt": "text/plain",
|
||||
}
|
||||
return content_types.get(ext, "application/octet-stream")
|
||||
|
||||
|
||||
@router.post("/upload", response_model=DocumentUploadResponse)
|
||||
async def upload_document(file: UploadFile = File(...)):
|
||||
"""上传法规文档"""
|
||||
async def upload_document(
|
||||
file: UploadFile = File(...),
|
||||
background_tasks: BackgroundTasks = None,
|
||||
):
|
||||
"""
|
||||
上传法规文档到 MinIO,并自动触发异步解析
|
||||
|
||||
流程:
|
||||
1. 验证文件格式
|
||||
2. 生成文档ID
|
||||
3. 上传到 MinIO
|
||||
4. 创建数据库记录
|
||||
5. 触发异步解析任务(后续可替换为 RabbitMQ)
|
||||
"""
|
||||
# 检查文件格式
|
||||
allowed_ext = [".pdf", ".docx", ".doc", ".txt"]
|
||||
ext = os.path.splitext(file.filename)[1].lower()
|
||||
if ext not in allowed_ext:
|
||||
raise HTTPException(400, f"Unsupported file format: {ext}")
|
||||
|
||||
# 检查文件大小
|
||||
content = await file.read()
|
||||
max_size = 50 * 1024 * 1024 # 50MB
|
||||
if len(content) > max_size:
|
||||
raise HTTPException(400, f"File size exceeds limit: {max_size // 1024 // 1024}MB")
|
||||
|
||||
# 生成文档ID
|
||||
doc_id = generate_doc_id()
|
||||
|
||||
# 保存文件
|
||||
raw_dir = "/airegulation/demo-mao/backend/data/raw"
|
||||
os.makedirs(raw_dir, exist_ok=True)
|
||||
file_path = os.path.join(raw_dir, f"{doc_id}_{file.filename}")
|
||||
# 构建 MinIO 存储路径
|
||||
storage_filename = f"{doc_id}_{file.filename}"
|
||||
minio_path = f"documents/{storage_filename}"
|
||||
|
||||
content = await file.read()
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(content)
|
||||
try:
|
||||
# 上传到 MinIO
|
||||
content_type = get_content_type(file.filename)
|
||||
minio_url = minio_service.upload_file(
|
||||
minio_path,
|
||||
content,
|
||||
content_type,
|
||||
)
|
||||
|
||||
# 记录文档信息
|
||||
documents_store[doc_id] = {
|
||||
"id": doc_id,
|
||||
"name": file.filename,
|
||||
"path": file_path,
|
||||
"size": len(content),
|
||||
"status": "uploaded",
|
||||
"chunks": 0,
|
||||
"created_at": datetime.now(),
|
||||
}
|
||||
# 创建数据库记录
|
||||
doc = db_service.create_document(
|
||||
doc_id=doc_id,
|
||||
filename=storage_filename,
|
||||
original_name=file.filename,
|
||||
minio_path=minio_path,
|
||||
size=len(content),
|
||||
)
|
||||
|
||||
return DocumentUploadResponse(
|
||||
doc_id=doc_id,
|
||||
filename=file.filename,
|
||||
size=len(content),
|
||||
)
|
||||
logger.info(f"Document uploaded: {doc_id} - {file.filename}")
|
||||
|
||||
# 触发异步解析任务
|
||||
parse_task_id = generate_task_id()
|
||||
db_service.create_parse_task(parse_task_id, doc_id)
|
||||
|
||||
# 使用 asyncio 异步执行解析(后续替换为 RabbitMQ)
|
||||
background_tasks.add_task(
|
||||
run_parse_workflow_sync,
|
||||
parse_task_id,
|
||||
doc_id,
|
||||
)
|
||||
|
||||
return DocumentUploadResponse(
|
||||
doc_id=doc_id,
|
||||
filename=file.filename,
|
||||
size=len(content),
|
||||
status="uploaded",
|
||||
parse_task_id=parse_task_id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Upload failed: {e}")
|
||||
raise HTTPException(500, f"Upload failed: {str(e)}")
|
||||
|
||||
|
||||
def run_parse_workflow_sync(task_id: str, doc_id: str):
|
||||
"""同步包装器,用于 BackgroundTasks"""
|
||||
import asyncio
|
||||
asyncio.run(run_parse_workflow(task_id, doc_id))
|
||||
|
||||
|
||||
@router.get("/list", response_model=DocumentListResponse)
|
||||
async def list_documents():
|
||||
"""获取已索引文档列表"""
|
||||
docs = [
|
||||
DocumentInfo(
|
||||
id=d["id"],
|
||||
name=d["name"],
|
||||
chunks=d["chunks"],
|
||||
status=d["status"],
|
||||
created_at=d.get("created_at"),
|
||||
)
|
||||
for d in documents_store.values()
|
||||
]
|
||||
return DocumentListResponse(docs=docs)
|
||||
docs = db_service.list_documents()
|
||||
return DocumentListResponse(
|
||||
docs=[
|
||||
DocumentInfo(
|
||||
id=d.id,
|
||||
name=d.original_name,
|
||||
chunks=d.chunks,
|
||||
status=d.status,
|
||||
created_at=d.created_at,
|
||||
)
|
||||
for d in docs
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{doc_id}", response_model=DocumentInfo)
|
||||
async def get_document(doc_id: str):
|
||||
"""获取单个文档信息"""
|
||||
doc = db_service.get_document(doc_id)
|
||||
if not doc:
|
||||
raise HTTPException(404, "Document not found")
|
||||
|
||||
return DocumentInfo(
|
||||
id=doc.id,
|
||||
name=doc.original_name,
|
||||
chunks=doc.chunks,
|
||||
status=doc.status,
|
||||
created_at=doc.created_at,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/parse/{doc_id}", response_model=ParseResponse)
|
||||
async def parse_document(doc_id: str):
|
||||
"""解析文档并分块"""
|
||||
if doc_id not in documents_store:
|
||||
async def parse_document(
|
||||
doc_id: str,
|
||||
background_tasks: BackgroundTasks = None,
|
||||
):
|
||||
"""
|
||||
手动触发文档解析(如果文档已上传但未解析)
|
||||
"""
|
||||
doc = db_service.get_document(doc_id)
|
||||
if not doc:
|
||||
raise HTTPException(404, "Document not found")
|
||||
|
||||
doc = documents_store[doc_id]
|
||||
# 模拟解析逻辑
|
||||
doc["status"] = "parsed"
|
||||
# 根据文件大小计算chunks数量
|
||||
file_size = doc.get("size", 100000)
|
||||
doc["chunks"] = max(20, file_size // 8000)
|
||||
if doc.status not in [DocStatus.uploaded.value, DocStatus.failed.value]:
|
||||
raise HTTPException(400, f"Document status is {doc.status}, cannot parse")
|
||||
|
||||
return ParseResponse(doc_id=doc_id, chunks=doc["chunks"])
|
||||
# 创建解析任务
|
||||
task_id = generate_task_id()
|
||||
db_service.create_parse_task(task_id, doc_id)
|
||||
|
||||
# 异步执行
|
||||
background_tasks.add_task(
|
||||
run_parse_workflow_sync,
|
||||
task_id,
|
||||
doc_id,
|
||||
)
|
||||
|
||||
return ParseResponse(
|
||||
doc_id=doc_id,
|
||||
task_id=task_id,
|
||||
status="parsing",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/embed/{doc_id}", response_model=EmbedResponse)
|
||||
async def embed_document(doc_id: str):
|
||||
"""嵌入并存入向量库"""
|
||||
if doc_id not in documents_store:
|
||||
async def embed_document(
|
||||
doc_id: str,
|
||||
background_tasks: BackgroundTasks = None,
|
||||
):
|
||||
"""
|
||||
触发文档向量化(需要文档已解析)
|
||||
"""
|
||||
doc = db_service.get_document(doc_id)
|
||||
if not doc:
|
||||
raise HTTPException(404, "Document not found")
|
||||
|
||||
doc = documents_store[doc_id]
|
||||
# 模拟嵌入逻辑
|
||||
doc["status"] = "indexed"
|
||||
if doc.status != DocStatus.parsed.value:
|
||||
raise HTTPException(400, f"Document must be parsed first. Current status: {doc.status}")
|
||||
|
||||
return EmbedResponse(doc_id=doc_id, vectors=doc["chunks"])
|
||||
# 创建向量化任务
|
||||
task_id = generate_task_id()
|
||||
db_service.create_parse_task(task_id, doc_id)
|
||||
|
||||
# 异步执行
|
||||
background_tasks.add_task(
|
||||
run_embedding_workflow_sync,
|
||||
task_id,
|
||||
doc_id,
|
||||
)
|
||||
|
||||
return EmbedResponse(
|
||||
doc_id=doc_id,
|
||||
task_id=task_id,
|
||||
status="embedding",
|
||||
)
|
||||
|
||||
|
||||
def run_embedding_workflow_sync(task_id: str, doc_id: str):
|
||||
"""同步包装器,用于 BackgroundTasks"""
|
||||
import asyncio
|
||||
asyncio.run(run_embedding_workflow(task_id, doc_id))
|
||||
|
||||
|
||||
@router.get("/task/{task_id}", response_model=TaskStatusResponse)
|
||||
async def get_task_status_api(task_id: str):
|
||||
"""获取任务状态"""
|
||||
status = get_task_status(task_id)
|
||||
if not status:
|
||||
# 检查数据库中的任务记录
|
||||
task = db_service.get_parse_task(task_id)
|
||||
if task:
|
||||
return TaskStatusResponse(
|
||||
task_id=task_id,
|
||||
status=task.status,
|
||||
progress=task.progress or 0,
|
||||
message=task.message,
|
||||
)
|
||||
raise HTTPException(404, "Task not found")
|
||||
|
||||
return TaskStatusResponse(
|
||||
task_id=task_id,
|
||||
status=status.get("status", "unknown"),
|
||||
progress=status.get("progress", 0),
|
||||
message=status.get("message"),
|
||||
result=status.get("result"),
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/delete/{doc_id}")
|
||||
async def delete_document(doc_id: str):
|
||||
"""删除文档"""
|
||||
if doc_id not in documents_store:
|
||||
"""
|
||||
删除文档
|
||||
|
||||
同时删除:
|
||||
- MinIO 中的文件
|
||||
- 数据库中的记录
|
||||
- 解析后的文本文件
|
||||
"""
|
||||
doc = db_service.get_document(doc_id)
|
||||
if not doc:
|
||||
raise HTTPException(404, "Document not found")
|
||||
|
||||
del documents_store[doc_id]
|
||||
return {"success": True}
|
||||
try:
|
||||
# 删除 MinIO 文件
|
||||
minio_service.delete_file(doc.minio_path)
|
||||
|
||||
# 删除本地解析文件
|
||||
parsed_path = f"{settings.data_parsed_dir}/{doc_id}.txt"
|
||||
if os.path.exists(parsed_path):
|
||||
os.remove(parsed_path)
|
||||
|
||||
# 删除本地临时文件
|
||||
temp_path = f"{settings.data_raw_dir}/{doc.filename}"
|
||||
if os.path.exists(temp_path):
|
||||
os.remove(temp_path)
|
||||
|
||||
# 删除数据库记录
|
||||
db_service.delete_document(doc_id)
|
||||
|
||||
logger.info(f"Document deleted: {doc_id}")
|
||||
|
||||
return {"success": True, "doc_id": doc_id}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Delete failed: {e}")
|
||||
raise HTTPException(500, f"Delete failed: {str(e)}")
|
||||
@@ -3,13 +3,51 @@ from typing import Optional
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
# DashScope API
|
||||
dashscope_api_key: str = ""
|
||||
# Qwen API配置
|
||||
qwen_api_key: str = ""
|
||||
qwen_base_url: str = "https://dashscope.aliyuncs.com/api/v1"
|
||||
qwen_model: str = "qwen-max"
|
||||
qwen_vl_model: str = "qwen-vl-plus"
|
||||
|
||||
# DeepSeek API配置
|
||||
deepseek_api_key: str = ""
|
||||
deepseek_base_url: str = "https://api.deepseek.com/v1"
|
||||
deepseek_model: str = "deepseek-v3"
|
||||
|
||||
# PostgreSQL
|
||||
postgres_host: str = "localhost"
|
||||
postgres_port: int = 5432
|
||||
postgres_user: str = "postgresql"
|
||||
postgres_password: str = "postgresql123456"
|
||||
postgres_db: str = "mydb"
|
||||
|
||||
# Redis
|
||||
redis_host: str = "localhost"
|
||||
redis_port: int = 6379
|
||||
redis_password: str = ""
|
||||
|
||||
# MinIO
|
||||
minio_endpoint: str = "localhost:9000"
|
||||
minio_access_key: str = "minioadmin"
|
||||
minio_secret_key: str = "minioadmin"
|
||||
minio_bucket: str = "regulation-docs"
|
||||
minio_secure: bool = False
|
||||
|
||||
# Milvus
|
||||
milvus_host: str = "localhost"
|
||||
milvus_port: int = 19530
|
||||
|
||||
# Neo4j
|
||||
neo4j_uri: str = "bolt://localhost:7687"
|
||||
neo4j_user: str = "neo4j"
|
||||
neo4j_password: str = "neo4j123"
|
||||
|
||||
# RabbitMQ
|
||||
rabbitmq_host: str = "localhost"
|
||||
rabbitmq_port: int = 5672
|
||||
rabbitmq_user: str = "admin"
|
||||
rabbitmq_password: str = "admin@123"
|
||||
|
||||
# LLM配置
|
||||
llm_model: str = "qwen-max"
|
||||
embedding_model: str = "text-embedding-v3"
|
||||
@@ -32,6 +70,10 @@ class Settings(BaseSettings):
|
||||
regulations_collection: str = "vehicle_regulations"
|
||||
compliance_collection: str = "compliance_cache"
|
||||
|
||||
# 数据目录
|
||||
data_raw_dir: str = "/airegulation/demo-mao/backend/data/raw"
|
||||
data_parsed_dir: str = "/airegulation/demo-mao/backend/data/parsed"
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
env_file_encoding = "utf-8"
|
||||
|
||||
@@ -1,16 +1,21 @@
|
||||
"""文档相关数据模型"""
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
from typing import Optional, Any
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class DocumentUploadResponse(BaseModel):
|
||||
"""文档上传响应"""
|
||||
doc_id: str
|
||||
filename: str
|
||||
size: int
|
||||
status: str = "uploaded"
|
||||
parse_task_id: Optional[str] = None # 解析任务ID
|
||||
|
||||
|
||||
class DocumentInfo(BaseModel):
|
||||
"""文档信息"""
|
||||
id: str
|
||||
name: str
|
||||
chunks: int
|
||||
@@ -19,10 +24,12 @@ class DocumentInfo(BaseModel):
|
||||
|
||||
|
||||
class DocumentListResponse(BaseModel):
|
||||
"""文档列表响应"""
|
||||
docs: list[DocumentInfo]
|
||||
|
||||
|
||||
class ChunkInfo(BaseModel):
|
||||
"""文本块信息"""
|
||||
chunk_id: str
|
||||
doc_name: str
|
||||
clause_id: Optional[str] = None
|
||||
@@ -33,12 +40,25 @@ class ChunkInfo(BaseModel):
|
||||
|
||||
|
||||
class ParseResponse(BaseModel):
|
||||
"""解析响应"""
|
||||
doc_id: str
|
||||
chunks: int
|
||||
status: str = "parsed"
|
||||
task_id: Optional[str] = None
|
||||
chunks: int = 0
|
||||
status: str = "parsing"
|
||||
|
||||
|
||||
class EmbedResponse(BaseModel):
|
||||
"""嵌入响应"""
|
||||
doc_id: str
|
||||
vectors: int
|
||||
status: str = "embedded"
|
||||
task_id: Optional[str] = None
|
||||
vectors: int = 0
|
||||
status: str = "embedding"
|
||||
|
||||
|
||||
class TaskStatusResponse(BaseModel):
|
||||
"""任务状态响应"""
|
||||
task_id: str
|
||||
status: str
|
||||
progress: int
|
||||
message: Optional[str] = None
|
||||
result: Optional[Any] = None
|
||||
@@ -1,4 +1,9 @@
|
||||
# Import mock data service
|
||||
# Import services
|
||||
from .minio import minio_service, MinioService
|
||||
from .database import db_service, DatabaseService, init_db, Document, ParseTask
|
||||
from .tasks import task_manager, get_task_status, set_task_status, generate_task_id
|
||||
|
||||
# Import mock data service (for development)
|
||||
from .mock_data import (
|
||||
get_mock_documents,
|
||||
get_mock_quick_questions,
|
||||
@@ -29,6 +34,18 @@ except ImportError:
|
||||
get_document_service = None
|
||||
|
||||
__all__ = [
|
||||
# Core services
|
||||
"minio_service",
|
||||
"MinioService",
|
||||
"db_service",
|
||||
"DatabaseService",
|
||||
"init_db",
|
||||
"Document",
|
||||
"ParseTask",
|
||||
"task_manager",
|
||||
"get_task_status",
|
||||
"set_task_status",
|
||||
"generate_task_id",
|
||||
# Mock data services
|
||||
"get_mock_documents",
|
||||
"get_mock_quick_questions",
|
||||
|
||||
228
app/services/database.py
Normal file
228
app/services/database.py
Normal file
@@ -0,0 +1,228 @@
|
||||
"""数据库服务 - PostgreSQL"""
|
||||
|
||||
from sqlalchemy import create_engine, Column, String, Integer, DateTime, Enum, Text
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import sessionmaker, Session
|
||||
from datetime import datetime
|
||||
from typing import Optional, List
|
||||
import enum
|
||||
from app.core.config import settings
|
||||
from app.utils.logger import logger
|
||||
|
||||
|
||||
# 数据库连接
|
||||
DATABASE_URL = f"postgresql://{settings.postgres_user}:{settings.postgres_password}@{settings.postgres_host}:{settings.postgres_port}/{settings.postgres_db}"
|
||||
|
||||
engine = create_engine(DATABASE_URL, echo=False, pool_pre_ping=True)
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
class DocStatus(str, enum.Enum):
|
||||
"""文档处理状态"""
|
||||
uploaded = "uploaded" # 已上传
|
||||
parsing = "parsing" # 解析中
|
||||
parsed = "parsed" # 已解析
|
||||
embedding = "embedding" # 向量化中
|
||||
indexed = "indexed" # 已索引
|
||||
failed = "failed" # 处理失败
|
||||
|
||||
|
||||
class Document(Base):
|
||||
"""文档表"""
|
||||
__tablename__ = "documents"
|
||||
|
||||
id = Column(String(64), primary_key=True)
|
||||
filename = Column(String(255), nullable=False)
|
||||
original_name = Column(String(255), nullable=False)
|
||||
minio_path = Column(String(512), nullable=False) # MinIO 存储路径
|
||||
size = Column(Integer, default=0)
|
||||
status = Column(String(32), default=DocStatus.uploaded.value)
|
||||
chunks = Column(Integer, default=0)
|
||||
vectors = Column(Integer, default=0)
|
||||
error_message = Column(Text, nullable=True)
|
||||
created_at = Column(DateTime, default=datetime.now)
|
||||
updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now)
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"id": self.id,
|
||||
"name": self.original_name,
|
||||
"chunks": self.chunks,
|
||||
"status": self.status,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
}
|
||||
|
||||
|
||||
class ParseTask(Base):
|
||||
"""解析任务表"""
|
||||
__tablename__ = "parse_tasks"
|
||||
|
||||
id = Column(String(64), primary_key=True)
|
||||
doc_id = Column(String(64), nullable=False)
|
||||
status = Column(String(32), default="pending")
|
||||
progress = Column(Integer, default=0)
|
||||
message = Column(Text, nullable=True)
|
||||
created_at = Column(DateTime, default=datetime.now)
|
||||
started_at = Column(DateTime, nullable=True)
|
||||
completed_at = Column(DateTime, nullable=True)
|
||||
|
||||
|
||||
def init_db():
|
||||
"""初始化数据库表"""
|
||||
try:
|
||||
Base.metadata.create_all(bind=engine)
|
||||
logger.info("Database tables created successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Database initialization failed: {e}")
|
||||
|
||||
|
||||
def get_db() -> Session:
|
||||
"""获取数据库会话"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
return db
|
||||
finally:
|
||||
# 注意:调用者需要负责关闭会话
|
||||
pass
|
||||
|
||||
|
||||
class DatabaseService:
|
||||
"""数据库服务"""
|
||||
|
||||
def __init__(self):
|
||||
self.engine = engine
|
||||
self.SessionLocal = SessionLocal
|
||||
|
||||
def create_document(
|
||||
self,
|
||||
doc_id: str,
|
||||
filename: str,
|
||||
original_name: str,
|
||||
minio_path: str,
|
||||
size: int,
|
||||
) -> Document:
|
||||
"""创建文档记录"""
|
||||
db = self.SessionLocal()
|
||||
try:
|
||||
doc = Document(
|
||||
id=doc_id,
|
||||
filename=filename,
|
||||
original_name=original_name,
|
||||
minio_path=minio_path,
|
||||
size=size,
|
||||
status=DocStatus.uploaded.value,
|
||||
)
|
||||
db.add(doc)
|
||||
db.commit()
|
||||
db.refresh(doc)
|
||||
return doc
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def get_document(self, doc_id: str) -> Optional[Document]:
|
||||
"""获取文档"""
|
||||
db = self.SessionLocal()
|
||||
try:
|
||||
return db.query(Document).filter(Document.id == doc_id).first()
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def update_document_status(
|
||||
self,
|
||||
doc_id: str,
|
||||
status: str,
|
||||
chunks: int = None,
|
||||
vectors: int = None,
|
||||
error_message: str = None,
|
||||
) -> Optional[Document]:
|
||||
"""更新文档状态"""
|
||||
db = self.SessionLocal()
|
||||
try:
|
||||
doc = db.query(Document).filter(Document.id == doc_id).first()
|
||||
if doc:
|
||||
doc.status = status
|
||||
doc.updated_at = datetime.now()
|
||||
if chunks is not None:
|
||||
doc.chunks = chunks
|
||||
if vectors is not None:
|
||||
doc.vectors = vectors
|
||||
if error_message:
|
||||
doc.error_message = error_message
|
||||
db.commit()
|
||||
db.refresh(doc)
|
||||
return doc
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def list_documents(self) -> List[Document]:
|
||||
"""列出所有文档"""
|
||||
db = self.SessionLocal()
|
||||
try:
|
||||
return db.query(Document).order_by(Document.created_at.desc()).all()
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def delete_document(self, doc_id: str) -> bool:
|
||||
"""删除文档"""
|
||||
db = self.SessionLocal()
|
||||
try:
|
||||
doc = db.query(Document).filter(Document.id == doc_id).first()
|
||||
if doc:
|
||||
db.delete(doc)
|
||||
db.commit()
|
||||
return True
|
||||
return False
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def create_parse_task(self, task_id: str, doc_id: str) -> ParseTask:
|
||||
"""创建解析任务"""
|
||||
db = self.SessionLocal()
|
||||
try:
|
||||
task = ParseTask(id=task_id, doc_id=doc_id)
|
||||
db.add(task)
|
||||
db.commit()
|
||||
db.refresh(task)
|
||||
return task
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def get_parse_task(self, task_id: str) -> Optional[ParseTask]:
|
||||
"""获取解析任务"""
|
||||
db = self.SessionLocal()
|
||||
try:
|
||||
return db.query(ParseTask).filter(ParseTask.id == task_id).first()
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def update_parse_task(
|
||||
self,
|
||||
task_id: str,
|
||||
status: str,
|
||||
progress: int = None,
|
||||
message: str = None,
|
||||
) -> Optional[ParseTask]:
|
||||
"""更新解析任务状态"""
|
||||
db = self.SessionLocal()
|
||||
try:
|
||||
task = db.query(ParseTask).filter(ParseTask.id == task_id).first()
|
||||
if task:
|
||||
task.status = status
|
||||
if progress is not None:
|
||||
task.progress = progress
|
||||
if message:
|
||||
task.message = message
|
||||
if status == "running":
|
||||
task.started_at = datetime.now()
|
||||
elif status in ("completed", "failed"):
|
||||
task.completed_at = datetime.now()
|
||||
db.commit()
|
||||
db.refresh(task)
|
||||
return task
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
# 单例
|
||||
db_service = DatabaseService()
|
||||
122
app/services/minio.py
Normal file
122
app/services/minio.py
Normal file
@@ -0,0 +1,122 @@
|
||||
"""MinIO 文件存储服务"""
|
||||
|
||||
import io
|
||||
from minio import Minio
|
||||
from minio.error import S3Error
|
||||
from app.core.config import settings
|
||||
from app.utils.logger import logger
|
||||
|
||||
|
||||
class MinioService:
|
||||
"""MinIO 文件存储服务"""
|
||||
|
||||
def __init__(self):
|
||||
self.client = Minio(
|
||||
settings.minio_endpoint,
|
||||
access_key=settings.minio_access_key,
|
||||
secret_key=settings.minio_secret_key,
|
||||
secure=settings.minio_secure,
|
||||
)
|
||||
self.bucket = settings.minio_bucket
|
||||
self._ensure_bucket()
|
||||
|
||||
def _ensure_bucket(self):
|
||||
"""确保存储桶存在"""
|
||||
try:
|
||||
if not self.client.bucket_exists(self.bucket):
|
||||
self.client.make_bucket(self.bucket)
|
||||
logger.info(f"Created MinIO bucket: {self.bucket}")
|
||||
except S3Error as e:
|
||||
logger.error(f"MinIO bucket check failed: {e}")
|
||||
|
||||
def upload_file(
|
||||
self,
|
||||
object_name: str,
|
||||
file_data: bytes,
|
||||
content_type: str = "application/octet-stream",
|
||||
) -> str:
|
||||
"""
|
||||
上传文件到 MinIO
|
||||
|
||||
Args:
|
||||
object_name: 对象名称(文件路径)
|
||||
file_data: 文件二进制数据
|
||||
content_type: 文件类型
|
||||
|
||||
Returns:
|
||||
文件的 MinIO URL
|
||||
"""
|
||||
try:
|
||||
data_stream = io.BytesIO(file_data)
|
||||
self.client.put_object(
|
||||
self.bucket,
|
||||
object_name,
|
||||
data_stream,
|
||||
length=len(file_data),
|
||||
content_type=content_type,
|
||||
)
|
||||
url = f"{settings.minio_endpoint}/{self.bucket}/{object_name}"
|
||||
logger.info(f"Uploaded file to MinIO: {object_name}")
|
||||
return url
|
||||
except S3Error as e:
|
||||
logger.error(f"MinIO upload failed: {e}")
|
||||
raise
|
||||
|
||||
def get_file(self, object_name: str) -> bytes:
|
||||
"""
|
||||
从 MinIO 获取文件
|
||||
|
||||
Args:
|
||||
object_name: 对象名称
|
||||
|
||||
Returns:
|
||||
文件二进制数据
|
||||
"""
|
||||
try:
|
||||
response = self.client.get_object(self.bucket, object_name)
|
||||
data = response.read()
|
||||
response.close()
|
||||
response.release_conn()
|
||||
return data
|
||||
except S3Error as e:
|
||||
logger.error(f"MinIO get file failed: {e}")
|
||||
raise
|
||||
|
||||
def delete_file(self, object_name: str) -> bool:
|
||||
"""
|
||||
删除 MinIO 中的文件
|
||||
|
||||
Args:
|
||||
object_name: 对象名称
|
||||
|
||||
Returns:
|
||||
是否成功删除
|
||||
"""
|
||||
try:
|
||||
self.client.remove_object(self.bucket, object_name)
|
||||
logger.info(f"Deleted file from MinIO: {object_name}")
|
||||
return True
|
||||
except S3Error as e:
|
||||
logger.error(f"MinIO delete failed: {e}")
|
||||
return False
|
||||
|
||||
def list_files(self, prefix: str = "") -> list[str]:
|
||||
"""
|
||||
列出 MinIO 中的文件
|
||||
|
||||
Args:
|
||||
prefix: 文件前缀过滤
|
||||
|
||||
Returns:
|
||||
文件名列表
|
||||
"""
|
||||
try:
|
||||
objects = self.client.list_objects(self.bucket, prefix=prefix)
|
||||
return [obj.object_name for obj in objects]
|
||||
except S3Error as e:
|
||||
logger.error(f"MinIO list files failed: {e}")
|
||||
return []
|
||||
|
||||
|
||||
# 单例
|
||||
minio_service = MinioService()
|
||||
89
app/services/tasks.py
Normal file
89
app/services/tasks.py
Normal file
@@ -0,0 +1,89 @@
|
||||
"""异步任务处理模块
|
||||
|
||||
TODO: 后续替换为 RabbitMQ 消息队列
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Callable, Awaitable
|
||||
from app.utils.logger import logger
|
||||
|
||||
|
||||
# 任务状态存储(后续替换为 Redis)
|
||||
_task_store: Dict[str, Dict] = {}
|
||||
|
||||
|
||||
def generate_task_id() -> str:
|
||||
"""生成任务ID"""
|
||||
return f"task-{uuid.uuid4().hex[:12]}"
|
||||
|
||||
|
||||
class AsyncTaskManager:
|
||||
"""异步任务管理器"""
|
||||
|
||||
def __init__(self):
|
||||
self._running_tasks: dict[str, asyncio.Task] = {}
|
||||
|
||||
def create_task(
|
||||
self,
|
||||
task_id: str,
|
||||
task_func: Callable[[str], Awaitable[None]],
|
||||
) -> asyncio.Task:
|
||||
"""
|
||||
创建异步任务
|
||||
|
||||
Args:
|
||||
task_id: 任务ID
|
||||
task_func: 任务执行函数
|
||||
|
||||
Returns:
|
||||
asyncio.Task
|
||||
"""
|
||||
task = asyncio.create_task(self._run_task(task_id, task_func))
|
||||
self._running_tasks[task_id] = task
|
||||
return task
|
||||
|
||||
async def _run_task(
|
||||
self,
|
||||
task_id: str,
|
||||
task_func: Callable[[str], Awaitable[None]],
|
||||
):
|
||||
"""运行任务并处理状态"""
|
||||
try:
|
||||
await task_func(task_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Task {task_id} failed: {e}")
|
||||
_task_store[task_id] = {
|
||||
"status": "failed",
|
||||
"error": str(e),
|
||||
"completed_at": datetime.now(),
|
||||
}
|
||||
finally:
|
||||
if task_id in self._running_tasks:
|
||||
del self._running_tasks[task_id]
|
||||
|
||||
def get_task_status(self, task_id: str) -> dict | None:
|
||||
"""获取任务状态"""
|
||||
return _task_store.get(task_id)
|
||||
|
||||
def cancel_task(self, task_id: str) -> bool:
|
||||
"""取消任务"""
|
||||
if task_id in self._running_tasks:
|
||||
self._running_tasks[task_id].cancel()
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# 单例
|
||||
task_manager = AsyncTaskManager()
|
||||
|
||||
|
||||
def get_task_status(task_id: str) -> dict | None:
|
||||
"""获取任务状态"""
|
||||
return _task_store.get(task_id)
|
||||
|
||||
|
||||
def set_task_status(task_id: str, status: dict):
|
||||
"""设置任务状态"""
|
||||
_task_store[task_id] = status
|
||||
252
app/workflows/document_workflow.py
Normal file
252
app/workflows/document_workflow.py
Normal file
@@ -0,0 +1,252 @@
|
||||
"""文档解析工作流 - 异步处理"""
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
import io
|
||||
|
||||
from app.core.config import settings
|
||||
from app.services.minio import minio_service
|
||||
from app.services.database import db_service, DocStatus
|
||||
from app.services.tasks import set_task_status, get_task_status
|
||||
from app.services.document import DocumentService
|
||||
from app.utils.chunking import TextChunker
|
||||
from app.utils.logger import logger
|
||||
|
||||
|
||||
def generate_doc_id() -> str:
|
||||
"""生成文档ID"""
|
||||
return f"doc-{uuid.uuid4().hex[:12]}"
|
||||
|
||||
|
||||
def generate_chunk_id(doc_id: str, index: int) -> str:
|
||||
"""生成块ID"""
|
||||
return f"{doc_id}-chunk-{index}"
|
||||
|
||||
|
||||
async def run_parse_workflow(task_id: str, doc_id: str):
|
||||
"""
|
||||
执行文档解析工作流
|
||||
|
||||
处理步骤:
|
||||
1. 获取文件 - 从 MinIO 下载文件
|
||||
2. 解析文档 - 提取文本内容
|
||||
3. 文本分块 - 按条款或固定大小分块
|
||||
4. 保存结果 - 存储分块数据
|
||||
|
||||
Args:
|
||||
task_id: 任务ID
|
||||
doc_id: 文档ID
|
||||
"""
|
||||
chunker = TextChunker()
|
||||
doc_service = DocumentService(settings.data_raw_dir, settings.data_parsed_dir)
|
||||
|
||||
try:
|
||||
# Step 1: 获取文件
|
||||
set_task_status(task_id, {
|
||||
"status": "running",
|
||||
"step": "fetching",
|
||||
"progress": 10,
|
||||
"message": "正在从存储获取文件...",
|
||||
"started_at": datetime.now(),
|
||||
})
|
||||
db_service.update_document_status(doc_id, DocStatus.parsing.value)
|
||||
|
||||
doc = db_service.get_document(doc_id)
|
||||
if not doc:
|
||||
raise ValueError(f"Document {doc_id} not found")
|
||||
|
||||
# 从 MinIO 获取文件
|
||||
file_data = minio_service.get_file(doc.minio_path)
|
||||
|
||||
# 保存到本地临时目录(用于解析)
|
||||
temp_path = f"{settings.data_raw_dir}/{doc_id}_{doc.filename}"
|
||||
with open(temp_path, "wb") as f:
|
||||
f.write(file_data)
|
||||
|
||||
await asyncio.sleep(0.5) # 模拟延迟
|
||||
|
||||
# Step 2: 解析文档
|
||||
set_task_status(task_id, {
|
||||
"status": "running",
|
||||
"step": "parsing",
|
||||
"progress": 30,
|
||||
"message": "正在解析文档内容...",
|
||||
})
|
||||
|
||||
text = doc_service.parse_document(temp_path)
|
||||
|
||||
if not text:
|
||||
raise ValueError("Document parsing returned empty content")
|
||||
|
||||
# 保存解析后的文本
|
||||
parsed_path = doc_service.save_parsed_text(doc_id, text)
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Step 3: 文本分块
|
||||
set_task_status(task_id, {
|
||||
"status": "running",
|
||||
"step": "chunking",
|
||||
"progress": 50,
|
||||
"message": "正在进行文本分块...",
|
||||
})
|
||||
|
||||
# 尝试按条款分块,如果不是法规格式则按大小分块
|
||||
chunks = chunker.chunk_by_clause(text)
|
||||
if len(chunks) == 0:
|
||||
chunks = chunker.chunk_by_size(text)
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Step 4: 保存分块结果
|
||||
set_task_status(task_id, {
|
||||
"status": "running",
|
||||
"step": "saving",
|
||||
"progress": 80,
|
||||
"message": f"正在保存 {len(chunks)} 个文本块...",
|
||||
})
|
||||
|
||||
# TODO: 将分块存储到数据库或向量库
|
||||
# 这里先统计数量
|
||||
|
||||
|
||||
|
||||
|
||||
chunk_count = len(chunks)
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Step 5: 完成
|
||||
set_task_status(task_id, {
|
||||
"status": "completed",
|
||||
"step": "done",
|
||||
"progress": 100,
|
||||
"message": f"解析完成,共生成 {chunk_count} 个文本块",
|
||||
"completed_at": datetime.now(),
|
||||
"result": {
|
||||
"doc_id": doc_id,
|
||||
"chunks": chunk_count,
|
||||
"parsed_path": parsed_path,
|
||||
}
|
||||
})
|
||||
|
||||
db_service.update_document_status(
|
||||
doc_id,
|
||||
DocStatus.parsed.value,
|
||||
chunks=chunk_count,
|
||||
)
|
||||
|
||||
logger.info(f"Parse workflow completed for doc {doc_id}: {chunk_count} chunks")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Parse workflow failed for doc {doc_id}: {e}")
|
||||
|
||||
set_task_status(task_id, {
|
||||
"status": "failed",
|
||||
"step": "error",
|
||||
"progress": 0,
|
||||
"message": str(e),
|
||||
"completed_at": datetime.now(),
|
||||
})
|
||||
|
||||
db_service.update_document_status(
|
||||
doc_id,
|
||||
DocStatus.failed.value,
|
||||
error_message=str(e),
|
||||
)
|
||||
|
||||
|
||||
async def run_embedding_workflow(task_id: str, doc_id: str):
|
||||
"""
|
||||
执行向量化工作流
|
||||
|
||||
处理步骤:
|
||||
1. 获取分块数据
|
||||
2. 生成向量嵌入
|
||||
3. 存入向量数据库
|
||||
|
||||
Args:
|
||||
task_id: 任务ID
|
||||
doc_id: 文档ID
|
||||
"""
|
||||
try:
|
||||
# Step 1: 获取分块
|
||||
set_task_status(task_id, {
|
||||
"status": "running",
|
||||
"step": "fetching_chunks",
|
||||
"progress": 10,
|
||||
"message": "正在获取文本分块...",
|
||||
"started_at": datetime.now(),
|
||||
})
|
||||
db_service.update_document_status(doc_id, DocStatus.embedding.value)
|
||||
|
||||
doc = db_service.get_document(doc_id)
|
||||
if not doc:
|
||||
raise ValueError(f"Document {doc_id} not found")
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Step 2: 生成嵌入
|
||||
set_task_status(task_id, {
|
||||
"status": "running",
|
||||
"step": "embedding",
|
||||
"progress": 40,
|
||||
"message": "正在生成向量嵌入...",
|
||||
})
|
||||
|
||||
# TODO: 调用 Embedding 服务生成向量
|
||||
# 这里先模拟处理
|
||||
vector_count = doc.chunks
|
||||
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Step 3: 存入向量库
|
||||
set_task_status(task_id, {
|
||||
"status": "running",
|
||||
"step": "storing",
|
||||
"progress": 70,
|
||||
"message": "正在存入向量数据库...",
|
||||
})
|
||||
|
||||
# TODO: 存入 Milvus
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Step 4: 完成
|
||||
set_task_status(task_id, {
|
||||
"status": "completed",
|
||||
"step": "done",
|
||||
"progress": 100,
|
||||
"message": f"向量化完成,共处理 {vector_count} 个向量",
|
||||
"completed_at": datetime.now(),
|
||||
"result": {
|
||||
"doc_id": doc_id,
|
||||
"vectors": vector_count,
|
||||
}
|
||||
})
|
||||
|
||||
db_service.update_document_status(
|
||||
doc_id,
|
||||
DocStatus.indexed.value,
|
||||
vectors=vector_count,
|
||||
)
|
||||
|
||||
logger.info(f"Embedding workflow completed for doc {doc_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Embedding workflow failed for doc {doc_id}: {e}")
|
||||
|
||||
set_task_status(task_id, {
|
||||
"status": "failed",
|
||||
"step": "error",
|
||||
"progress": 0,
|
||||
"message": str(e),
|
||||
"completed_at": datetime.now(),
|
||||
})
|
||||
|
||||
db_service.update_document_status(
|
||||
doc_id,
|
||||
DocStatus.failed.value,
|
||||
error_message=str(e),
|
||||
)
|
||||
Reference in New Issue
Block a user