Refactor code structure for improved readability and maintainability
This commit is contained in:
228
app/services/database.py
Normal file
228
app/services/database.py
Normal file
@@ -0,0 +1,228 @@
|
||||
"""数据库服务 - PostgreSQL"""
|
||||
|
||||
from sqlalchemy import create_engine, Column, String, Integer, DateTime, Enum, Text
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import sessionmaker, Session
|
||||
from datetime import datetime
|
||||
from typing import Optional, List
|
||||
import enum
|
||||
from app.core.config import settings
|
||||
from app.utils.logger import logger
|
||||
|
||||
|
||||
# 数据库连接
|
||||
DATABASE_URL = f"postgresql://{settings.postgres_user}:{settings.postgres_password}@{settings.postgres_host}:{settings.postgres_port}/{settings.postgres_db}"
|
||||
|
||||
engine = create_engine(DATABASE_URL, echo=False, pool_pre_ping=True)
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
class DocStatus(str, enum.Enum):
|
||||
"""文档处理状态"""
|
||||
uploaded = "uploaded" # 已上传
|
||||
parsing = "parsing" # 解析中
|
||||
parsed = "parsed" # 已解析
|
||||
embedding = "embedding" # 向量化中
|
||||
indexed = "indexed" # 已索引
|
||||
failed = "failed" # 处理失败
|
||||
|
||||
|
||||
class Document(Base):
|
||||
"""文档表"""
|
||||
__tablename__ = "documents"
|
||||
|
||||
id = Column(String(64), primary_key=True)
|
||||
filename = Column(String(255), nullable=False)
|
||||
original_name = Column(String(255), nullable=False)
|
||||
minio_path = Column(String(512), nullable=False) # MinIO 存储路径
|
||||
size = Column(Integer, default=0)
|
||||
status = Column(String(32), default=DocStatus.uploaded.value)
|
||||
chunks = Column(Integer, default=0)
|
||||
vectors = Column(Integer, default=0)
|
||||
error_message = Column(Text, nullable=True)
|
||||
created_at = Column(DateTime, default=datetime.now)
|
||||
updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now)
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"id": self.id,
|
||||
"name": self.original_name,
|
||||
"chunks": self.chunks,
|
||||
"status": self.status,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
}
|
||||
|
||||
|
||||
class ParseTask(Base):
|
||||
"""解析任务表"""
|
||||
__tablename__ = "parse_tasks"
|
||||
|
||||
id = Column(String(64), primary_key=True)
|
||||
doc_id = Column(String(64), nullable=False)
|
||||
status = Column(String(32), default="pending")
|
||||
progress = Column(Integer, default=0)
|
||||
message = Column(Text, nullable=True)
|
||||
created_at = Column(DateTime, default=datetime.now)
|
||||
started_at = Column(DateTime, nullable=True)
|
||||
completed_at = Column(DateTime, nullable=True)
|
||||
|
||||
|
||||
def init_db():
|
||||
"""初始化数据库表"""
|
||||
try:
|
||||
Base.metadata.create_all(bind=engine)
|
||||
logger.info("Database tables created successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Database initialization failed: {e}")
|
||||
|
||||
|
||||
def get_db() -> Session:
|
||||
"""获取数据库会话"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
return db
|
||||
finally:
|
||||
# 注意:调用者需要负责关闭会话
|
||||
pass
|
||||
|
||||
|
||||
class DatabaseService:
|
||||
"""数据库服务"""
|
||||
|
||||
def __init__(self):
|
||||
self.engine = engine
|
||||
self.SessionLocal = SessionLocal
|
||||
|
||||
def create_document(
|
||||
self,
|
||||
doc_id: str,
|
||||
filename: str,
|
||||
original_name: str,
|
||||
minio_path: str,
|
||||
size: int,
|
||||
) -> Document:
|
||||
"""创建文档记录"""
|
||||
db = self.SessionLocal()
|
||||
try:
|
||||
doc = Document(
|
||||
id=doc_id,
|
||||
filename=filename,
|
||||
original_name=original_name,
|
||||
minio_path=minio_path,
|
||||
size=size,
|
||||
status=DocStatus.uploaded.value,
|
||||
)
|
||||
db.add(doc)
|
||||
db.commit()
|
||||
db.refresh(doc)
|
||||
return doc
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def get_document(self, doc_id: str) -> Optional[Document]:
|
||||
"""获取文档"""
|
||||
db = self.SessionLocal()
|
||||
try:
|
||||
return db.query(Document).filter(Document.id == doc_id).first()
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def update_document_status(
|
||||
self,
|
||||
doc_id: str,
|
||||
status: str,
|
||||
chunks: int = None,
|
||||
vectors: int = None,
|
||||
error_message: str = None,
|
||||
) -> Optional[Document]:
|
||||
"""更新文档状态"""
|
||||
db = self.SessionLocal()
|
||||
try:
|
||||
doc = db.query(Document).filter(Document.id == doc_id).first()
|
||||
if doc:
|
||||
doc.status = status
|
||||
doc.updated_at = datetime.now()
|
||||
if chunks is not None:
|
||||
doc.chunks = chunks
|
||||
if vectors is not None:
|
||||
doc.vectors = vectors
|
||||
if error_message:
|
||||
doc.error_message = error_message
|
||||
db.commit()
|
||||
db.refresh(doc)
|
||||
return doc
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def list_documents(self) -> List[Document]:
|
||||
"""列出所有文档"""
|
||||
db = self.SessionLocal()
|
||||
try:
|
||||
return db.query(Document).order_by(Document.created_at.desc()).all()
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def delete_document(self, doc_id: str) -> bool:
|
||||
"""删除文档"""
|
||||
db = self.SessionLocal()
|
||||
try:
|
||||
doc = db.query(Document).filter(Document.id == doc_id).first()
|
||||
if doc:
|
||||
db.delete(doc)
|
||||
db.commit()
|
||||
return True
|
||||
return False
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def create_parse_task(self, task_id: str, doc_id: str) -> ParseTask:
|
||||
"""创建解析任务"""
|
||||
db = self.SessionLocal()
|
||||
try:
|
||||
task = ParseTask(id=task_id, doc_id=doc_id)
|
||||
db.add(task)
|
||||
db.commit()
|
||||
db.refresh(task)
|
||||
return task
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def get_parse_task(self, task_id: str) -> Optional[ParseTask]:
|
||||
"""获取解析任务"""
|
||||
db = self.SessionLocal()
|
||||
try:
|
||||
return db.query(ParseTask).filter(ParseTask.id == task_id).first()
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def update_parse_task(
|
||||
self,
|
||||
task_id: str,
|
||||
status: str,
|
||||
progress: int = None,
|
||||
message: str = None,
|
||||
) -> Optional[ParseTask]:
|
||||
"""更新解析任务状态"""
|
||||
db = self.SessionLocal()
|
||||
try:
|
||||
task = db.query(ParseTask).filter(ParseTask.id == task_id).first()
|
||||
if task:
|
||||
task.status = status
|
||||
if progress is not None:
|
||||
task.progress = progress
|
||||
if message:
|
||||
task.message = message
|
||||
if status == "running":
|
||||
task.started_at = datetime.now()
|
||||
elif status in ("completed", "failed"):
|
||||
task.completed_at = datetime.now()
|
||||
db.commit()
|
||||
db.refresh(task)
|
||||
return task
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
# 单例
|
||||
db_service = DatabaseService()
|
||||
Reference in New Issue
Block a user