Files
AIRegulation-DocAnalysis/backend/app/application/documents/services.py
2026-05-26 12:34:12 +08:00

577 lines
23 KiB
Python

"""Implement application-layer logic for services."""
from __future__ import annotations
import os
import tempfile
import uuid
import json
from dataclasses import dataclass
from datetime import UTC, datetime
from loguru import logger
from app.config.settings import settings
from app.domain.documents import (
ChunkBuilder,
Document,
DocumentArtifact,
DocumentBinaryStore,
DocumentParser,
DocumentProcessingRun,
DocumentProcessingStore,
DocumentRepository,
DocumentStatus,
DocumentStatusEvent,
ParseArtifactStore,
ParsedDocument,
)
from app.domain.retrieval import EmbeddingProvider, VectorIndex
# Keep orchestration logic centralized so use-case flow stays easy to trace.
@dataclass
class DocumentProcessResult:
"""Represent document process result data."""
doc_id: str
doc_name: str
status: str
message: str
num_chunks: int = 0
summary: str = ""
summary_latency_ms: int = 0
class DocumentCommandService:
"""Provide the Document Command Service service."""
def __init__(
self,
*,
document_repository: DocumentRepository,
binary_store: DocumentBinaryStore,
parser: DocumentParser,
chunk_builder: ChunkBuilder,
embedding_provider: EmbeddingProvider,
vector_index: VectorIndex,
parse_artifact_store: ParseArtifactStore | None = None,
document_processing_store: DocumentProcessingStore | None = None,
) -> None:
"""Initialize the Document Command Service instance."""
self.document_repository = document_repository
self.binary_store = binary_store
self.parser = parser
self.chunk_builder = chunk_builder
self.embedding_provider = embedding_provider
self.vector_index = vector_index
self.parse_artifact_store = parse_artifact_store
self.document_processing_store = document_processing_store
def _utcnow(self) -> datetime:
"""Return the current UTC timestamp for persisted processing metadata."""
return datetime.now(UTC)
def _save_parse_artifacts(self, *, doc_id: str, parsed_document: ParsedDocument) -> dict[str, str]:
"""Persist parse artifacts so troubleshooting does not depend on provider retention windows."""
prefix = f"{parsed_document.metadata.get('artifact_prefix', 'artifacts').strip('/')}/{doc_id}"
artifact_payloads = {
"layouts": parsed_document.raw_layouts,
"structure_nodes": parsed_document.structure_nodes,
"semantic_blocks": parsed_document.semantic_blocks,
"vector_chunks": parsed_document.vector_chunks,
}
artifact_keys: dict[str, str] = {}
for name, payload in artifact_payloads.items():
object_name = f"{prefix}/{name}.json"
self.binary_store.save(
object_name=object_name,
data=json.dumps(payload, ensure_ascii=False, indent=2).encode("utf-8"),
content_type="application/json",
metadata={"doc_id": doc_id, "artifact_type": name},
)
artifact_keys[name] = object_name
return artifact_keys
def _safe_create_processing_run(self, *, doc_id: str, trigger_type: str, generate_summary: bool) -> str | None:
"""Create a processing run record when the optional store is available."""
if not self.document_processing_store:
return None
run = DocumentProcessingRun(
run_id=str(uuid.uuid4()),
doc_id=doc_id,
trigger_type=trigger_type,
run_status="running",
parser_backend=settings.parser_backend,
chunk_backend=settings.chunk_backend,
embedding_model=settings.embedding_model,
metadata={"generate_summary": generate_summary},
)
try:
created = self.document_processing_store.create_run(run)
return created.run_id
except Exception:
logger.warning("DocumentProcessingStore.create_run failed for doc_id={}", doc_id)
return None
def _safe_append_status_event(
self,
*,
doc_id: str,
run_id: str | None,
from_status: str,
to_status: str,
stage: str,
message: str = "",
metadata: dict | None = None,
) -> None:
"""Append a status event without allowing auxiliary persistence failures to abort processing."""
if not self.document_processing_store or not run_id:
return
event = DocumentStatusEvent(
event_id=str(uuid.uuid4()),
doc_id=doc_id,
run_id=run_id,
from_status=from_status,
to_status=to_status,
stage=stage,
message=message,
metadata=metadata or {},
)
try:
self.document_processing_store.append_status_event(event)
except Exception:
logger.warning(
"DocumentProcessingStore.append_status_event failed for doc_id={}, run_id={}",
doc_id,
run_id,
)
def _safe_mark_run_stored(self, *, doc_id: str, run_id: str | None) -> None:
"""Mark the processing run as stored without affecting the main workflow."""
if not self.document_processing_store or not run_id:
return
try:
self.document_processing_store.mark_run_stored(run_id, stored_at=self._utcnow())
except Exception:
logger.warning("DocumentProcessingStore.mark_run_stored failed for doc_id={}, run_id={}", doc_id, run_id)
def _safe_mark_run_parsed(self, *, doc_id: str, run_id: str | None, parsed_document: ParsedDocument) -> None:
"""Persist parse completion details without failing the document pipeline."""
if not self.document_processing_store or not run_id:
return
try:
self.document_processing_store.mark_run_parsed(
run_id,
parser_backend=parsed_document.parser_name,
layout_count=int(parsed_document.metadata.get("layout_count", len(parsed_document.raw_layouts)) or 0),
structure_node_count=len(parsed_document.structure_nodes),
semantic_block_count=len(parsed_document.semantic_blocks),
vector_chunk_count=len(parsed_document.vector_chunks),
parsed_at=self._utcnow(),
metadata={"parse_task_id": parsed_document.metadata.get("task_id", "")},
)
except Exception:
logger.warning("DocumentProcessingStore.mark_run_parsed failed for doc_id={}, run_id={}", doc_id, run_id)
def _safe_replace_processing_artifacts(self, *, doc_id: str, run_id: str | None, artifact_keys: dict[str, str]) -> None:
"""Store artifact references without turning persistence drift into a user-visible failure."""
if not self.document_processing_store or not run_id:
return
artifacts = [
DocumentArtifact(
artifact_id=str(uuid.uuid4()),
doc_id=doc_id,
run_id=run_id,
artifact_type=artifact_type,
object_name=object_name,
content_type="application/json",
byte_size=0,
checksum="",
)
for artifact_type, object_name in artifact_keys.items()
]
try:
self.document_processing_store.replace_artifacts_for_run(run_id, artifacts)
except Exception:
logger.warning(
"DocumentProcessingStore.replace_artifacts_for_run failed for doc_id={}, run_id={}",
doc_id,
run_id,
)
def _safe_mark_run_indexed(self, *, doc_id: str, run_id: str | None, chunk_count: int, index_name: str) -> None:
"""Mark the processing run as indexed without affecting the success path."""
if not self.document_processing_store or not run_id:
return
now = self._utcnow()
try:
self.document_processing_store.mark_run_indexed(
run_id,
chunk_count=chunk_count,
index_name=index_name,
indexed_at=now,
finished_at=now,
)
except Exception:
logger.warning("DocumentProcessingStore.mark_run_indexed failed for doc_id={}, run_id={}", doc_id, run_id)
def _safe_mark_run_failed(self, *, doc_id: str, run_id: str | None, failure_stage: str, error_message: str) -> None:
"""Mark the processing run as failed without masking the original error handling path."""
if not self.document_processing_store or not run_id:
return
try:
self.document_processing_store.mark_run_failed(
run_id,
failure_stage=failure_stage,
error_message=error_message,
finished_at=self._utcnow(),
)
except Exception:
logger.warning("DocumentProcessingStore.mark_run_failed failed for doc_id={}, run_id={}", doc_id, run_id)
def upload_and_process(
self,
*,
doc_id: str | None = None,
file_name: str,
content: bytes,
content_type: str,
doc_name: str | None,
regulation_type: str,
version: str,
generate_summary: bool,
trigger_type: str = "upload",
) -> DocumentProcessResult:
"""Handle upload and process for the Document Command Service instance."""
doc_id = doc_id or str(uuid.uuid4())[:8]
final_doc_name = doc_name or file_name
object_name = f"{doc_id}/{file_name}"
run_id: str | None = None
current_status = DocumentStatus.PENDING
current_stage = "store"
document = Document(
doc_id=doc_id,
doc_name=final_doc_name,
file_name=file_name,
object_name=object_name,
content_type=content_type,
size_bytes=len(content),
regulation_type=regulation_type,
version=version,
metadata={"generate_summary": generate_summary},
)
self.document_repository.create(document)
run_id = self._safe_create_processing_run(
doc_id=doc_id,
trigger_type=trigger_type,
generate_summary=generate_summary,
)
self._safe_append_status_event(
doc_id=doc_id,
run_id=run_id,
from_status="",
to_status=DocumentStatus.PENDING.value,
stage="document_created",
message="Document record created",
)
temp_path = ""
try:
self.binary_store.save(
object_name=object_name,
data=content,
content_type=content_type,
metadata={"doc_id": doc_id},
)
self.document_repository.update_status(doc_id, DocumentStatus.STORED)
current_status = DocumentStatus.STORED
current_stage = "parse"
self._safe_mark_run_stored(doc_id=doc_id, run_id=run_id)
self._safe_append_status_event(
doc_id=doc_id,
run_id=run_id,
from_status=DocumentStatus.PENDING.value,
to_status=DocumentStatus.STORED.value,
stage="store",
message="Source file stored",
)
suffix = os.path.splitext(file_name)[1]
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file:
temp_file.write(content)
temp_path = temp_file.name
parsed_document = self.parser.parse(
file_path=temp_path,
doc_id=doc_id,
doc_name=final_doc_name,
)
self._safe_mark_run_parsed(doc_id=doc_id, run_id=run_id, parsed_document=parsed_document)
artifact_keys: dict[str, str] = {}
try:
artifact_keys = self._save_parse_artifacts(doc_id=doc_id, parsed_document=parsed_document)
except Exception:
logger.warning("Parse artifact binary persistence failed for doc_id={}", doc_id)
self.document_repository.update_status(
doc_id,
DocumentStatus.PARSED,
parser_name=parsed_document.parser_name,
metadata={
"parser_backend": parsed_document.parser_name,
"parse_task_id": parsed_document.metadata.get("task_id", ""),
"layout_count": parsed_document.metadata.get("layout_count", len(parsed_document.raw_layouts)),
"structure_node_count": len(parsed_document.structure_nodes),
"semantic_block_count": len(parsed_document.semantic_blocks),
"vector_chunk_count": len(parsed_document.vector_chunks),
"artifact_keys": artifact_keys,
"processing_stage": "parsed",
},
)
current_status = DocumentStatus.PARSED
current_stage = "embed"
self._safe_replace_processing_artifacts(doc_id=doc_id, run_id=run_id, artifact_keys=artifact_keys)
self._safe_append_status_event(
doc_id=doc_id,
run_id=run_id,
from_status=DocumentStatus.STORED.value,
to_status=DocumentStatus.PARSED.value,
stage="parse",
message="Document parsed",
metadata={"artifact_count": len(artifact_keys)},
)
if self.parse_artifact_store:
try:
self.parse_artifact_store.save(
doc_id,
parsed_document.structure_nodes,
parsed_document.semantic_blocks,
)
except Exception:
logger.warning("ParseArtifactStore.save failed for doc_id={}", doc_id)
chunks = self.chunk_builder.build(
parsed_document=parsed_document,
regulation_type=regulation_type,
version=version,
)
if not chunks:
raise ValueError("解析完成但没有生成可入库的 chunks")
vectors = self.embedding_provider.embed_texts([chunk.embedding_text for chunk in chunks])
current_stage = "index"
inserted = self.vector_index.upsert(chunks, vectors)
if inserted != len(chunks):
logger.warning("Milvus upsert count mismatched: inserted={}, chunks={}", inserted, len(chunks))
health = self.vector_index.health()
self.document_repository.update_status(
doc_id,
DocumentStatus.INDEXED,
chunk_count=len(chunks),
summary="",
summary_latency_ms=0,
index_name=health.get("collection_name", ""),
metadata={
"index_collection": health.get("collection_name", ""),
"processing_stage": "indexed",
},
)
current_status = DocumentStatus.INDEXED
index_name = health.get("collection_name", "")
self._safe_mark_run_indexed(
doc_id=doc_id,
run_id=run_id,
chunk_count=len(chunks),
index_name=index_name,
)
self._safe_append_status_event(
doc_id=doc_id,
run_id=run_id,
from_status=DocumentStatus.PARSED.value,
to_status=DocumentStatus.INDEXED.value,
stage="index",
message="Document indexed",
metadata={"chunk_count": len(chunks), "index_name": index_name},
)
stored = self.document_repository.get(doc_id)
return DocumentProcessResult(
doc_id=doc_id,
doc_name=final_doc_name,
status=(stored.status.value if stored else DocumentStatus.INDEXED.value),
message="处理成功",
num_chunks=len(chunks),
summary=stored.summary if stored else "",
summary_latency_ms=stored.summary_latency_ms if stored else 0,
)
except Exception as exc:
logger.exception("文档处理失败: doc_id={}", doc_id)
failure_stage = current_stage
self.document_repository.update_status(
doc_id,
DocumentStatus.FAILED,
error_message=str(exc),
metadata={
"failure_reason": str(exc),
"processing_stage": "failed",
"failure_stage": failure_stage,
},
)
self._safe_mark_run_failed(
doc_id=doc_id,
run_id=run_id,
failure_stage=failure_stage,
error_message=str(exc),
)
self._safe_append_status_event(
doc_id=doc_id,
run_id=run_id,
from_status=current_status.value,
to_status=DocumentStatus.FAILED.value,
stage=failure_stage,
message=str(exc),
)
return DocumentProcessResult(
doc_id=doc_id,
doc_name=final_doc_name,
status=DocumentStatus.FAILED.value,
message=f"文档处理失败: {exc}",
)
finally:
if temp_path and os.path.exists(temp_path):
try:
os.remove(temp_path)
except OSError:
logger.warning("临时文件清理失败: {}", temp_path)
def delete(self, doc_id: str) -> bool:
"""Delete document record, binary file, and vector chunks."""
document = self.document_repository.get(doc_id)
if not document:
return False
try:
self.binary_store.delete(document.object_name)
except Exception:
logger.warning("Binary delete failed for doc_id={}", doc_id)
try:
self.vector_index.delete_by_document(doc_id)
except Exception:
logger.warning("Vector delete failed for doc_id={}", doc_id)
if self.parse_artifact_store:
try:
self.parse_artifact_store.delete(doc_id)
except Exception:
logger.warning("ParseArtifactStore delete failed for doc_id={}", doc_id)
if self.document_processing_store:
try:
self.document_processing_store.delete_by_document(doc_id)
except Exception:
logger.warning("DocumentProcessingStore delete failed for doc_id={}", doc_id)
self.document_repository.delete(doc_id)
return True
def retry(self, doc_id: str) -> DocumentProcessResult:
"""Re-process a failed document from its stored binary."""
document = self.document_repository.get(doc_id)
if not document:
return DocumentProcessResult(doc_id=doc_id, doc_name="", status="failed", message="文档不存在")
content = self.binary_store.read(document.object_name)
return self.upload_and_process(
doc_id=doc_id,
file_name=document.file_name,
content=content,
content_type=document.content_type,
doc_name=document.doc_name,
regulation_type=document.regulation_type,
version=document.version,
generate_summary=bool(document.metadata.get("generate_summary", False)),
trigger_type="retry",
)
class DocumentQueryService:
"""Provide the Document Query Service service."""
def __init__(self, *, document_repository: DocumentRepository, binary_store: DocumentBinaryStore, vector_index: VectorIndex) -> None:
"""Initialize the Document Query Service instance."""
self.document_repository = document_repository
self.binary_store = binary_store
self.vector_index = vector_index
def get(self, doc_id: str) -> Document | None:
"""Handle get for the Document Query Service instance."""
return self.document_repository.get(doc_id)
def list_documents(self, limit: int | None = None) -> list[Document]:
"""Return documents with real-time state from Milvus as the authoritative source.
Algorithm:
1. Query Milvus for all doc metadata (doc_id, doc_name, chunk_count, …).
2. Load JSON/PG metadata records and index them by doc_id.
3. Merge: Milvus-present docs get status=INDEXED and live chunk_count;
metadata-only docs with status=INDEXED are demoted to FAILED.
4. Milvus-only docs (no metadata record) are surfaced as synthetic INDEXED
entries so they are never invisible to the management list.
"""
# Fetch live Milvus state first.
try:
milvus_rows = self.vector_index.list_document_metadata()
except Exception:
milvus_rows = []
milvus_by_id: dict[str, dict] = {r["doc_id"]: r for r in milvus_rows}
# Load metadata store records.
meta_docs = self.document_repository.list(limit=limit)
meta_by_id: dict[str, Document] = {d.doc_id: d for d in meta_docs}
result: list[Document] = []
# Reconcile metadata records against Milvus.
for doc in meta_docs:
if doc.doc_id in milvus_by_id:
row = milvus_by_id[doc.doc_id]
doc.chunk_count = row["chunk_count"]
doc.status = DocumentStatus.INDEXED
# Backfill fields that may be missing from older JSON records.
if not doc.doc_name and row.get("doc_name"):
doc.doc_name = row["doc_name"]
if not doc.regulation_type and row.get("regulation_type"):
doc.regulation_type = row["regulation_type"]
if not doc.version and row.get("version"):
doc.version = row["version"]
elif doc.status == DocumentStatus.INDEXED:
# Metadata says indexed but Milvus has no chunks.
doc.status = DocumentStatus.FAILED
doc.error_message = "向量数据库中未找到对应数据"
result.append(doc)
# Surface Milvus-only docs that have no metadata record at all.
for doc_id, row in milvus_by_id.items():
if doc_id not in meta_by_id:
synthetic = Document(
doc_id=doc_id,
doc_name=row.get("doc_name", doc_id),
file_name=row.get("doc_name", doc_id),
object_name="",
content_type="",
size_bytes=0,
status=DocumentStatus.INDEXED,
regulation_type=row.get("regulation_type", ""),
version=row.get("version", ""),
chunk_count=row["chunk_count"],
)
result.append(synthetic)
result.sort(key=lambda d: d.updated_at, reverse=True)
return result[:limit] if limit is not None else result
def download(self, doc_id: str) -> tuple[Document, bytes]:
"""Handle download for the Document Query Service instance."""
document = self.document_repository.get(doc_id)
if not document:
raise FileNotFoundError(f"文档不存在: {doc_id}")
return document, self.binary_store.read(document.object_name)