Files
AIRegulation-Demo-Test-Backend/app/services/database.py

228 lines
6.8 KiB
Python
Raw Normal View History

"""数据库服务 - PostgreSQL"""
from sqlalchemy import create_engine, Column, String, Integer, DateTime, Enum, Text
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, Session
from datetime import datetime
from typing import Optional, List
import enum
from app.core.config import settings
from app.utils.logger import logger
# 数据库连接
DATABASE_URL = f"postgresql://{settings.postgres_user}:{settings.postgres_password}@{settings.postgres_host}:{settings.postgres_port}/{settings.postgres_db}"
engine = create_engine(DATABASE_URL, echo=False, pool_pre_ping=True)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base()
class DocStatus(str, enum.Enum):
"""文档处理状态"""
uploaded = "uploaded" # 已上传
parsing = "parsing" # 解析中
parsed = "parsed" # 已解析
embedding = "embedding" # 向量化中
indexed = "indexed" # 已索引
failed = "failed" # 处理失败
class Document(Base):
"""文档表"""
__tablename__ = "documents"
id = Column(String(64), primary_key=True)
filename = Column(String(255), nullable=False)
original_name = Column(String(255), nullable=False)
minio_path = Column(String(512), nullable=False) # MinIO 存储路径
size = Column(Integer, default=0)
status = Column(String(32), default=DocStatus.uploaded.value)
chunks = Column(Integer, default=0)
vectors = Column(Integer, default=0)
error_message = Column(Text, nullable=True)
created_at = Column(DateTime, default=datetime.now)
updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now)
def to_dict(self):
return {
"id": self.id,
"name": self.original_name,
"chunks": self.chunks,
"status": self.status,
"created_at": self.created_at.isoformat() if self.created_at else None,
}
class ParseTask(Base):
"""解析任务表"""
__tablename__ = "parse_tasks"
id = Column(String(64), primary_key=True)
doc_id = Column(String(64), nullable=False)
status = Column(String(32), default="pending")
progress = Column(Integer, default=0)
message = Column(Text, nullable=True)
created_at = Column(DateTime, default=datetime.now)
started_at = Column(DateTime, nullable=True)
completed_at = Column(DateTime, nullable=True)
def init_db():
"""初始化数据库表"""
try:
Base.metadata.create_all(bind=engine)
logger.info("Database tables created successfully")
except Exception as e:
logger.error(f"Database initialization failed: {e}")
def get_db() -> Session:
"""获取数据库会话"""
db = SessionLocal()
try:
return db
finally:
# 注意:调用者需要负责关闭会话
pass
class DatabaseService:
"""数据库服务"""
def __init__(self):
self.engine = engine
self.SessionLocal = SessionLocal
def create_document(
self,
doc_id: str,
filename: str,
original_name: str,
minio_path: str,
size: int,
) -> Document:
"""创建文档记录"""
db = self.SessionLocal()
try:
doc = Document(
id=doc_id,
filename=filename,
original_name=original_name,
minio_path=minio_path,
size=size,
status=DocStatus.uploaded.value,
)
db.add(doc)
db.commit()
db.refresh(doc)
return doc
finally:
db.close()
def get_document(self, doc_id: str) -> Optional[Document]:
"""获取文档"""
db = self.SessionLocal()
try:
return db.query(Document).filter(Document.id == doc_id).first()
finally:
db.close()
def update_document_status(
self,
doc_id: str,
status: str,
chunks: int = None,
vectors: int = None,
error_message: str = None,
) -> Optional[Document]:
"""更新文档状态"""
db = self.SessionLocal()
try:
doc = db.query(Document).filter(Document.id == doc_id).first()
if doc:
doc.status = status
doc.updated_at = datetime.now()
if chunks is not None:
doc.chunks = chunks
if vectors is not None:
doc.vectors = vectors
if error_message:
doc.error_message = error_message
db.commit()
db.refresh(doc)
return doc
finally:
db.close()
def list_documents(self) -> List[Document]:
"""列出所有文档"""
db = self.SessionLocal()
try:
return db.query(Document).order_by(Document.created_at.desc()).all()
finally:
db.close()
def delete_document(self, doc_id: str) -> bool:
"""删除文档"""
db = self.SessionLocal()
try:
doc = db.query(Document).filter(Document.id == doc_id).first()
if doc:
db.delete(doc)
db.commit()
return True
return False
finally:
db.close()
def create_parse_task(self, task_id: str, doc_id: str) -> ParseTask:
"""创建解析任务"""
db = self.SessionLocal()
try:
task = ParseTask(id=task_id, doc_id=doc_id)
db.add(task)
db.commit()
db.refresh(task)
return task
finally:
db.close()
def get_parse_task(self, task_id: str) -> Optional[ParseTask]:
"""获取解析任务"""
db = self.SessionLocal()
try:
return db.query(ParseTask).filter(ParseTask.id == task_id).first()
finally:
db.close()
def update_parse_task(
self,
task_id: str,
status: str,
progress: int = None,
message: str = None,
) -> Optional[ParseTask]:
"""更新解析任务状态"""
db = self.SessionLocal()
try:
task = db.query(ParseTask).filter(ParseTask.id == task_id).first()
if task:
task.status = status
if progress is not None:
task.progress = progress
if message:
task.message = message
if status == "running":
task.started_at = datetime.now()
elif status in ("completed", "failed"):
task.completed_at = datetime.now()
db.commit()
db.refresh(task)
return task
finally:
db.close()
# 单例
db_service = DatabaseService()