Refactor code structure for improved readability and maintainability

This commit is contained in:
2026-05-18 11:41:20 +08:00
parent d39de39f96
commit 3f154a3077
43 changed files with 5046 additions and 113 deletions

View File

@@ -1,115 +1,299 @@
from fastapi import APIRouter, UploadFile, File, HTTPException
"""文档管理 API"""
from fastapi import APIRouter, UploadFile, File, HTTPException, BackgroundTasks
import os
import uuid
from datetime import datetime
from typing import Optional
from app.schemas.doc import (
DocumentUploadResponse,
DocumentListResponse,
DocumentInfo,
ParseResponse,
EmbedResponse,
TaskStatusResponse,
)
from app.services.mock_data import get_mock_documents, generate_doc_id
from app.core.config import settings
from app.services.minio import minio_service
from app.services.database import db_service, init_db, DocStatus
from app.services.tasks import generate_task_id, task_manager, get_task_status
from app.workflows.document_workflow import (
generate_doc_id,
run_parse_workflow,
run_embedding_workflow,
)
from app.utils.logger import logger
router = APIRouter(prefix="/docs", tags=["文档管理"])
# 临时存储文档信息包含预设的mock文档
documents_store: dict[str, dict] = {}
# 启动时初始化数据库
init_db()
# 初始化时加载mock文档
for doc in get_mock_documents():
documents_store[doc["id"]] = doc
def get_content_type(filename: str) -> str:
"""根据文件扩展名获取 Content-Type"""
ext = os.path.splitext(filename)[1].lower()
content_types = {
".pdf": "application/pdf",
".docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
".doc": "application/msword",
".txt": "text/plain",
}
return content_types.get(ext, "application/octet-stream")
@router.post("/upload", response_model=DocumentUploadResponse)
async def upload_document(file: UploadFile = File(...)):
"""上传法规文档"""
async def upload_document(
file: UploadFile = File(...),
background_tasks: BackgroundTasks = None,
):
"""
上传法规文档到 MinIO并自动触发异步解析
流程:
1. 验证文件格式
2. 生成文档ID
3. 上传到 MinIO
4. 创建数据库记录
5. 触发异步解析任务(后续可替换为 RabbitMQ
"""
# 检查文件格式
allowed_ext = [".pdf", ".docx", ".doc", ".txt"]
ext = os.path.splitext(file.filename)[1].lower()
if ext not in allowed_ext:
raise HTTPException(400, f"Unsupported file format: {ext}")
# 检查文件大小
content = await file.read()
max_size = 50 * 1024 * 1024 # 50MB
if len(content) > max_size:
raise HTTPException(400, f"File size exceeds limit: {max_size // 1024 // 1024}MB")
# 生成文档ID
doc_id = generate_doc_id()
# 保存文件
raw_dir = "/airegulation/demo-mao/backend/data/raw"
os.makedirs(raw_dir, exist_ok=True)
file_path = os.path.join(raw_dir, f"{doc_id}_{file.filename}")
# 构建 MinIO 存储路径
storage_filename = f"{doc_id}_{file.filename}"
minio_path = f"documents/{storage_filename}"
content = await file.read()
with open(file_path, "wb") as f:
f.write(content)
try:
# 上传到 MinIO
content_type = get_content_type(file.filename)
minio_url = minio_service.upload_file(
minio_path,
content,
content_type,
)
# 记录文档信息
documents_store[doc_id] = {
"id": doc_id,
"name": file.filename,
"path": file_path,
"size": len(content),
"status": "uploaded",
"chunks": 0,
"created_at": datetime.now(),
}
# 创建数据库记录
doc = db_service.create_document(
doc_id=doc_id,
filename=storage_filename,
original_name=file.filename,
minio_path=minio_path,
size=len(content),
)
return DocumentUploadResponse(
doc_id=doc_id,
filename=file.filename,
size=len(content),
)
logger.info(f"Document uploaded: {doc_id} - {file.filename}")
# 触发异步解析任务
parse_task_id = generate_task_id()
db_service.create_parse_task(parse_task_id, doc_id)
# 使用 asyncio 异步执行解析(后续替换为 RabbitMQ
background_tasks.add_task(
run_parse_workflow_sync,
parse_task_id,
doc_id,
)
return DocumentUploadResponse(
doc_id=doc_id,
filename=file.filename,
size=len(content),
status="uploaded",
parse_task_id=parse_task_id,
)
except Exception as e:
logger.error(f"Upload failed: {e}")
raise HTTPException(500, f"Upload failed: {str(e)}")
def run_parse_workflow_sync(task_id: str, doc_id: str):
"""同步包装器,用于 BackgroundTasks"""
import asyncio
asyncio.run(run_parse_workflow(task_id, doc_id))
@router.get("/list", response_model=DocumentListResponse)
async def list_documents():
"""获取已索引文档列表"""
docs = [
DocumentInfo(
id=d["id"],
name=d["name"],
chunks=d["chunks"],
status=d["status"],
created_at=d.get("created_at"),
)
for d in documents_store.values()
]
return DocumentListResponse(docs=docs)
docs = db_service.list_documents()
return DocumentListResponse(
docs=[
DocumentInfo(
id=d.id,
name=d.original_name,
chunks=d.chunks,
status=d.status,
created_at=d.created_at,
)
for d in docs
]
)
@router.get("/{doc_id}", response_model=DocumentInfo)
async def get_document(doc_id: str):
"""获取单个文档信息"""
doc = db_service.get_document(doc_id)
if not doc:
raise HTTPException(404, "Document not found")
return DocumentInfo(
id=doc.id,
name=doc.original_name,
chunks=doc.chunks,
status=doc.status,
created_at=doc.created_at,
)
@router.post("/parse/{doc_id}", response_model=ParseResponse)
async def parse_document(doc_id: str):
"""解析文档并分块"""
if doc_id not in documents_store:
async def parse_document(
doc_id: str,
background_tasks: BackgroundTasks = None,
):
"""
手动触发文档解析(如果文档已上传但未解析)
"""
doc = db_service.get_document(doc_id)
if not doc:
raise HTTPException(404, "Document not found")
doc = documents_store[doc_id]
# 模拟解析逻辑
doc["status"] = "parsed"
# 根据文件大小计算chunks数量
file_size = doc.get("size", 100000)
doc["chunks"] = max(20, file_size // 8000)
if doc.status not in [DocStatus.uploaded.value, DocStatus.failed.value]:
raise HTTPException(400, f"Document status is {doc.status}, cannot parse")
return ParseResponse(doc_id=doc_id, chunks=doc["chunks"])
# 创建解析任务
task_id = generate_task_id()
db_service.create_parse_task(task_id, doc_id)
# 异步执行
background_tasks.add_task(
run_parse_workflow_sync,
task_id,
doc_id,
)
return ParseResponse(
doc_id=doc_id,
task_id=task_id,
status="parsing",
)
@router.post("/embed/{doc_id}", response_model=EmbedResponse)
async def embed_document(doc_id: str):
"""嵌入并存入向量库"""
if doc_id not in documents_store:
async def embed_document(
doc_id: str,
background_tasks: BackgroundTasks = None,
):
"""
触发文档向量化(需要文档已解析)
"""
doc = db_service.get_document(doc_id)
if not doc:
raise HTTPException(404, "Document not found")
doc = documents_store[doc_id]
# 模拟嵌入逻辑
doc["status"] = "indexed"
if doc.status != DocStatus.parsed.value:
raise HTTPException(400, f"Document must be parsed first. Current status: {doc.status}")
return EmbedResponse(doc_id=doc_id, vectors=doc["chunks"])
# 创建向量化任务
task_id = generate_task_id()
db_service.create_parse_task(task_id, doc_id)
# 异步执行
background_tasks.add_task(
run_embedding_workflow_sync,
task_id,
doc_id,
)
return EmbedResponse(
doc_id=doc_id,
task_id=task_id,
status="embedding",
)
def run_embedding_workflow_sync(task_id: str, doc_id: str):
"""同步包装器,用于 BackgroundTasks"""
import asyncio
asyncio.run(run_embedding_workflow(task_id, doc_id))
@router.get("/task/{task_id}", response_model=TaskStatusResponse)
async def get_task_status_api(task_id: str):
"""获取任务状态"""
status = get_task_status(task_id)
if not status:
# 检查数据库中的任务记录
task = db_service.get_parse_task(task_id)
if task:
return TaskStatusResponse(
task_id=task_id,
status=task.status,
progress=task.progress or 0,
message=task.message,
)
raise HTTPException(404, "Task not found")
return TaskStatusResponse(
task_id=task_id,
status=status.get("status", "unknown"),
progress=status.get("progress", 0),
message=status.get("message"),
result=status.get("result"),
)
@router.delete("/delete/{doc_id}")
async def delete_document(doc_id: str):
"""删除文档"""
if doc_id not in documents_store:
"""
删除文档
同时删除:
- MinIO 中的文件
- 数据库中的记录
- 解析后的文本文件
"""
doc = db_service.get_document(doc_id)
if not doc:
raise HTTPException(404, "Document not found")
del documents_store[doc_id]
return {"success": True}
try:
# 删除 MinIO 文件
minio_service.delete_file(doc.minio_path)
# 删除本地解析文件
parsed_path = f"{settings.data_parsed_dir}/{doc_id}.txt"
if os.path.exists(parsed_path):
os.remove(parsed_path)
# 删除本地临时文件
temp_path = f"{settings.data_raw_dir}/{doc.filename}"
if os.path.exists(temp_path):
os.remove(temp_path)
# 删除数据库记录
db_service.delete_document(doc_id)
logger.info(f"Document deleted: {doc_id}")
return {"success": True, "doc_id": doc_id}
except Exception as e:
logger.error(f"Delete failed: {e}")
raise HTTPException(500, f"Delete failed: {str(e)}")

View File

@@ -3,13 +3,51 @@ from typing import Optional
class Settings(BaseSettings):
# DashScope API
dashscope_api_key: str = ""
# Qwen API配置
qwen_api_key: str = ""
qwen_base_url: str = "https://dashscope.aliyuncs.com/api/v1"
qwen_model: str = "qwen-max"
qwen_vl_model: str = "qwen-vl-plus"
# DeepSeek API配置
deepseek_api_key: str = ""
deepseek_base_url: str = "https://api.deepseek.com/v1"
deepseek_model: str = "deepseek-v3"
# PostgreSQL
postgres_host: str = "localhost"
postgres_port: int = 5432
postgres_user: str = "postgresql"
postgres_password: str = "postgresql123456"
postgres_db: str = "mydb"
# Redis
redis_host: str = "localhost"
redis_port: int = 6379
redis_password: str = ""
# MinIO
minio_endpoint: str = "localhost:9000"
minio_access_key: str = "minioadmin"
minio_secret_key: str = "minioadmin"
minio_bucket: str = "regulation-docs"
minio_secure: bool = False
# Milvus
milvus_host: str = "localhost"
milvus_port: int = 19530
# Neo4j
neo4j_uri: str = "bolt://localhost:7687"
neo4j_user: str = "neo4j"
neo4j_password: str = "neo4j123"
# RabbitMQ
rabbitmq_host: str = "localhost"
rabbitmq_port: int = 5672
rabbitmq_user: str = "admin"
rabbitmq_password: str = "admin@123"
# LLM配置
llm_model: str = "qwen-max"
embedding_model: str = "text-embedding-v3"
@@ -32,6 +70,10 @@ class Settings(BaseSettings):
regulations_collection: str = "vehicle_regulations"
compliance_collection: str = "compliance_cache"
# 数据目录
data_raw_dir: str = "/airegulation/demo-mao/backend/data/raw"
data_parsed_dir: str = "/airegulation/demo-mao/backend/data/parsed"
class Config:
env_file = ".env"
env_file_encoding = "utf-8"

View File

@@ -1,16 +1,21 @@
"""文档相关数据模型"""
from pydantic import BaseModel
from typing import Optional
from typing import Optional, Any
from datetime import datetime
class DocumentUploadResponse(BaseModel):
"""文档上传响应"""
doc_id: str
filename: str
size: int
status: str = "uploaded"
parse_task_id: Optional[str] = None # 解析任务ID
class DocumentInfo(BaseModel):
"""文档信息"""
id: str
name: str
chunks: int
@@ -19,10 +24,12 @@ class DocumentInfo(BaseModel):
class DocumentListResponse(BaseModel):
"""文档列表响应"""
docs: list[DocumentInfo]
class ChunkInfo(BaseModel):
"""文本块信息"""
chunk_id: str
doc_name: str
clause_id: Optional[str] = None
@@ -33,12 +40,25 @@ class ChunkInfo(BaseModel):
class ParseResponse(BaseModel):
"""解析响应"""
doc_id: str
chunks: int
status: str = "parsed"
task_id: Optional[str] = None
chunks: int = 0
status: str = "parsing"
class EmbedResponse(BaseModel):
"""嵌入响应"""
doc_id: str
vectors: int
status: str = "embedded"
task_id: Optional[str] = None
vectors: int = 0
status: str = "embedding"
class TaskStatusResponse(BaseModel):
"""任务状态响应"""
task_id: str
status: str
progress: int
message: Optional[str] = None
result: Optional[Any] = None

View File

@@ -1,4 +1,9 @@
# Import mock data service
# Import services
from .minio import minio_service, MinioService
from .database import db_service, DatabaseService, init_db, Document, ParseTask
from .tasks import task_manager, get_task_status, set_task_status, generate_task_id
# Import mock data service (for development)
from .mock_data import (
get_mock_documents,
get_mock_quick_questions,
@@ -29,6 +34,18 @@ except ImportError:
get_document_service = None
__all__ = [
# Core services
"minio_service",
"MinioService",
"db_service",
"DatabaseService",
"init_db",
"Document",
"ParseTask",
"task_manager",
"get_task_status",
"set_task_status",
"generate_task_id",
# Mock data services
"get_mock_documents",
"get_mock_quick_questions",

228
app/services/database.py Normal file
View File

@@ -0,0 +1,228 @@
"""数据库服务 - PostgreSQL"""
from sqlalchemy import create_engine, Column, String, Integer, DateTime, Enum, Text
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, Session
from datetime import datetime
from typing import Optional, List
import enum
from app.core.config import settings
from app.utils.logger import logger
# 数据库连接
DATABASE_URL = f"postgresql://{settings.postgres_user}:{settings.postgres_password}@{settings.postgres_host}:{settings.postgres_port}/{settings.postgres_db}"
engine = create_engine(DATABASE_URL, echo=False, pool_pre_ping=True)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base()
class DocStatus(str, enum.Enum):
"""文档处理状态"""
uploaded = "uploaded" # 已上传
parsing = "parsing" # 解析中
parsed = "parsed" # 已解析
embedding = "embedding" # 向量化中
indexed = "indexed" # 已索引
failed = "failed" # 处理失败
class Document(Base):
"""文档表"""
__tablename__ = "documents"
id = Column(String(64), primary_key=True)
filename = Column(String(255), nullable=False)
original_name = Column(String(255), nullable=False)
minio_path = Column(String(512), nullable=False) # MinIO 存储路径
size = Column(Integer, default=0)
status = Column(String(32), default=DocStatus.uploaded.value)
chunks = Column(Integer, default=0)
vectors = Column(Integer, default=0)
error_message = Column(Text, nullable=True)
created_at = Column(DateTime, default=datetime.now)
updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now)
def to_dict(self):
return {
"id": self.id,
"name": self.original_name,
"chunks": self.chunks,
"status": self.status,
"created_at": self.created_at.isoformat() if self.created_at else None,
}
class ParseTask(Base):
"""解析任务表"""
__tablename__ = "parse_tasks"
id = Column(String(64), primary_key=True)
doc_id = Column(String(64), nullable=False)
status = Column(String(32), default="pending")
progress = Column(Integer, default=0)
message = Column(Text, nullable=True)
created_at = Column(DateTime, default=datetime.now)
started_at = Column(DateTime, nullable=True)
completed_at = Column(DateTime, nullable=True)
def init_db():
"""初始化数据库表"""
try:
Base.metadata.create_all(bind=engine)
logger.info("Database tables created successfully")
except Exception as e:
logger.error(f"Database initialization failed: {e}")
def get_db() -> Session:
"""获取数据库会话"""
db = SessionLocal()
try:
return db
finally:
# 注意:调用者需要负责关闭会话
pass
class DatabaseService:
"""数据库服务"""
def __init__(self):
self.engine = engine
self.SessionLocal = SessionLocal
def create_document(
self,
doc_id: str,
filename: str,
original_name: str,
minio_path: str,
size: int,
) -> Document:
"""创建文档记录"""
db = self.SessionLocal()
try:
doc = Document(
id=doc_id,
filename=filename,
original_name=original_name,
minio_path=minio_path,
size=size,
status=DocStatus.uploaded.value,
)
db.add(doc)
db.commit()
db.refresh(doc)
return doc
finally:
db.close()
def get_document(self, doc_id: str) -> Optional[Document]:
"""获取文档"""
db = self.SessionLocal()
try:
return db.query(Document).filter(Document.id == doc_id).first()
finally:
db.close()
def update_document_status(
self,
doc_id: str,
status: str,
chunks: int = None,
vectors: int = None,
error_message: str = None,
) -> Optional[Document]:
"""更新文档状态"""
db = self.SessionLocal()
try:
doc = db.query(Document).filter(Document.id == doc_id).first()
if doc:
doc.status = status
doc.updated_at = datetime.now()
if chunks is not None:
doc.chunks = chunks
if vectors is not None:
doc.vectors = vectors
if error_message:
doc.error_message = error_message
db.commit()
db.refresh(doc)
return doc
finally:
db.close()
def list_documents(self) -> List[Document]:
"""列出所有文档"""
db = self.SessionLocal()
try:
return db.query(Document).order_by(Document.created_at.desc()).all()
finally:
db.close()
def delete_document(self, doc_id: str) -> bool:
"""删除文档"""
db = self.SessionLocal()
try:
doc = db.query(Document).filter(Document.id == doc_id).first()
if doc:
db.delete(doc)
db.commit()
return True
return False
finally:
db.close()
def create_parse_task(self, task_id: str, doc_id: str) -> ParseTask:
"""创建解析任务"""
db = self.SessionLocal()
try:
task = ParseTask(id=task_id, doc_id=doc_id)
db.add(task)
db.commit()
db.refresh(task)
return task
finally:
db.close()
def get_parse_task(self, task_id: str) -> Optional[ParseTask]:
"""获取解析任务"""
db = self.SessionLocal()
try:
return db.query(ParseTask).filter(ParseTask.id == task_id).first()
finally:
db.close()
def update_parse_task(
self,
task_id: str,
status: str,
progress: int = None,
message: str = None,
) -> Optional[ParseTask]:
"""更新解析任务状态"""
db = self.SessionLocal()
try:
task = db.query(ParseTask).filter(ParseTask.id == task_id).first()
if task:
task.status = status
if progress is not None:
task.progress = progress
if message:
task.message = message
if status == "running":
task.started_at = datetime.now()
elif status in ("completed", "failed"):
task.completed_at = datetime.now()
db.commit()
db.refresh(task)
return task
finally:
db.close()
# 单例
db_service = DatabaseService()

122
app/services/minio.py Normal file
View File

@@ -0,0 +1,122 @@
"""MinIO 文件存储服务"""
import io
from minio import Minio
from minio.error import S3Error
from app.core.config import settings
from app.utils.logger import logger
class MinioService:
"""MinIO 文件存储服务"""
def __init__(self):
self.client = Minio(
settings.minio_endpoint,
access_key=settings.minio_access_key,
secret_key=settings.minio_secret_key,
secure=settings.minio_secure,
)
self.bucket = settings.minio_bucket
self._ensure_bucket()
def _ensure_bucket(self):
"""确保存储桶存在"""
try:
if not self.client.bucket_exists(self.bucket):
self.client.make_bucket(self.bucket)
logger.info(f"Created MinIO bucket: {self.bucket}")
except S3Error as e:
logger.error(f"MinIO bucket check failed: {e}")
def upload_file(
self,
object_name: str,
file_data: bytes,
content_type: str = "application/octet-stream",
) -> str:
"""
上传文件到 MinIO
Args:
object_name: 对象名称(文件路径)
file_data: 文件二进制数据
content_type: 文件类型
Returns:
文件的 MinIO URL
"""
try:
data_stream = io.BytesIO(file_data)
self.client.put_object(
self.bucket,
object_name,
data_stream,
length=len(file_data),
content_type=content_type,
)
url = f"{settings.minio_endpoint}/{self.bucket}/{object_name}"
logger.info(f"Uploaded file to MinIO: {object_name}")
return url
except S3Error as e:
logger.error(f"MinIO upload failed: {e}")
raise
def get_file(self, object_name: str) -> bytes:
"""
从 MinIO 获取文件
Args:
object_name: 对象名称
Returns:
文件二进制数据
"""
try:
response = self.client.get_object(self.bucket, object_name)
data = response.read()
response.close()
response.release_conn()
return data
except S3Error as e:
logger.error(f"MinIO get file failed: {e}")
raise
def delete_file(self, object_name: str) -> bool:
"""
删除 MinIO 中的文件
Args:
object_name: 对象名称
Returns:
是否成功删除
"""
try:
self.client.remove_object(self.bucket, object_name)
logger.info(f"Deleted file from MinIO: {object_name}")
return True
except S3Error as e:
logger.error(f"MinIO delete failed: {e}")
return False
def list_files(self, prefix: str = "") -> list[str]:
"""
列出 MinIO 中的文件
Args:
prefix: 文件前缀过滤
Returns:
文件名列表
"""
try:
objects = self.client.list_objects(self.bucket, prefix=prefix)
return [obj.object_name for obj in objects]
except S3Error as e:
logger.error(f"MinIO list files failed: {e}")
return []
# 单例
minio_service = MinioService()

89
app/services/tasks.py Normal file
View File

@@ -0,0 +1,89 @@
"""异步任务处理模块
TODO: 后续替换为 RabbitMQ 消息队列
"""
import asyncio
import uuid
from datetime import datetime
from typing import Callable, Awaitable
from app.utils.logger import logger
# 任务状态存储(后续替换为 Redis
_task_store: Dict[str, Dict] = {}
def generate_task_id() -> str:
"""生成任务ID"""
return f"task-{uuid.uuid4().hex[:12]}"
class AsyncTaskManager:
"""异步任务管理器"""
def __init__(self):
self._running_tasks: dict[str, asyncio.Task] = {}
def create_task(
self,
task_id: str,
task_func: Callable[[str], Awaitable[None]],
) -> asyncio.Task:
"""
创建异步任务
Args:
task_id: 任务ID
task_func: 任务执行函数
Returns:
asyncio.Task
"""
task = asyncio.create_task(self._run_task(task_id, task_func))
self._running_tasks[task_id] = task
return task
async def _run_task(
self,
task_id: str,
task_func: Callable[[str], Awaitable[None]],
):
"""运行任务并处理状态"""
try:
await task_func(task_id)
except Exception as e:
logger.error(f"Task {task_id} failed: {e}")
_task_store[task_id] = {
"status": "failed",
"error": str(e),
"completed_at": datetime.now(),
}
finally:
if task_id in self._running_tasks:
del self._running_tasks[task_id]
def get_task_status(self, task_id: str) -> dict | None:
"""获取任务状态"""
return _task_store.get(task_id)
def cancel_task(self, task_id: str) -> bool:
"""取消任务"""
if task_id in self._running_tasks:
self._running_tasks[task_id].cancel()
return True
return False
# 单例
task_manager = AsyncTaskManager()
def get_task_status(task_id: str) -> dict | None:
"""获取任务状态"""
return _task_store.get(task_id)
def set_task_status(task_id: str, status: dict):
"""设置任务状态"""
_task_store[task_id] = status

View File

@@ -0,0 +1,252 @@
"""文档解析工作流 - 异步处理"""
import asyncio
import uuid
from datetime import datetime
from typing import List
import io
from app.core.config import settings
from app.services.minio import minio_service
from app.services.database import db_service, DocStatus
from app.services.tasks import set_task_status, get_task_status
from app.services.document import DocumentService
from app.utils.chunking import TextChunker
from app.utils.logger import logger
def generate_doc_id() -> str:
"""生成文档ID"""
return f"doc-{uuid.uuid4().hex[:12]}"
def generate_chunk_id(doc_id: str, index: int) -> str:
"""生成块ID"""
return f"{doc_id}-chunk-{index}"
async def run_parse_workflow(task_id: str, doc_id: str):
"""
执行文档解析工作流
处理步骤:
1. 获取文件 - 从 MinIO 下载文件
2. 解析文档 - 提取文本内容
3. 文本分块 - 按条款或固定大小分块
4. 保存结果 - 存储分块数据
Args:
task_id: 任务ID
doc_id: 文档ID
"""
chunker = TextChunker()
doc_service = DocumentService(settings.data_raw_dir, settings.data_parsed_dir)
try:
# Step 1: 获取文件
set_task_status(task_id, {
"status": "running",
"step": "fetching",
"progress": 10,
"message": "正在从存储获取文件...",
"started_at": datetime.now(),
})
db_service.update_document_status(doc_id, DocStatus.parsing.value)
doc = db_service.get_document(doc_id)
if not doc:
raise ValueError(f"Document {doc_id} not found")
# 从 MinIO 获取文件
file_data = minio_service.get_file(doc.minio_path)
# 保存到本地临时目录(用于解析)
temp_path = f"{settings.data_raw_dir}/{doc_id}_{doc.filename}"
with open(temp_path, "wb") as f:
f.write(file_data)
await asyncio.sleep(0.5) # 模拟延迟
# Step 2: 解析文档
set_task_status(task_id, {
"status": "running",
"step": "parsing",
"progress": 30,
"message": "正在解析文档内容...",
})
text = doc_service.parse_document(temp_path)
if not text:
raise ValueError("Document parsing returned empty content")
# 保存解析后的文本
parsed_path = doc_service.save_parsed_text(doc_id, text)
await asyncio.sleep(0.5)
# Step 3: 文本分块
set_task_status(task_id, {
"status": "running",
"step": "chunking",
"progress": 50,
"message": "正在进行文本分块...",
})
# 尝试按条款分块,如果不是法规格式则按大小分块
chunks = chunker.chunk_by_clause(text)
if len(chunks) == 0:
chunks = chunker.chunk_by_size(text)
await asyncio.sleep(0.5)
# Step 4: 保存分块结果
set_task_status(task_id, {
"status": "running",
"step": "saving",
"progress": 80,
"message": f"正在保存 {len(chunks)} 个文本块...",
})
# TODO: 将分块存储到数据库或向量库
# 这里先统计数量
chunk_count = len(chunks)
await asyncio.sleep(0.5)
# Step 5: 完成
set_task_status(task_id, {
"status": "completed",
"step": "done",
"progress": 100,
"message": f"解析完成,共生成 {chunk_count} 个文本块",
"completed_at": datetime.now(),
"result": {
"doc_id": doc_id,
"chunks": chunk_count,
"parsed_path": parsed_path,
}
})
db_service.update_document_status(
doc_id,
DocStatus.parsed.value,
chunks=chunk_count,
)
logger.info(f"Parse workflow completed for doc {doc_id}: {chunk_count} chunks")
except Exception as e:
logger.error(f"Parse workflow failed for doc {doc_id}: {e}")
set_task_status(task_id, {
"status": "failed",
"step": "error",
"progress": 0,
"message": str(e),
"completed_at": datetime.now(),
})
db_service.update_document_status(
doc_id,
DocStatus.failed.value,
error_message=str(e),
)
async def run_embedding_workflow(task_id: str, doc_id: str):
"""
执行向量化工作流
处理步骤:
1. 获取分块数据
2. 生成向量嵌入
3. 存入向量数据库
Args:
task_id: 任务ID
doc_id: 文档ID
"""
try:
# Step 1: 获取分块
set_task_status(task_id, {
"status": "running",
"step": "fetching_chunks",
"progress": 10,
"message": "正在获取文本分块...",
"started_at": datetime.now(),
})
db_service.update_document_status(doc_id, DocStatus.embedding.value)
doc = db_service.get_document(doc_id)
if not doc:
raise ValueError(f"Document {doc_id} not found")
await asyncio.sleep(0.5)
# Step 2: 生成嵌入
set_task_status(task_id, {
"status": "running",
"step": "embedding",
"progress": 40,
"message": "正在生成向量嵌入...",
})
# TODO: 调用 Embedding 服务生成向量
# 这里先模拟处理
vector_count = doc.chunks
await asyncio.sleep(1)
# Step 3: 存入向量库
set_task_status(task_id, {
"status": "running",
"step": "storing",
"progress": 70,
"message": "正在存入向量数据库...",
})
# TODO: 存入 Milvus
await asyncio.sleep(0.5)
# Step 4: 完成
set_task_status(task_id, {
"status": "completed",
"step": "done",
"progress": 100,
"message": f"向量化完成,共处理 {vector_count} 个向量",
"completed_at": datetime.now(),
"result": {
"doc_id": doc_id,
"vectors": vector_count,
}
})
db_service.update_document_status(
doc_id,
DocStatus.indexed.value,
vectors=vector_count,
)
logger.info(f"Embedding workflow completed for doc {doc_id}")
except Exception as e:
logger.error(f"Embedding workflow failed for doc {doc_id}: {e}")
set_task_status(task_id, {
"status": "failed",
"step": "error",
"progress": 0,
"message": str(e),
"completed_at": datetime.now(),
})
db_service.update_document_status(
doc_id,
DocStatus.failed.value,
error_message=str(e),
)