228 lines
6.8 KiB
Python
228 lines
6.8 KiB
Python
"""数据库服务 - PostgreSQL"""
|
|
|
|
from sqlalchemy import create_engine, Column, String, Integer, DateTime, Enum, Text
|
|
from sqlalchemy.ext.declarative import declarative_base
|
|
from sqlalchemy.orm import sessionmaker, Session
|
|
from datetime import datetime
|
|
from typing import Optional, List
|
|
import enum
|
|
from app.core.config import settings
|
|
from app.utils.logger import logger
|
|
|
|
|
|
# 数据库连接
|
|
DATABASE_URL = f"postgresql://{settings.postgres_user}:{settings.postgres_password}@{settings.postgres_host}:{settings.postgres_port}/{settings.postgres_db}"
|
|
|
|
engine = create_engine(DATABASE_URL, echo=False, pool_pre_ping=True)
|
|
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
|
Base = declarative_base()
|
|
|
|
|
|
class DocStatus(str, enum.Enum):
|
|
"""文档处理状态"""
|
|
uploaded = "uploaded" # 已上传
|
|
parsing = "parsing" # 解析中
|
|
parsed = "parsed" # 已解析
|
|
embedding = "embedding" # 向量化中
|
|
indexed = "indexed" # 已索引
|
|
failed = "failed" # 处理失败
|
|
|
|
|
|
class Document(Base):
|
|
"""文档表"""
|
|
__tablename__ = "documents"
|
|
|
|
id = Column(String(64), primary_key=True)
|
|
filename = Column(String(255), nullable=False)
|
|
original_name = Column(String(255), nullable=False)
|
|
minio_path = Column(String(512), nullable=False) # MinIO 存储路径
|
|
size = Column(Integer, default=0)
|
|
status = Column(String(32), default=DocStatus.uploaded.value)
|
|
chunks = Column(Integer, default=0)
|
|
vectors = Column(Integer, default=0)
|
|
error_message = Column(Text, nullable=True)
|
|
created_at = Column(DateTime, default=datetime.now)
|
|
updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now)
|
|
|
|
def to_dict(self):
|
|
return {
|
|
"id": self.id,
|
|
"name": self.original_name,
|
|
"chunks": self.chunks,
|
|
"status": self.status,
|
|
"created_at": self.created_at.isoformat() if self.created_at else None,
|
|
}
|
|
|
|
|
|
class ParseTask(Base):
|
|
"""解析任务表"""
|
|
__tablename__ = "parse_tasks"
|
|
|
|
id = Column(String(64), primary_key=True)
|
|
doc_id = Column(String(64), nullable=False)
|
|
status = Column(String(32), default="pending")
|
|
progress = Column(Integer, default=0)
|
|
message = Column(Text, nullable=True)
|
|
created_at = Column(DateTime, default=datetime.now)
|
|
started_at = Column(DateTime, nullable=True)
|
|
completed_at = Column(DateTime, nullable=True)
|
|
|
|
|
|
def init_db():
|
|
"""初始化数据库表"""
|
|
try:
|
|
Base.metadata.create_all(bind=engine)
|
|
logger.info("Database tables created successfully")
|
|
except Exception as e:
|
|
logger.error(f"Database initialization failed: {e}")
|
|
|
|
|
|
def get_db() -> Session:
|
|
"""获取数据库会话"""
|
|
db = SessionLocal()
|
|
try:
|
|
return db
|
|
finally:
|
|
# 注意:调用者需要负责关闭会话
|
|
pass
|
|
|
|
|
|
class DatabaseService:
|
|
"""数据库服务"""
|
|
|
|
def __init__(self):
|
|
self.engine = engine
|
|
self.SessionLocal = SessionLocal
|
|
|
|
def create_document(
|
|
self,
|
|
doc_id: str,
|
|
filename: str,
|
|
original_name: str,
|
|
minio_path: str,
|
|
size: int,
|
|
) -> Document:
|
|
"""创建文档记录"""
|
|
db = self.SessionLocal()
|
|
try:
|
|
doc = Document(
|
|
id=doc_id,
|
|
filename=filename,
|
|
original_name=original_name,
|
|
minio_path=minio_path,
|
|
size=size,
|
|
status=DocStatus.uploaded.value,
|
|
)
|
|
db.add(doc)
|
|
db.commit()
|
|
db.refresh(doc)
|
|
return doc
|
|
finally:
|
|
db.close()
|
|
|
|
def get_document(self, doc_id: str) -> Optional[Document]:
|
|
"""获取文档"""
|
|
db = self.SessionLocal()
|
|
try:
|
|
return db.query(Document).filter(Document.id == doc_id).first()
|
|
finally:
|
|
db.close()
|
|
|
|
def update_document_status(
|
|
self,
|
|
doc_id: str,
|
|
status: str,
|
|
chunks: int = None,
|
|
vectors: int = None,
|
|
error_message: str = None,
|
|
) -> Optional[Document]:
|
|
"""更新文档状态"""
|
|
db = self.SessionLocal()
|
|
try:
|
|
doc = db.query(Document).filter(Document.id == doc_id).first()
|
|
if doc:
|
|
doc.status = status
|
|
doc.updated_at = datetime.now()
|
|
if chunks is not None:
|
|
doc.chunks = chunks
|
|
if vectors is not None:
|
|
doc.vectors = vectors
|
|
if error_message:
|
|
doc.error_message = error_message
|
|
db.commit()
|
|
db.refresh(doc)
|
|
return doc
|
|
finally:
|
|
db.close()
|
|
|
|
def list_documents(self) -> List[Document]:
|
|
"""列出所有文档"""
|
|
db = self.SessionLocal()
|
|
try:
|
|
return db.query(Document).order_by(Document.created_at.desc()).all()
|
|
finally:
|
|
db.close()
|
|
|
|
def delete_document(self, doc_id: str) -> bool:
|
|
"""删除文档"""
|
|
db = self.SessionLocal()
|
|
try:
|
|
doc = db.query(Document).filter(Document.id == doc_id).first()
|
|
if doc:
|
|
db.delete(doc)
|
|
db.commit()
|
|
return True
|
|
return False
|
|
finally:
|
|
db.close()
|
|
|
|
def create_parse_task(self, task_id: str, doc_id: str) -> ParseTask:
|
|
"""创建解析任务"""
|
|
db = self.SessionLocal()
|
|
try:
|
|
task = ParseTask(id=task_id, doc_id=doc_id)
|
|
db.add(task)
|
|
db.commit()
|
|
db.refresh(task)
|
|
return task
|
|
finally:
|
|
db.close()
|
|
|
|
def get_parse_task(self, task_id: str) -> Optional[ParseTask]:
|
|
"""获取解析任务"""
|
|
db = self.SessionLocal()
|
|
try:
|
|
return db.query(ParseTask).filter(ParseTask.id == task_id).first()
|
|
finally:
|
|
db.close()
|
|
|
|
def update_parse_task(
|
|
self,
|
|
task_id: str,
|
|
status: str,
|
|
progress: int = None,
|
|
message: str = None,
|
|
) -> Optional[ParseTask]:
|
|
"""更新解析任务状态"""
|
|
db = self.SessionLocal()
|
|
try:
|
|
task = db.query(ParseTask).filter(ParseTask.id == task_id).first()
|
|
if task:
|
|
task.status = status
|
|
if progress is not None:
|
|
task.progress = progress
|
|
if message:
|
|
task.message = message
|
|
if status == "running":
|
|
task.started_at = datetime.now()
|
|
elif status in ("completed", "failed"):
|
|
task.completed_at = datetime.now()
|
|
db.commit()
|
|
db.refresh(task)
|
|
return task
|
|
finally:
|
|
db.close()
|
|
|
|
|
|
# 单例
|
|
db_service = DatabaseService() |