Refactor code structure for improved readability and maintainability
This commit is contained in:
@@ -1,4 +1,9 @@
|
||||
# Import mock data service
|
||||
# Import services
|
||||
from .minio import minio_service, MinioService
|
||||
from .database import db_service, DatabaseService, init_db, Document, ParseTask
|
||||
from .tasks import task_manager, get_task_status, set_task_status, generate_task_id
|
||||
|
||||
# Import mock data service (for development)
|
||||
from .mock_data import (
|
||||
get_mock_documents,
|
||||
get_mock_quick_questions,
|
||||
@@ -29,6 +34,18 @@ except ImportError:
|
||||
get_document_service = None
|
||||
|
||||
__all__ = [
|
||||
# Core services
|
||||
"minio_service",
|
||||
"MinioService",
|
||||
"db_service",
|
||||
"DatabaseService",
|
||||
"init_db",
|
||||
"Document",
|
||||
"ParseTask",
|
||||
"task_manager",
|
||||
"get_task_status",
|
||||
"set_task_status",
|
||||
"generate_task_id",
|
||||
# Mock data services
|
||||
"get_mock_documents",
|
||||
"get_mock_quick_questions",
|
||||
|
||||
228
app/services/database.py
Normal file
228
app/services/database.py
Normal file
@@ -0,0 +1,228 @@
|
||||
"""数据库服务 - PostgreSQL"""
|
||||
|
||||
from sqlalchemy import create_engine, Column, String, Integer, DateTime, Enum, Text
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import sessionmaker, Session
|
||||
from datetime import datetime
|
||||
from typing import Optional, List
|
||||
import enum
|
||||
from app.core.config import settings
|
||||
from app.utils.logger import logger
|
||||
|
||||
|
||||
# 数据库连接
|
||||
DATABASE_URL = f"postgresql://{settings.postgres_user}:{settings.postgres_password}@{settings.postgres_host}:{settings.postgres_port}/{settings.postgres_db}"
|
||||
|
||||
engine = create_engine(DATABASE_URL, echo=False, pool_pre_ping=True)
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
class DocStatus(str, enum.Enum):
|
||||
"""文档处理状态"""
|
||||
uploaded = "uploaded" # 已上传
|
||||
parsing = "parsing" # 解析中
|
||||
parsed = "parsed" # 已解析
|
||||
embedding = "embedding" # 向量化中
|
||||
indexed = "indexed" # 已索引
|
||||
failed = "failed" # 处理失败
|
||||
|
||||
|
||||
class Document(Base):
|
||||
"""文档表"""
|
||||
__tablename__ = "documents"
|
||||
|
||||
id = Column(String(64), primary_key=True)
|
||||
filename = Column(String(255), nullable=False)
|
||||
original_name = Column(String(255), nullable=False)
|
||||
minio_path = Column(String(512), nullable=False) # MinIO 存储路径
|
||||
size = Column(Integer, default=0)
|
||||
status = Column(String(32), default=DocStatus.uploaded.value)
|
||||
chunks = Column(Integer, default=0)
|
||||
vectors = Column(Integer, default=0)
|
||||
error_message = Column(Text, nullable=True)
|
||||
created_at = Column(DateTime, default=datetime.now)
|
||||
updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now)
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"id": self.id,
|
||||
"name": self.original_name,
|
||||
"chunks": self.chunks,
|
||||
"status": self.status,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
}
|
||||
|
||||
|
||||
class ParseTask(Base):
|
||||
"""解析任务表"""
|
||||
__tablename__ = "parse_tasks"
|
||||
|
||||
id = Column(String(64), primary_key=True)
|
||||
doc_id = Column(String(64), nullable=False)
|
||||
status = Column(String(32), default="pending")
|
||||
progress = Column(Integer, default=0)
|
||||
message = Column(Text, nullable=True)
|
||||
created_at = Column(DateTime, default=datetime.now)
|
||||
started_at = Column(DateTime, nullable=True)
|
||||
completed_at = Column(DateTime, nullable=True)
|
||||
|
||||
|
||||
def init_db():
|
||||
"""初始化数据库表"""
|
||||
try:
|
||||
Base.metadata.create_all(bind=engine)
|
||||
logger.info("Database tables created successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Database initialization failed: {e}")
|
||||
|
||||
|
||||
def get_db() -> Session:
|
||||
"""获取数据库会话"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
return db
|
||||
finally:
|
||||
# 注意:调用者需要负责关闭会话
|
||||
pass
|
||||
|
||||
|
||||
class DatabaseService:
|
||||
"""数据库服务"""
|
||||
|
||||
def __init__(self):
|
||||
self.engine = engine
|
||||
self.SessionLocal = SessionLocal
|
||||
|
||||
def create_document(
|
||||
self,
|
||||
doc_id: str,
|
||||
filename: str,
|
||||
original_name: str,
|
||||
minio_path: str,
|
||||
size: int,
|
||||
) -> Document:
|
||||
"""创建文档记录"""
|
||||
db = self.SessionLocal()
|
||||
try:
|
||||
doc = Document(
|
||||
id=doc_id,
|
||||
filename=filename,
|
||||
original_name=original_name,
|
||||
minio_path=minio_path,
|
||||
size=size,
|
||||
status=DocStatus.uploaded.value,
|
||||
)
|
||||
db.add(doc)
|
||||
db.commit()
|
||||
db.refresh(doc)
|
||||
return doc
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def get_document(self, doc_id: str) -> Optional[Document]:
|
||||
"""获取文档"""
|
||||
db = self.SessionLocal()
|
||||
try:
|
||||
return db.query(Document).filter(Document.id == doc_id).first()
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def update_document_status(
|
||||
self,
|
||||
doc_id: str,
|
||||
status: str,
|
||||
chunks: int = None,
|
||||
vectors: int = None,
|
||||
error_message: str = None,
|
||||
) -> Optional[Document]:
|
||||
"""更新文档状态"""
|
||||
db = self.SessionLocal()
|
||||
try:
|
||||
doc = db.query(Document).filter(Document.id == doc_id).first()
|
||||
if doc:
|
||||
doc.status = status
|
||||
doc.updated_at = datetime.now()
|
||||
if chunks is not None:
|
||||
doc.chunks = chunks
|
||||
if vectors is not None:
|
||||
doc.vectors = vectors
|
||||
if error_message:
|
||||
doc.error_message = error_message
|
||||
db.commit()
|
||||
db.refresh(doc)
|
||||
return doc
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def list_documents(self) -> List[Document]:
|
||||
"""列出所有文档"""
|
||||
db = self.SessionLocal()
|
||||
try:
|
||||
return db.query(Document).order_by(Document.created_at.desc()).all()
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def delete_document(self, doc_id: str) -> bool:
|
||||
"""删除文档"""
|
||||
db = self.SessionLocal()
|
||||
try:
|
||||
doc = db.query(Document).filter(Document.id == doc_id).first()
|
||||
if doc:
|
||||
db.delete(doc)
|
||||
db.commit()
|
||||
return True
|
||||
return False
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def create_parse_task(self, task_id: str, doc_id: str) -> ParseTask:
|
||||
"""创建解析任务"""
|
||||
db = self.SessionLocal()
|
||||
try:
|
||||
task = ParseTask(id=task_id, doc_id=doc_id)
|
||||
db.add(task)
|
||||
db.commit()
|
||||
db.refresh(task)
|
||||
return task
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def get_parse_task(self, task_id: str) -> Optional[ParseTask]:
|
||||
"""获取解析任务"""
|
||||
db = self.SessionLocal()
|
||||
try:
|
||||
return db.query(ParseTask).filter(ParseTask.id == task_id).first()
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def update_parse_task(
|
||||
self,
|
||||
task_id: str,
|
||||
status: str,
|
||||
progress: int = None,
|
||||
message: str = None,
|
||||
) -> Optional[ParseTask]:
|
||||
"""更新解析任务状态"""
|
||||
db = self.SessionLocal()
|
||||
try:
|
||||
task = db.query(ParseTask).filter(ParseTask.id == task_id).first()
|
||||
if task:
|
||||
task.status = status
|
||||
if progress is not None:
|
||||
task.progress = progress
|
||||
if message:
|
||||
task.message = message
|
||||
if status == "running":
|
||||
task.started_at = datetime.now()
|
||||
elif status in ("completed", "failed"):
|
||||
task.completed_at = datetime.now()
|
||||
db.commit()
|
||||
db.refresh(task)
|
||||
return task
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
# 单例
|
||||
db_service = DatabaseService()
|
||||
122
app/services/minio.py
Normal file
122
app/services/minio.py
Normal file
@@ -0,0 +1,122 @@
|
||||
"""MinIO 文件存储服务"""
|
||||
|
||||
import io
|
||||
from minio import Minio
|
||||
from minio.error import S3Error
|
||||
from app.core.config import settings
|
||||
from app.utils.logger import logger
|
||||
|
||||
|
||||
class MinioService:
|
||||
"""MinIO 文件存储服务"""
|
||||
|
||||
def __init__(self):
|
||||
self.client = Minio(
|
||||
settings.minio_endpoint,
|
||||
access_key=settings.minio_access_key,
|
||||
secret_key=settings.minio_secret_key,
|
||||
secure=settings.minio_secure,
|
||||
)
|
||||
self.bucket = settings.minio_bucket
|
||||
self._ensure_bucket()
|
||||
|
||||
def _ensure_bucket(self):
|
||||
"""确保存储桶存在"""
|
||||
try:
|
||||
if not self.client.bucket_exists(self.bucket):
|
||||
self.client.make_bucket(self.bucket)
|
||||
logger.info(f"Created MinIO bucket: {self.bucket}")
|
||||
except S3Error as e:
|
||||
logger.error(f"MinIO bucket check failed: {e}")
|
||||
|
||||
def upload_file(
|
||||
self,
|
||||
object_name: str,
|
||||
file_data: bytes,
|
||||
content_type: str = "application/octet-stream",
|
||||
) -> str:
|
||||
"""
|
||||
上传文件到 MinIO
|
||||
|
||||
Args:
|
||||
object_name: 对象名称(文件路径)
|
||||
file_data: 文件二进制数据
|
||||
content_type: 文件类型
|
||||
|
||||
Returns:
|
||||
文件的 MinIO URL
|
||||
"""
|
||||
try:
|
||||
data_stream = io.BytesIO(file_data)
|
||||
self.client.put_object(
|
||||
self.bucket,
|
||||
object_name,
|
||||
data_stream,
|
||||
length=len(file_data),
|
||||
content_type=content_type,
|
||||
)
|
||||
url = f"{settings.minio_endpoint}/{self.bucket}/{object_name}"
|
||||
logger.info(f"Uploaded file to MinIO: {object_name}")
|
||||
return url
|
||||
except S3Error as e:
|
||||
logger.error(f"MinIO upload failed: {e}")
|
||||
raise
|
||||
|
||||
def get_file(self, object_name: str) -> bytes:
|
||||
"""
|
||||
从 MinIO 获取文件
|
||||
|
||||
Args:
|
||||
object_name: 对象名称
|
||||
|
||||
Returns:
|
||||
文件二进制数据
|
||||
"""
|
||||
try:
|
||||
response = self.client.get_object(self.bucket, object_name)
|
||||
data = response.read()
|
||||
response.close()
|
||||
response.release_conn()
|
||||
return data
|
||||
except S3Error as e:
|
||||
logger.error(f"MinIO get file failed: {e}")
|
||||
raise
|
||||
|
||||
def delete_file(self, object_name: str) -> bool:
|
||||
"""
|
||||
删除 MinIO 中的文件
|
||||
|
||||
Args:
|
||||
object_name: 对象名称
|
||||
|
||||
Returns:
|
||||
是否成功删除
|
||||
"""
|
||||
try:
|
||||
self.client.remove_object(self.bucket, object_name)
|
||||
logger.info(f"Deleted file from MinIO: {object_name}")
|
||||
return True
|
||||
except S3Error as e:
|
||||
logger.error(f"MinIO delete failed: {e}")
|
||||
return False
|
||||
|
||||
def list_files(self, prefix: str = "") -> list[str]:
|
||||
"""
|
||||
列出 MinIO 中的文件
|
||||
|
||||
Args:
|
||||
prefix: 文件前缀过滤
|
||||
|
||||
Returns:
|
||||
文件名列表
|
||||
"""
|
||||
try:
|
||||
objects = self.client.list_objects(self.bucket, prefix=prefix)
|
||||
return [obj.object_name for obj in objects]
|
||||
except S3Error as e:
|
||||
logger.error(f"MinIO list files failed: {e}")
|
||||
return []
|
||||
|
||||
|
||||
# 单例
|
||||
minio_service = MinioService()
|
||||
89
app/services/tasks.py
Normal file
89
app/services/tasks.py
Normal file
@@ -0,0 +1,89 @@
|
||||
"""异步任务处理模块
|
||||
|
||||
TODO: 后续替换为 RabbitMQ 消息队列
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Callable, Awaitable
|
||||
from app.utils.logger import logger
|
||||
|
||||
|
||||
# 任务状态存储(后续替换为 Redis)
|
||||
_task_store: Dict[str, Dict] = {}
|
||||
|
||||
|
||||
def generate_task_id() -> str:
|
||||
"""生成任务ID"""
|
||||
return f"task-{uuid.uuid4().hex[:12]}"
|
||||
|
||||
|
||||
class AsyncTaskManager:
|
||||
"""异步任务管理器"""
|
||||
|
||||
def __init__(self):
|
||||
self._running_tasks: dict[str, asyncio.Task] = {}
|
||||
|
||||
def create_task(
|
||||
self,
|
||||
task_id: str,
|
||||
task_func: Callable[[str], Awaitable[None]],
|
||||
) -> asyncio.Task:
|
||||
"""
|
||||
创建异步任务
|
||||
|
||||
Args:
|
||||
task_id: 任务ID
|
||||
task_func: 任务执行函数
|
||||
|
||||
Returns:
|
||||
asyncio.Task
|
||||
"""
|
||||
task = asyncio.create_task(self._run_task(task_id, task_func))
|
||||
self._running_tasks[task_id] = task
|
||||
return task
|
||||
|
||||
async def _run_task(
|
||||
self,
|
||||
task_id: str,
|
||||
task_func: Callable[[str], Awaitable[None]],
|
||||
):
|
||||
"""运行任务并处理状态"""
|
||||
try:
|
||||
await task_func(task_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Task {task_id} failed: {e}")
|
||||
_task_store[task_id] = {
|
||||
"status": "failed",
|
||||
"error": str(e),
|
||||
"completed_at": datetime.now(),
|
||||
}
|
||||
finally:
|
||||
if task_id in self._running_tasks:
|
||||
del self._running_tasks[task_id]
|
||||
|
||||
def get_task_status(self, task_id: str) -> dict | None:
|
||||
"""获取任务状态"""
|
||||
return _task_store.get(task_id)
|
||||
|
||||
def cancel_task(self, task_id: str) -> bool:
|
||||
"""取消任务"""
|
||||
if task_id in self._running_tasks:
|
||||
self._running_tasks[task_id].cancel()
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# 单例
|
||||
task_manager = AsyncTaskManager()
|
||||
|
||||
|
||||
def get_task_status(task_id: str) -> dict | None:
|
||||
"""获取任务状态"""
|
||||
return _task_store.get(task_id)
|
||||
|
||||
|
||||
def set_task_status(task_id: str, status: dict):
|
||||
"""设置任务状态"""
|
||||
_task_store[task_id] = status
|
||||
Reference in New Issue
Block a user