Refactor code structure for improved readability and maintainability

This commit is contained in:
2026-05-18 11:41:20 +08:00
parent d39de39f96
commit 3f154a3077
43 changed files with 5046 additions and 113 deletions

View File

@@ -1,4 +1,9 @@
# Import mock data service
# Import services
from .minio import minio_service, MinioService
from .database import db_service, DatabaseService, init_db, Document, ParseTask
from .tasks import task_manager, get_task_status, set_task_status, generate_task_id
# Import mock data service (for development)
from .mock_data import (
get_mock_documents,
get_mock_quick_questions,
@@ -29,6 +34,18 @@ except ImportError:
get_document_service = None
__all__ = [
# Core services
"minio_service",
"MinioService",
"db_service",
"DatabaseService",
"init_db",
"Document",
"ParseTask",
"task_manager",
"get_task_status",
"set_task_status",
"generate_task_id",
# Mock data services
"get_mock_documents",
"get_mock_quick_questions",

228
app/services/database.py Normal file
View File

@@ -0,0 +1,228 @@
"""数据库服务 - PostgreSQL"""
from sqlalchemy import create_engine, Column, String, Integer, DateTime, Enum, Text
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, Session
from datetime import datetime
from typing import Optional, List
import enum
from app.core.config import settings
from app.utils.logger import logger
# 数据库连接
DATABASE_URL = f"postgresql://{settings.postgres_user}:{settings.postgres_password}@{settings.postgres_host}:{settings.postgres_port}/{settings.postgres_db}"
engine = create_engine(DATABASE_URL, echo=False, pool_pre_ping=True)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base()
class DocStatus(str, enum.Enum):
"""文档处理状态"""
uploaded = "uploaded" # 已上传
parsing = "parsing" # 解析中
parsed = "parsed" # 已解析
embedding = "embedding" # 向量化中
indexed = "indexed" # 已索引
failed = "failed" # 处理失败
class Document(Base):
"""文档表"""
__tablename__ = "documents"
id = Column(String(64), primary_key=True)
filename = Column(String(255), nullable=False)
original_name = Column(String(255), nullable=False)
minio_path = Column(String(512), nullable=False) # MinIO 存储路径
size = Column(Integer, default=0)
status = Column(String(32), default=DocStatus.uploaded.value)
chunks = Column(Integer, default=0)
vectors = Column(Integer, default=0)
error_message = Column(Text, nullable=True)
created_at = Column(DateTime, default=datetime.now)
updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now)
def to_dict(self):
return {
"id": self.id,
"name": self.original_name,
"chunks": self.chunks,
"status": self.status,
"created_at": self.created_at.isoformat() if self.created_at else None,
}
class ParseTask(Base):
"""解析任务表"""
__tablename__ = "parse_tasks"
id = Column(String(64), primary_key=True)
doc_id = Column(String(64), nullable=False)
status = Column(String(32), default="pending")
progress = Column(Integer, default=0)
message = Column(Text, nullable=True)
created_at = Column(DateTime, default=datetime.now)
started_at = Column(DateTime, nullable=True)
completed_at = Column(DateTime, nullable=True)
def init_db():
"""初始化数据库表"""
try:
Base.metadata.create_all(bind=engine)
logger.info("Database tables created successfully")
except Exception as e:
logger.error(f"Database initialization failed: {e}")
def get_db() -> Session:
"""获取数据库会话"""
db = SessionLocal()
try:
return db
finally:
# 注意:调用者需要负责关闭会话
pass
class DatabaseService:
"""数据库服务"""
def __init__(self):
self.engine = engine
self.SessionLocal = SessionLocal
def create_document(
self,
doc_id: str,
filename: str,
original_name: str,
minio_path: str,
size: int,
) -> Document:
"""创建文档记录"""
db = self.SessionLocal()
try:
doc = Document(
id=doc_id,
filename=filename,
original_name=original_name,
minio_path=minio_path,
size=size,
status=DocStatus.uploaded.value,
)
db.add(doc)
db.commit()
db.refresh(doc)
return doc
finally:
db.close()
def get_document(self, doc_id: str) -> Optional[Document]:
"""获取文档"""
db = self.SessionLocal()
try:
return db.query(Document).filter(Document.id == doc_id).first()
finally:
db.close()
def update_document_status(
self,
doc_id: str,
status: str,
chunks: int = None,
vectors: int = None,
error_message: str = None,
) -> Optional[Document]:
"""更新文档状态"""
db = self.SessionLocal()
try:
doc = db.query(Document).filter(Document.id == doc_id).first()
if doc:
doc.status = status
doc.updated_at = datetime.now()
if chunks is not None:
doc.chunks = chunks
if vectors is not None:
doc.vectors = vectors
if error_message:
doc.error_message = error_message
db.commit()
db.refresh(doc)
return doc
finally:
db.close()
def list_documents(self) -> List[Document]:
"""列出所有文档"""
db = self.SessionLocal()
try:
return db.query(Document).order_by(Document.created_at.desc()).all()
finally:
db.close()
def delete_document(self, doc_id: str) -> bool:
"""删除文档"""
db = self.SessionLocal()
try:
doc = db.query(Document).filter(Document.id == doc_id).first()
if doc:
db.delete(doc)
db.commit()
return True
return False
finally:
db.close()
def create_parse_task(self, task_id: str, doc_id: str) -> ParseTask:
"""创建解析任务"""
db = self.SessionLocal()
try:
task = ParseTask(id=task_id, doc_id=doc_id)
db.add(task)
db.commit()
db.refresh(task)
return task
finally:
db.close()
def get_parse_task(self, task_id: str) -> Optional[ParseTask]:
"""获取解析任务"""
db = self.SessionLocal()
try:
return db.query(ParseTask).filter(ParseTask.id == task_id).first()
finally:
db.close()
def update_parse_task(
self,
task_id: str,
status: str,
progress: int = None,
message: str = None,
) -> Optional[ParseTask]:
"""更新解析任务状态"""
db = self.SessionLocal()
try:
task = db.query(ParseTask).filter(ParseTask.id == task_id).first()
if task:
task.status = status
if progress is not None:
task.progress = progress
if message:
task.message = message
if status == "running":
task.started_at = datetime.now()
elif status in ("completed", "failed"):
task.completed_at = datetime.now()
db.commit()
db.refresh(task)
return task
finally:
db.close()
# 单例
db_service = DatabaseService()

122
app/services/minio.py Normal file
View File

@@ -0,0 +1,122 @@
"""MinIO 文件存储服务"""
import io
from minio import Minio
from minio.error import S3Error
from app.core.config import settings
from app.utils.logger import logger
class MinioService:
"""MinIO 文件存储服务"""
def __init__(self):
self.client = Minio(
settings.minio_endpoint,
access_key=settings.minio_access_key,
secret_key=settings.minio_secret_key,
secure=settings.minio_secure,
)
self.bucket = settings.minio_bucket
self._ensure_bucket()
def _ensure_bucket(self):
"""确保存储桶存在"""
try:
if not self.client.bucket_exists(self.bucket):
self.client.make_bucket(self.bucket)
logger.info(f"Created MinIO bucket: {self.bucket}")
except S3Error as e:
logger.error(f"MinIO bucket check failed: {e}")
def upload_file(
self,
object_name: str,
file_data: bytes,
content_type: str = "application/octet-stream",
) -> str:
"""
上传文件到 MinIO
Args:
object_name: 对象名称(文件路径)
file_data: 文件二进制数据
content_type: 文件类型
Returns:
文件的 MinIO URL
"""
try:
data_stream = io.BytesIO(file_data)
self.client.put_object(
self.bucket,
object_name,
data_stream,
length=len(file_data),
content_type=content_type,
)
url = f"{settings.minio_endpoint}/{self.bucket}/{object_name}"
logger.info(f"Uploaded file to MinIO: {object_name}")
return url
except S3Error as e:
logger.error(f"MinIO upload failed: {e}")
raise
def get_file(self, object_name: str) -> bytes:
"""
从 MinIO 获取文件
Args:
object_name: 对象名称
Returns:
文件二进制数据
"""
try:
response = self.client.get_object(self.bucket, object_name)
data = response.read()
response.close()
response.release_conn()
return data
except S3Error as e:
logger.error(f"MinIO get file failed: {e}")
raise
def delete_file(self, object_name: str) -> bool:
"""
删除 MinIO 中的文件
Args:
object_name: 对象名称
Returns:
是否成功删除
"""
try:
self.client.remove_object(self.bucket, object_name)
logger.info(f"Deleted file from MinIO: {object_name}")
return True
except S3Error as e:
logger.error(f"MinIO delete failed: {e}")
return False
def list_files(self, prefix: str = "") -> list[str]:
"""
列出 MinIO 中的文件
Args:
prefix: 文件前缀过滤
Returns:
文件名列表
"""
try:
objects = self.client.list_objects(self.bucket, prefix=prefix)
return [obj.object_name for obj in objects]
except S3Error as e:
logger.error(f"MinIO list files failed: {e}")
return []
# 单例
minio_service = MinioService()

89
app/services/tasks.py Normal file
View File

@@ -0,0 +1,89 @@
"""异步任务处理模块
TODO: 后续替换为 RabbitMQ 消息队列
"""
import asyncio
import uuid
from datetime import datetime
from typing import Callable, Awaitable
from app.utils.logger import logger
# 任务状态存储(后续替换为 Redis
_task_store: Dict[str, Dict] = {}
def generate_task_id() -> str:
"""生成任务ID"""
return f"task-{uuid.uuid4().hex[:12]}"
class AsyncTaskManager:
"""异步任务管理器"""
def __init__(self):
self._running_tasks: dict[str, asyncio.Task] = {}
def create_task(
self,
task_id: str,
task_func: Callable[[str], Awaitable[None]],
) -> asyncio.Task:
"""
创建异步任务
Args:
task_id: 任务ID
task_func: 任务执行函数
Returns:
asyncio.Task
"""
task = asyncio.create_task(self._run_task(task_id, task_func))
self._running_tasks[task_id] = task
return task
async def _run_task(
self,
task_id: str,
task_func: Callable[[str], Awaitable[None]],
):
"""运行任务并处理状态"""
try:
await task_func(task_id)
except Exception as e:
logger.error(f"Task {task_id} failed: {e}")
_task_store[task_id] = {
"status": "failed",
"error": str(e),
"completed_at": datetime.now(),
}
finally:
if task_id in self._running_tasks:
del self._running_tasks[task_id]
def get_task_status(self, task_id: str) -> dict | None:
"""获取任务状态"""
return _task_store.get(task_id)
def cancel_task(self, task_id: str) -> bool:
"""取消任务"""
if task_id in self._running_tasks:
self._running_tasks[task_id].cancel()
return True
return False
# 单例
task_manager = AsyncTaskManager()
def get_task_status(task_id: str) -> dict | None:
"""获取任务状态"""
return _task_store.get(task_id)
def set_task_status(task_id: str, status: dict):
"""设置任务状态"""
_task_store[task_id] = status