"""数据库服务 - PostgreSQL""" from sqlalchemy import create_engine, Column, String, Integer, DateTime, Enum, Text from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker, Session from datetime import datetime from typing import Optional, List import enum from app.core.config import settings from app.utils.logger import logger # 数据库连接 DATABASE_URL = f"postgresql://{settings.postgres_user}:{settings.postgres_password}@{settings.postgres_host}:{settings.postgres_port}/{settings.postgres_db}" engine = create_engine(DATABASE_URL, echo=False, pool_pre_ping=True) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) Base = declarative_base() class DocStatus(str, enum.Enum): """文档处理状态""" uploaded = "uploaded" # 已上传 parsing = "parsing" # 解析中 parsed = "parsed" # 已解析 embedding = "embedding" # 向量化中 indexed = "indexed" # 已索引 failed = "failed" # 处理失败 class Document(Base): """文档表""" __tablename__ = "documents" id = Column(String(64), primary_key=True) filename = Column(String(255), nullable=False) original_name = Column(String(255), nullable=False) minio_path = Column(String(512), nullable=False) # MinIO 存储路径 size = Column(Integer, default=0) status = Column(String(32), default=DocStatus.uploaded.value) chunks = Column(Integer, default=0) vectors = Column(Integer, default=0) error_message = Column(Text, nullable=True) created_at = Column(DateTime, default=datetime.now) updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now) def to_dict(self): return { "id": self.id, "name": self.original_name, "chunks": self.chunks, "status": self.status, "created_at": self.created_at.isoformat() if self.created_at else None, } class ParseTask(Base): """解析任务表""" __tablename__ = "parse_tasks" id = Column(String(64), primary_key=True) doc_id = Column(String(64), nullable=False) status = Column(String(32), default="pending") progress = Column(Integer, default=0) message = Column(Text, nullable=True) created_at = Column(DateTime, default=datetime.now) started_at = Column(DateTime, nullable=True) completed_at = Column(DateTime, nullable=True) def init_db(): """初始化数据库表""" try: Base.metadata.create_all(bind=engine) logger.info("Database tables created successfully") except Exception as e: logger.error(f"Database initialization failed: {e}") def get_db() -> Session: """获取数据库会话""" db = SessionLocal() try: return db finally: # 注意:调用者需要负责关闭会话 pass class DatabaseService: """数据库服务""" def __init__(self): self.engine = engine self.SessionLocal = SessionLocal def create_document( self, doc_id: str, filename: str, original_name: str, minio_path: str, size: int, ) -> Document: """创建文档记录""" db = self.SessionLocal() try: doc = Document( id=doc_id, filename=filename, original_name=original_name, minio_path=minio_path, size=size, status=DocStatus.uploaded.value, ) db.add(doc) db.commit() db.refresh(doc) return doc finally: db.close() def get_document(self, doc_id: str) -> Optional[Document]: """获取文档""" db = self.SessionLocal() try: return db.query(Document).filter(Document.id == doc_id).first() finally: db.close() def update_document_status( self, doc_id: str, status: str, chunks: int = None, vectors: int = None, error_message: str = None, ) -> Optional[Document]: """更新文档状态""" db = self.SessionLocal() try: doc = db.query(Document).filter(Document.id == doc_id).first() if doc: doc.status = status doc.updated_at = datetime.now() if chunks is not None: doc.chunks = chunks if vectors is not None: doc.vectors = vectors if error_message: doc.error_message = error_message db.commit() db.refresh(doc) return doc finally: db.close() def list_documents(self) -> List[Document]: """列出所有文档""" db = self.SessionLocal() try: return db.query(Document).order_by(Document.created_at.desc()).all() finally: db.close() def delete_document(self, doc_id: str) -> bool: """删除文档""" db = self.SessionLocal() try: doc = db.query(Document).filter(Document.id == doc_id).first() if doc: db.delete(doc) db.commit() return True return False finally: db.close() def create_parse_task(self, task_id: str, doc_id: str) -> ParseTask: """创建解析任务""" db = self.SessionLocal() try: task = ParseTask(id=task_id, doc_id=doc_id) db.add(task) db.commit() db.refresh(task) return task finally: db.close() def get_parse_task(self, task_id: str) -> Optional[ParseTask]: """获取解析任务""" db = self.SessionLocal() try: return db.query(ParseTask).filter(ParseTask.id == task_id).first() finally: db.close() def update_parse_task( self, task_id: str, status: str, progress: int = None, message: str = None, ) -> Optional[ParseTask]: """更新解析任务状态""" db = self.SessionLocal() try: task = db.query(ParseTask).filter(ParseTask.id == task_id).first() if task: task.status = status if progress is not None: task.progress = progress if message: task.message = message if status == "running": task.started_at = datetime.now() elif status in ("completed", "failed"): task.completed_at = datetime.now() db.commit() db.refresh(task) return task finally: db.close() # 单例 db_service = DatabaseService()