Fix centered content layout widths

This commit is contained in:
ash66
2026-05-26 12:34:12 +08:00
parent 34d72d7ce9
commit fec22a3a2c
16 changed files with 2102 additions and 140 deletions

View File

@@ -7,16 +7,22 @@ import tempfile
import uuid
import json
from dataclasses import dataclass
from datetime import UTC, datetime
from loguru import logger
from app.config.settings import settings
from app.domain.documents import (
ChunkBuilder,
Document,
DocumentArtifact,
DocumentBinaryStore,
DocumentParser,
DocumentProcessingRun,
DocumentProcessingStore,
DocumentRepository,
DocumentStatus,
DocumentStatusEvent,
ParseArtifactStore,
ParsedDocument,
)
@@ -39,6 +45,7 @@ class DocumentProcessResult:
class DocumentCommandService:
"""Provide the Document Command Service service."""
def __init__(
self,
*,
@@ -49,6 +56,7 @@ class DocumentCommandService:
embedding_provider: EmbeddingProvider,
vector_index: VectorIndex,
parse_artifact_store: ParseArtifactStore | None = None,
document_processing_store: DocumentProcessingStore | None = None,
) -> None:
"""Initialize the Document Command Service instance."""
self.document_repository = document_repository
@@ -58,6 +66,11 @@ class DocumentCommandService:
self.embedding_provider = embedding_provider
self.vector_index = vector_index
self.parse_artifact_store = parse_artifact_store
self.document_processing_store = document_processing_store
def _utcnow(self) -> datetime:
"""Return the current UTC timestamp for persisted processing metadata."""
return datetime.now(UTC)
def _save_parse_artifacts(self, *, doc_id: str, parsed_document: ParsedDocument) -> dict[str, str]:
"""Persist parse artifacts so troubleshooting does not depend on provider retention windows."""
@@ -80,6 +93,143 @@ class DocumentCommandService:
artifact_keys[name] = object_name
return artifact_keys
def _safe_create_processing_run(self, *, doc_id: str, trigger_type: str, generate_summary: bool) -> str | None:
"""Create a processing run record when the optional store is available."""
if not self.document_processing_store:
return None
run = DocumentProcessingRun(
run_id=str(uuid.uuid4()),
doc_id=doc_id,
trigger_type=trigger_type,
run_status="running",
parser_backend=settings.parser_backend,
chunk_backend=settings.chunk_backend,
embedding_model=settings.embedding_model,
metadata={"generate_summary": generate_summary},
)
try:
created = self.document_processing_store.create_run(run)
return created.run_id
except Exception:
logger.warning("DocumentProcessingStore.create_run failed for doc_id={}", doc_id)
return None
def _safe_append_status_event(
self,
*,
doc_id: str,
run_id: str | None,
from_status: str,
to_status: str,
stage: str,
message: str = "",
metadata: dict | None = None,
) -> None:
"""Append a status event without allowing auxiliary persistence failures to abort processing."""
if not self.document_processing_store or not run_id:
return
event = DocumentStatusEvent(
event_id=str(uuid.uuid4()),
doc_id=doc_id,
run_id=run_id,
from_status=from_status,
to_status=to_status,
stage=stage,
message=message,
metadata=metadata or {},
)
try:
self.document_processing_store.append_status_event(event)
except Exception:
logger.warning(
"DocumentProcessingStore.append_status_event failed for doc_id={}, run_id={}",
doc_id,
run_id,
)
def _safe_mark_run_stored(self, *, doc_id: str, run_id: str | None) -> None:
"""Mark the processing run as stored without affecting the main workflow."""
if not self.document_processing_store or not run_id:
return
try:
self.document_processing_store.mark_run_stored(run_id, stored_at=self._utcnow())
except Exception:
logger.warning("DocumentProcessingStore.mark_run_stored failed for doc_id={}, run_id={}", doc_id, run_id)
def _safe_mark_run_parsed(self, *, doc_id: str, run_id: str | None, parsed_document: ParsedDocument) -> None:
"""Persist parse completion details without failing the document pipeline."""
if not self.document_processing_store or not run_id:
return
try:
self.document_processing_store.mark_run_parsed(
run_id,
parser_backend=parsed_document.parser_name,
layout_count=int(parsed_document.metadata.get("layout_count", len(parsed_document.raw_layouts)) or 0),
structure_node_count=len(parsed_document.structure_nodes),
semantic_block_count=len(parsed_document.semantic_blocks),
vector_chunk_count=len(parsed_document.vector_chunks),
parsed_at=self._utcnow(),
metadata={"parse_task_id": parsed_document.metadata.get("task_id", "")},
)
except Exception:
logger.warning("DocumentProcessingStore.mark_run_parsed failed for doc_id={}, run_id={}", doc_id, run_id)
def _safe_replace_processing_artifacts(self, *, doc_id: str, run_id: str | None, artifact_keys: dict[str, str]) -> None:
"""Store artifact references without turning persistence drift into a user-visible failure."""
if not self.document_processing_store or not run_id:
return
artifacts = [
DocumentArtifact(
artifact_id=str(uuid.uuid4()),
doc_id=doc_id,
run_id=run_id,
artifact_type=artifact_type,
object_name=object_name,
content_type="application/json",
byte_size=0,
checksum="",
)
for artifact_type, object_name in artifact_keys.items()
]
try:
self.document_processing_store.replace_artifacts_for_run(run_id, artifacts)
except Exception:
logger.warning(
"DocumentProcessingStore.replace_artifacts_for_run failed for doc_id={}, run_id={}",
doc_id,
run_id,
)
def _safe_mark_run_indexed(self, *, doc_id: str, run_id: str | None, chunk_count: int, index_name: str) -> None:
"""Mark the processing run as indexed without affecting the success path."""
if not self.document_processing_store or not run_id:
return
now = self._utcnow()
try:
self.document_processing_store.mark_run_indexed(
run_id,
chunk_count=chunk_count,
index_name=index_name,
indexed_at=now,
finished_at=now,
)
except Exception:
logger.warning("DocumentProcessingStore.mark_run_indexed failed for doc_id={}, run_id={}", doc_id, run_id)
def _safe_mark_run_failed(self, *, doc_id: str, run_id: str | None, failure_stage: str, error_message: str) -> None:
"""Mark the processing run as failed without masking the original error handling path."""
if not self.document_processing_store or not run_id:
return
try:
self.document_processing_store.mark_run_failed(
run_id,
failure_stage=failure_stage,
error_message=error_message,
finished_at=self._utcnow(),
)
except Exception:
logger.warning("DocumentProcessingStore.mark_run_failed failed for doc_id={}, run_id={}", doc_id, run_id)
def upload_and_process(
self,
*,
@@ -91,11 +241,15 @@ class DocumentCommandService:
regulation_type: str,
version: str,
generate_summary: bool,
trigger_type: str = "upload",
) -> DocumentProcessResult:
"""Handle upload and process for the Document Command Service instance."""
doc_id = doc_id or str(uuid.uuid4())[:8]
final_doc_name = doc_name or file_name
object_name = f"{doc_id}/{file_name}"
run_id: str | None = None
current_status = DocumentStatus.PENDING
current_stage = "store"
document = Document(
doc_id=doc_id,
@@ -109,6 +263,19 @@ class DocumentCommandService:
metadata={"generate_summary": generate_summary},
)
self.document_repository.create(document)
run_id = self._safe_create_processing_run(
doc_id=doc_id,
trigger_type=trigger_type,
generate_summary=generate_summary,
)
self._safe_append_status_event(
doc_id=doc_id,
run_id=run_id,
from_status="",
to_status=DocumentStatus.PENDING.value,
stage="document_created",
message="Document record created",
)
temp_path = ""
try:
@@ -119,6 +286,17 @@ class DocumentCommandService:
metadata={"doc_id": doc_id},
)
self.document_repository.update_status(doc_id, DocumentStatus.STORED)
current_status = DocumentStatus.STORED
current_stage = "parse"
self._safe_mark_run_stored(doc_id=doc_id, run_id=run_id)
self._safe_append_status_event(
doc_id=doc_id,
run_id=run_id,
from_status=DocumentStatus.PENDING.value,
to_status=DocumentStatus.STORED.value,
stage="store",
message="Source file stored",
)
suffix = os.path.splitext(file_name)[1]
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file:
@@ -130,7 +308,13 @@ class DocumentCommandService:
doc_id=doc_id,
doc_name=final_doc_name,
)
artifact_keys = self._save_parse_artifacts(doc_id=doc_id, parsed_document=parsed_document)
self._safe_mark_run_parsed(doc_id=doc_id, run_id=run_id, parsed_document=parsed_document)
artifact_keys: dict[str, str] = {}
try:
artifact_keys = self._save_parse_artifacts(doc_id=doc_id, parsed_document=parsed_document)
except Exception:
logger.warning("Parse artifact binary persistence failed for doc_id={}", doc_id)
self.document_repository.update_status(
doc_id,
DocumentStatus.PARSED,
@@ -146,6 +330,18 @@ class DocumentCommandService:
"processing_stage": "parsed",
},
)
current_status = DocumentStatus.PARSED
current_stage = "embed"
self._safe_replace_processing_artifacts(doc_id=doc_id, run_id=run_id, artifact_keys=artifact_keys)
self._safe_append_status_event(
doc_id=doc_id,
run_id=run_id,
from_status=DocumentStatus.STORED.value,
to_status=DocumentStatus.PARSED.value,
stage="parse",
message="Document parsed",
metadata={"artifact_count": len(artifact_keys)},
)
if self.parse_artifact_store:
try:
self.parse_artifact_store.save(
@@ -165,6 +361,7 @@ class DocumentCommandService:
raise ValueError("解析完成但没有生成可入库的 chunks")
vectors = self.embedding_provider.embed_texts([chunk.embedding_text for chunk in chunks])
current_stage = "index"
inserted = self.vector_index.upsert(chunks, vectors)
if inserted != len(chunks):
logger.warning("Milvus upsert count mismatched: inserted={}, chunks={}", inserted, len(chunks))
@@ -182,6 +379,23 @@ class DocumentCommandService:
"processing_stage": "indexed",
},
)
current_status = DocumentStatus.INDEXED
index_name = health.get("collection_name", "")
self._safe_mark_run_indexed(
doc_id=doc_id,
run_id=run_id,
chunk_count=len(chunks),
index_name=index_name,
)
self._safe_append_status_event(
doc_id=doc_id,
run_id=run_id,
from_status=DocumentStatus.PARSED.value,
to_status=DocumentStatus.INDEXED.value,
stage="index",
message="Document indexed",
metadata={"chunk_count": len(chunks), "index_name": index_name},
)
stored = self.document_repository.get(doc_id)
return DocumentProcessResult(
doc_id=doc_id,
@@ -194,6 +408,7 @@ class DocumentCommandService:
)
except Exception as exc:
logger.exception("文档处理失败: doc_id={}", doc_id)
failure_stage = current_stage
self.document_repository.update_status(
doc_id,
DocumentStatus.FAILED,
@@ -201,8 +416,23 @@ class DocumentCommandService:
metadata={
"failure_reason": str(exc),
"processing_stage": "failed",
"failure_stage": failure_stage,
},
)
self._safe_mark_run_failed(
doc_id=doc_id,
run_id=run_id,
failure_stage=failure_stage,
error_message=str(exc),
)
self._safe_append_status_event(
doc_id=doc_id,
run_id=run_id,
from_status=current_status.value,
to_status=DocumentStatus.FAILED.value,
stage=failure_stage,
message=str(exc),
)
return DocumentProcessResult(
doc_id=doc_id,
doc_name=final_doc_name,
@@ -235,6 +465,11 @@ class DocumentCommandService:
self.parse_artifact_store.delete(doc_id)
except Exception:
logger.warning("ParseArtifactStore delete failed for doc_id={}", doc_id)
if self.document_processing_store:
try:
self.document_processing_store.delete_by_document(doc_id)
except Exception:
logger.warning("DocumentProcessingStore delete failed for doc_id={}", doc_id)
self.document_repository.delete(doc_id)
return True
@@ -253,6 +488,7 @@ class DocumentCommandService:
regulation_type=document.regulation_type,
version=document.version,
generate_summary=bool(document.metadata.get("generate_summary", False)),
trigger_type="retry",
)

View File

@@ -78,6 +78,7 @@ class Settings(BaseSettings):
chunk_overlap: int = Field(default=50, description="分块重叠大小")
max_file_size_mb: int = Field(default=100, description="最大文件大小(MB)")
document_metadata_path: str = Field(default="backend/data/documents.json", description="文档元数据存储路径")
document_processing_metadata_path: str = Field(default="backend/data/document_processing.json", description="文档处理历史存储路径")
parser_backend: str = Field(default="aliyun", description="解析后端(local/aliyun)")
chunk_backend: str = Field(default="aliyun", description="分块后端(local/aliyun)")
document_repository_backend: str = Field(default="json", description="文档元数据存储后端 (json/postgres)")

View File

@@ -1,18 +1,29 @@
"""Initialize the app.domain.documents package."""
from .models import Chunk, Document, DocumentStatus, ParsedDocument
from .ports import ChunkBuilder, DocumentBinaryStore, DocumentParser, DocumentRepository, ParseArtifactStore
from .models import Chunk, Document, DocumentArtifact, DocumentProcessingRun, DocumentStatus, DocumentStatusEvent, ParsedDocument
from .ports import (
ChunkBuilder,
DocumentBinaryStore,
DocumentParser,
DocumentProcessingStore,
DocumentRepository,
ParseArtifactStore,
)
# Keep package boundaries explicit so backend imports stay predictable.
__all__ = [
"Chunk",
"Document",
"DocumentArtifact",
"DocumentProcessingRun",
"DocumentStatus",
"DocumentStatusEvent",
"ParsedDocument",
"ChunkBuilder",
"DocumentBinaryStore",
"DocumentParser",
"DocumentProcessingStore",
"DocumentRepository",
"ParseArtifactStore",
]

View File

@@ -76,3 +76,61 @@ class Chunk:
semantic_id: str = ""
block_type: str = ""
metadata: dict[str, Any] = field(default_factory=dict)
@dataclass
class DocumentProcessingRun:
"""Represent one processing attempt for a document."""
run_id: str
doc_id: str
trigger_type: str
run_status: str
parser_backend: str = ""
chunk_backend: str = ""
embedding_model: str = ""
index_name: str = ""
started_at: datetime = field(default_factory=utcnow)
stored_at: datetime | None = None
parsed_at: datetime | None = None
indexed_at: datetime | None = None
finished_at: datetime | None = None
layout_count: int = 0
structure_node_count: int = 0
semantic_block_count: int = 0
vector_chunk_count: int = 0
chunk_count: int = 0
failure_stage: str = ""
error_message: str = ""
metadata: dict[str, Any] = field(default_factory=dict)
@dataclass
class DocumentStatusEvent:
"""Represent a document lifecycle event emitted during processing."""
event_id: str
doc_id: str
run_id: str
from_status: str
to_status: str
stage: str
message: str = ""
metadata: dict[str, Any] = field(default_factory=dict)
occurred_at: datetime = field(default_factory=utcnow)
@dataclass
class DocumentArtifact:
"""Represent a persisted artifact reference for one processing run."""
artifact_id: str
doc_id: str
run_id: str
artifact_type: str
object_name: str
content_type: str
byte_size: int = 0
checksum: str = ""
metadata: dict[str, Any] = field(default_factory=dict)
created_at: datetime = field(default_factory=utcnow)

View File

@@ -4,7 +4,7 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from .models import Chunk, Document, DocumentStatus, ParsedDocument
from .models import Chunk, Document, DocumentArtifact, DocumentProcessingRun, DocumentStatus, DocumentStatusEvent, ParsedDocument
# Keep domain contracts explicit so adapters can swap implementations cleanly.
@@ -128,3 +128,111 @@ class ParseArtifactStore(ABC):
def get_structure_nodes(self, doc_id: str) -> list[dict]:
"""Return all structure nodes for a document."""
pass
class DocumentProcessingStore(ABC):
"""Persist document processing runs, events, and artifact references."""
@abstractmethod
def create_run(self, run: DocumentProcessingRun) -> DocumentProcessingRun:
"""Create a new processing run record."""
pass
@abstractmethod
def mark_run_stored(
self,
run_id: str,
*,
stored_at: object | None = None,
metadata: dict | None = None,
) -> DocumentProcessingRun | None:
"""Mark a run as having persisted the source file."""
pass
@abstractmethod
def mark_run_parsed(
self,
run_id: str,
*,
parser_backend: str,
layout_count: int,
structure_node_count: int,
semantic_block_count: int,
vector_chunk_count: int,
parsed_at: object | None = None,
metadata: dict | None = None,
) -> DocumentProcessingRun | None:
"""Record parse completion details for a run."""
pass
@abstractmethod
def mark_run_indexed(
self,
run_id: str,
*,
chunk_count: int,
index_name: str,
indexed_at: object | None = None,
finished_at: object | None = None,
metadata: dict | None = None,
) -> DocumentProcessingRun | None:
"""Mark a run as successfully indexed."""
pass
@abstractmethod
def mark_run_failed(
self,
run_id: str,
*,
failure_stage: str,
error_message: str,
finished_at: object | None = None,
metadata: dict | None = None,
) -> DocumentProcessingRun | None:
"""Mark a run as failed."""
pass
@abstractmethod
def append_status_event(self, event: DocumentStatusEvent) -> DocumentStatusEvent:
"""Append a document status event."""
pass
@abstractmethod
def replace_artifacts_for_run(self, run_id: str, artifacts: list[DocumentArtifact]) -> list[DocumentArtifact]:
"""Replace all artifacts for a run with the provided list."""
pass
@abstractmethod
def delete_by_document(self, doc_id: str) -> None:
"""Delete all processing data for a document."""
pass
@abstractmethod
def list_runs_by_document(self, doc_id: str) -> list[DocumentProcessingRun]:
"""List all processing runs for a document."""
pass
@abstractmethod
def get_run(self, run_id: str) -> DocumentProcessingRun | None:
"""Return one processing run by identifier."""
pass
@abstractmethod
def list_status_events_by_document(self, doc_id: str) -> list[DocumentStatusEvent]:
"""List status events for a document."""
pass
@abstractmethod
def list_status_events_by_run(self, run_id: str) -> list[DocumentStatusEvent]:
"""List status events for a run."""
pass
@abstractmethod
def list_artifacts_by_document(self, doc_id: str) -> list[DocumentArtifact]:
"""List artifact references for a document."""
pass
@abstractmethod
def list_artifacts_by_run(self, run_id: str) -> list[DocumentArtifact]:
"""List artifact references for a run."""
pass

View File

@@ -0,0 +1,373 @@
"""Implement infrastructure support for json document processing history."""
from __future__ import annotations
import json
from datetime import UTC, datetime
from pathlib import Path
from typing import Any
from app.domain.documents import DocumentArtifact, DocumentProcessingRun, DocumentProcessingStore, DocumentStatusEvent
# Keep JSON persistence behavior aligned with the lightweight document repository adapter.
class JsonDocumentProcessingStore(DocumentProcessingStore):
"""Persist processing history in a standalone JSON file."""
def __init__(self, file_path: str) -> None:
"""Initialize the JSON processing history store."""
self.file_path = Path(file_path)
self.file_path.parent.mkdir(parents=True, exist_ok=True)
if not self.file_path.exists():
self._save(self._empty_payload())
def _empty_payload(self) -> dict[str, dict[str, dict[str, Any]]]:
"""Return the canonical empty JSON structure for processing history."""
return {"runs": {}, "status_events": {}, "artifacts": {}}
def _load(self) -> dict[str, dict[str, dict[str, Any]]]:
"""Load the full JSON payload and normalize missing sections."""
if not self.file_path.exists():
return self._empty_payload()
payload = json.loads(self.file_path.read_text(encoding="utf-8") or "{}")
normalized = self._empty_payload()
for key in normalized:
section = payload.get(key, {})
normalized[key] = section if isinstance(section, dict) else {}
return normalized
def _save(self, payload: dict[str, dict[str, dict[str, Any]]]) -> None:
"""Persist the full JSON payload with stable formatting."""
self.file_path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
def _serialize_datetime(self, value: datetime | None) -> str | None:
"""Serialize optional datetimes into ISO8601 strings."""
return value.isoformat() if value is not None else None
def _deserialize_datetime(self, value: str | None) -> datetime | None:
"""Deserialize optional ISO8601 strings into datetimes."""
return datetime.fromisoformat(value) if value else None
def _serialize_run(self, run: DocumentProcessingRun) -> dict[str, Any]:
"""Serialize one processing run to a JSON-compatible payload."""
return {
"run_id": run.run_id,
"doc_id": run.doc_id,
"trigger_type": run.trigger_type,
"run_status": run.run_status,
"parser_backend": run.parser_backend,
"chunk_backend": run.chunk_backend,
"embedding_model": run.embedding_model,
"index_name": run.index_name,
"started_at": self._serialize_datetime(run.started_at),
"stored_at": self._serialize_datetime(run.stored_at),
"parsed_at": self._serialize_datetime(run.parsed_at),
"indexed_at": self._serialize_datetime(run.indexed_at),
"finished_at": self._serialize_datetime(run.finished_at),
"layout_count": run.layout_count,
"structure_node_count": run.structure_node_count,
"semantic_block_count": run.semantic_block_count,
"vector_chunk_count": run.vector_chunk_count,
"chunk_count": run.chunk_count,
"failure_stage": run.failure_stage,
"error_message": run.error_message,
"metadata": run.metadata,
}
def _deserialize_run(self, payload: dict[str, Any]) -> DocumentProcessingRun:
"""Deserialize one JSON payload into a processing run dataclass."""
return DocumentProcessingRun(
run_id=payload["run_id"],
doc_id=payload["doc_id"],
trigger_type=payload["trigger_type"],
run_status=payload["run_status"],
parser_backend=payload.get("parser_backend", ""),
chunk_backend=payload.get("chunk_backend", ""),
embedding_model=payload.get("embedding_model", ""),
index_name=payload.get("index_name", ""),
started_at=self._deserialize_datetime(payload.get("started_at")) or datetime.now(UTC),
stored_at=self._deserialize_datetime(payload.get("stored_at")),
parsed_at=self._deserialize_datetime(payload.get("parsed_at")),
indexed_at=self._deserialize_datetime(payload.get("indexed_at")),
finished_at=self._deserialize_datetime(payload.get("finished_at")),
layout_count=int(payload.get("layout_count", 0) or 0),
structure_node_count=int(payload.get("structure_node_count", 0) or 0),
semantic_block_count=int(payload.get("semantic_block_count", 0) or 0),
vector_chunk_count=int(payload.get("vector_chunk_count", 0) or 0),
chunk_count=int(payload.get("chunk_count", 0) or 0),
failure_stage=payload.get("failure_stage", ""),
error_message=payload.get("error_message", ""),
metadata=payload.get("metadata", {}),
)
def _serialize_event(self, event: DocumentStatusEvent) -> dict[str, Any]:
"""Serialize one status event to a JSON-compatible payload."""
return {
"event_id": event.event_id,
"doc_id": event.doc_id,
"run_id": event.run_id,
"from_status": event.from_status,
"to_status": event.to_status,
"stage": event.stage,
"message": event.message,
"metadata": event.metadata,
"occurred_at": self._serialize_datetime(event.occurred_at),
}
def _deserialize_event(self, payload: dict[str, Any]) -> DocumentStatusEvent:
"""Deserialize one JSON payload into a status event dataclass."""
return DocumentStatusEvent(
event_id=payload["event_id"],
doc_id=payload["doc_id"],
run_id=payload["run_id"],
from_status=payload.get("from_status", ""),
to_status=payload["to_status"],
stage=payload.get("stage", ""),
message=payload.get("message", ""),
metadata=payload.get("metadata", {}),
occurred_at=self._deserialize_datetime(payload.get("occurred_at")) or datetime.now(UTC),
)
def _serialize_artifact(self, artifact: DocumentArtifact) -> dict[str, Any]:
"""Serialize one artifact reference to a JSON-compatible payload."""
return {
"artifact_id": artifact.artifact_id,
"doc_id": artifact.doc_id,
"run_id": artifact.run_id,
"artifact_type": artifact.artifact_type,
"object_name": artifact.object_name,
"content_type": artifact.content_type,
"byte_size": artifact.byte_size,
"checksum": artifact.checksum,
"metadata": artifact.metadata,
"created_at": self._serialize_datetime(artifact.created_at),
}
def _deserialize_artifact(self, payload: dict[str, Any]) -> DocumentArtifact:
"""Deserialize one JSON payload into an artifact dataclass."""
return DocumentArtifact(
artifact_id=payload["artifact_id"],
doc_id=payload["doc_id"],
run_id=payload["run_id"],
artifact_type=payload["artifact_type"],
object_name=payload["object_name"],
content_type=payload.get("content_type", ""),
byte_size=int(payload.get("byte_size", 0) or 0),
checksum=payload.get("checksum", ""),
metadata=payload.get("metadata", {}),
created_at=self._deserialize_datetime(payload.get("created_at")) or datetime.now(UTC),
)
def _merge_metadata(self, original: dict[str, Any], update: dict | None) -> dict[str, Any]:
"""Merge metadata updates onto an existing payload."""
merged = dict(original)
if update:
merged.update(update)
return merged
def create_run(self, run: DocumentProcessingRun) -> DocumentProcessingRun:
"""Create a new processing run record."""
payload = self._load()
payload["runs"][run.run_id] = self._serialize_run(run)
self._save(payload)
return run
def mark_run_stored(
self,
run_id: str,
*,
stored_at: datetime | None = None,
metadata: dict | None = None,
) -> DocumentProcessingRun | None:
"""Mark a run as having persisted the source file."""
payload = self._load()
run_payload = payload["runs"].get(run_id)
if not run_payload:
return None
run = self._deserialize_run(run_payload)
run.stored_at = stored_at or datetime.now(UTC)
run.metadata = self._merge_metadata(run.metadata, metadata)
payload["runs"][run_id] = self._serialize_run(run)
self._save(payload)
return run
def mark_run_parsed(
self,
run_id: str,
*,
parser_backend: str,
layout_count: int,
structure_node_count: int,
semantic_block_count: int,
vector_chunk_count: int,
parsed_at: datetime | None = None,
metadata: dict | None = None,
) -> DocumentProcessingRun | None:
"""Record parse completion details for a run."""
payload = self._load()
run_payload = payload["runs"].get(run_id)
if not run_payload:
return None
run = self._deserialize_run(run_payload)
run.parser_backend = parser_backend
run.layout_count = layout_count
run.structure_node_count = structure_node_count
run.semantic_block_count = semantic_block_count
run.vector_chunk_count = vector_chunk_count
run.parsed_at = parsed_at or datetime.now(UTC)
run.metadata = self._merge_metadata(run.metadata, metadata)
payload["runs"][run_id] = self._serialize_run(run)
self._save(payload)
return run
def mark_run_indexed(
self,
run_id: str,
*,
chunk_count: int,
index_name: str,
indexed_at: datetime | None = None,
finished_at: datetime | None = None,
metadata: dict | None = None,
) -> DocumentProcessingRun | None:
"""Mark a run as successfully indexed."""
payload = self._load()
run_payload = payload["runs"].get(run_id)
if not run_payload:
return None
run = self._deserialize_run(run_payload)
now = datetime.now(UTC)
run.run_status = "succeeded"
run.chunk_count = chunk_count
run.index_name = index_name
run.indexed_at = indexed_at or now
run.finished_at = finished_at or now
run.metadata = self._merge_metadata(run.metadata, metadata)
payload["runs"][run_id] = self._serialize_run(run)
self._save(payload)
return run
def mark_run_failed(
self,
run_id: str,
*,
failure_stage: str,
error_message: str,
finished_at: datetime | None = None,
metadata: dict | None = None,
) -> DocumentProcessingRun | None:
"""Mark a run as failed."""
payload = self._load()
run_payload = payload["runs"].get(run_id)
if not run_payload:
return None
run = self._deserialize_run(run_payload)
run.run_status = "failed"
run.failure_stage = failure_stage
run.error_message = error_message
run.finished_at = finished_at or datetime.now(UTC)
run.metadata = self._merge_metadata(run.metadata, metadata)
payload["runs"][run_id] = self._serialize_run(run)
self._save(payload)
return run
def append_status_event(self, event: DocumentStatusEvent) -> DocumentStatusEvent:
"""Append a document status event."""
payload = self._load()
payload["status_events"][event.event_id] = self._serialize_event(event)
self._save(payload)
return event
def replace_artifacts_for_run(self, run_id: str, artifacts: list[DocumentArtifact]) -> list[DocumentArtifact]:
"""Replace all artifacts for a run with the provided list."""
payload = self._load()
payload["artifacts"] = {
artifact_id: artifact_payload
for artifact_id, artifact_payload in payload["artifacts"].items()
if artifact_payload.get("run_id") != run_id
}
for artifact in artifacts:
payload["artifacts"][artifact.artifact_id] = self._serialize_artifact(artifact)
self._save(payload)
return artifacts
def delete_by_document(self, doc_id: str) -> None:
"""Delete all processing data for a document."""
payload = self._load()
payload["runs"] = {
run_id: run_payload
for run_id, run_payload in payload["runs"].items()
if run_payload.get("doc_id") != doc_id
}
payload["status_events"] = {
event_id: event_payload
for event_id, event_payload in payload["status_events"].items()
if event_payload.get("doc_id") != doc_id
}
payload["artifacts"] = {
artifact_id: artifact_payload
for artifact_id, artifact_payload in payload["artifacts"].items()
if artifact_payload.get("doc_id") != doc_id
}
self._save(payload)
def list_runs_by_document(self, doc_id: str) -> list[DocumentProcessingRun]:
"""List all processing runs for a document."""
payload = self._load()
runs = [
self._deserialize_run(run_payload)
for run_payload in payload["runs"].values()
if run_payload.get("doc_id") == doc_id
]
runs.sort(key=lambda run: run.started_at)
return runs
def get_run(self, run_id: str) -> DocumentProcessingRun | None:
"""Return one processing run by identifier."""
payload = self._load()
run_payload = payload["runs"].get(run_id)
return self._deserialize_run(run_payload) if run_payload else None
def list_status_events_by_document(self, doc_id: str) -> list[DocumentStatusEvent]:
"""List status events for a document."""
payload = self._load()
events = [
self._deserialize_event(event_payload)
for event_payload in payload["status_events"].values()
if event_payload.get("doc_id") == doc_id
]
events.sort(key=lambda event: event.occurred_at)
return events
def list_status_events_by_run(self, run_id: str) -> list[DocumentStatusEvent]:
"""List status events for a run."""
payload = self._load()
events = [
self._deserialize_event(event_payload)
for event_payload in payload["status_events"].values()
if event_payload.get("run_id") == run_id
]
events.sort(key=lambda event: event.occurred_at)
return events
def list_artifacts_by_document(self, doc_id: str) -> list[DocumentArtifact]:
"""List artifact references for a document."""
payload = self._load()
artifacts = [
self._deserialize_artifact(artifact_payload)
for artifact_payload in payload["artifacts"].values()
if artifact_payload.get("doc_id") == doc_id
]
artifacts.sort(key=lambda artifact: artifact.created_at)
return artifacts
def list_artifacts_by_run(self, run_id: str) -> list[DocumentArtifact]:
"""List artifact references for a run."""
payload = self._load()
artifacts = [
self._deserialize_artifact(artifact_payload)
for artifact_payload in payload["artifacts"].values()
if artifact_payload.get("run_id") == run_id
]
artifacts.sort(key=lambda artifact: artifact.created_at)
return artifacts

View File

@@ -0,0 +1,466 @@
"""Implement infrastructure support for postgres document processing history."""
from __future__ import annotations
import json
from contextlib import contextmanager
from datetime import UTC, datetime
from typing import Any
import psycopg2
import psycopg2.extras
from psycopg2.pool import ThreadedConnectionPool
from app.config.settings import settings
from app.domain.documents import DocumentArtifact, DocumentProcessingRun, DocumentProcessingStore, DocumentStatusEvent
# Keep SQL mapping local to this adapter so the domain stays storage-agnostic.
_CREATE_RUNS_TABLE = """
CREATE TABLE IF NOT EXISTS document_processing_runs (
run_id VARCHAR(128) PRIMARY KEY,
doc_id VARCHAR(128) NOT NULL,
trigger_type VARCHAR(32) NOT NULL,
run_status VARCHAR(32) NOT NULL DEFAULT 'running',
parser_backend VARCHAR(128) NOT NULL DEFAULT '',
chunk_backend VARCHAR(128) NOT NULL DEFAULT '',
embedding_model VARCHAR(256) NOT NULL DEFAULT '',
index_name VARCHAR(128) NOT NULL DEFAULT '',
started_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
stored_at TIMESTAMPTZ,
parsed_at TIMESTAMPTZ,
indexed_at TIMESTAMPTZ,
finished_at TIMESTAMPTZ,
layout_count INTEGER NOT NULL DEFAULT 0,
structure_node_count INTEGER NOT NULL DEFAULT 0,
semantic_block_count INTEGER NOT NULL DEFAULT 0,
vector_chunk_count INTEGER NOT NULL DEFAULT 0,
chunk_count INTEGER NOT NULL DEFAULT 0,
failure_stage VARCHAR(64) NOT NULL DEFAULT '',
error_message TEXT NOT NULL DEFAULT '',
metadata JSONB NOT NULL DEFAULT '{}',
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
CONSTRAINT fk_dpr_doc FOREIGN KEY (doc_id) REFERENCES documents(doc_id) ON DELETE CASCADE
);
CREATE INDEX IF NOT EXISTS idx_document_processing_runs_doc_id ON document_processing_runs(doc_id, started_at DESC);
"""
_CREATE_EVENTS_TABLE = """
CREATE TABLE IF NOT EXISTS document_status_history (
event_id VARCHAR(128) PRIMARY KEY,
doc_id VARCHAR(128) NOT NULL,
run_id VARCHAR(128) NOT NULL,
from_status VARCHAR(32) NOT NULL DEFAULT '',
to_status VARCHAR(32) NOT NULL,
stage VARCHAR(64) NOT NULL DEFAULT '',
message TEXT NOT NULL DEFAULT '',
metadata JSONB NOT NULL DEFAULT '{}',
occurred_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
CONSTRAINT fk_dsh_doc FOREIGN KEY (doc_id) REFERENCES documents(doc_id) ON DELETE CASCADE,
CONSTRAINT fk_dsh_run FOREIGN KEY (run_id) REFERENCES document_processing_runs(run_id) ON DELETE CASCADE
);
CREATE INDEX IF NOT EXISTS idx_document_status_history_doc_id ON document_status_history(doc_id, occurred_at ASC);
CREATE INDEX IF NOT EXISTS idx_document_status_history_run_id ON document_status_history(run_id, occurred_at ASC);
"""
_CREATE_ARTIFACTS_TABLE = """
CREATE TABLE IF NOT EXISTS document_artifacts (
artifact_id VARCHAR(128) PRIMARY KEY,
doc_id VARCHAR(128) NOT NULL,
run_id VARCHAR(128) NOT NULL,
artifact_type VARCHAR(64) NOT NULL,
object_name VARCHAR(1024) NOT NULL,
content_type VARCHAR(128) NOT NULL DEFAULT '',
byte_size BIGINT NOT NULL DEFAULT 0,
checksum VARCHAR(256) NOT NULL DEFAULT '',
metadata JSONB NOT NULL DEFAULT '{}',
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
CONSTRAINT fk_da_doc FOREIGN KEY (doc_id) REFERENCES documents(doc_id) ON DELETE CASCADE,
CONSTRAINT fk_da_run FOREIGN KEY (run_id) REFERENCES document_processing_runs(run_id) ON DELETE CASCADE
);
CREATE INDEX IF NOT EXISTS idx_document_artifacts_doc_id ON document_artifacts(doc_id, created_at ASC);
CREATE INDEX IF NOT EXISTS idx_document_artifacts_run_id ON document_artifacts(run_id, created_at ASC);
"""
class PostgresDocumentProcessingStore(DocumentProcessingStore):
"""Persist processing history in PostgreSQL using handwritten SQL."""
def __init__(self) -> None:
"""Initialize the store and ensure the required tables exist."""
self._pool = ThreadedConnectionPool(
minconn=1,
maxconn=5,
host=settings.postgres_host,
port=settings.postgres_port,
user=settings.postgres_user,
password=settings.postgres_password,
dbname=settings.postgres_db,
)
self._ensure_schema()
def _ensure_schema(self) -> None:
"""Create processing history tables and indexes if they are missing."""
with self._conn() as conn:
with conn.cursor() as cur:
cur.execute(_CREATE_RUNS_TABLE)
cur.execute(_CREATE_EVENTS_TABLE)
cur.execute(_CREATE_ARTIFACTS_TABLE)
conn.commit()
@contextmanager
def _conn(self):
"""Borrow one connection from the pool and return it afterwards."""
conn = self._pool.getconn()
try:
yield conn
finally:
self._pool.putconn(conn)
def _normalize_metadata(self, value: Any) -> dict[str, Any]:
"""Return a JSON-object payload regardless of the row representation."""
if isinstance(value, dict):
return value
if not value:
return {}
return json.loads(value)
def _row_to_run(self, row: dict[str, Any]) -> DocumentProcessingRun:
"""Map one run row into the domain dataclass."""
return DocumentProcessingRun(
run_id=row["run_id"],
doc_id=row["doc_id"],
trigger_type=row["trigger_type"],
run_status=row["run_status"],
parser_backend=row["parser_backend"],
chunk_backend=row["chunk_backend"],
embedding_model=row["embedding_model"],
index_name=row["index_name"],
started_at=row["started_at"],
stored_at=row["stored_at"],
parsed_at=row["parsed_at"],
indexed_at=row["indexed_at"],
finished_at=row["finished_at"],
layout_count=row["layout_count"],
structure_node_count=row["structure_node_count"],
semantic_block_count=row["semantic_block_count"],
vector_chunk_count=row["vector_chunk_count"],
chunk_count=row["chunk_count"],
failure_stage=row["failure_stage"],
error_message=row["error_message"],
metadata=self._normalize_metadata(row["metadata"]),
)
def _row_to_event(self, row: dict[str, Any]) -> DocumentStatusEvent:
"""Map one event row into the domain dataclass."""
return DocumentStatusEvent(
event_id=row["event_id"],
doc_id=row["doc_id"],
run_id=row["run_id"],
from_status=row["from_status"],
to_status=row["to_status"],
stage=row["stage"],
message=row["message"],
metadata=self._normalize_metadata(row["metadata"]),
occurred_at=row["occurred_at"],
)
def _row_to_artifact(self, row: dict[str, Any]) -> DocumentArtifact:
"""Map one artifact row into the domain dataclass."""
return DocumentArtifact(
artifact_id=row["artifact_id"],
doc_id=row["doc_id"],
run_id=row["run_id"],
artifact_type=row["artifact_type"],
object_name=row["object_name"],
content_type=row["content_type"],
byte_size=row["byte_size"],
checksum=row["checksum"],
metadata=self._normalize_metadata(row["metadata"]),
created_at=row["created_at"],
)
def _update_run(
self,
run_id: str,
*,
assignments: dict[str, Any],
metadata: dict | None = None,
) -> DocumentProcessingRun | None:
"""Update one run row and return the latest stored state."""
set_clauses = []
params: dict[str, Any] = {"run_id": run_id, "updated_at": datetime.now(UTC)}
for key, value in assignments.items():
set_clauses.append(f"{key} = %({key})s")
params[key] = value
set_clauses.append("updated_at = %(updated_at)s")
if metadata is not None:
set_clauses.append("metadata = COALESCE(metadata, '{}'::jsonb) || %(metadata)s::jsonb")
params["metadata"] = json.dumps(metadata, ensure_ascii=False)
sql = f"""
UPDATE document_processing_runs
SET {", ".join(set_clauses)}
WHERE run_id = %(run_id)s
RETURNING *
"""
with self._conn() as conn:
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
cur.execute(sql, params)
row = cur.fetchone()
conn.commit()
return self._row_to_run(dict(row)) if row else None
def create_run(self, run: DocumentProcessingRun) -> DocumentProcessingRun:
"""Create a new processing run record."""
sql = """
INSERT INTO document_processing_runs
(run_id, doc_id, trigger_type, run_status, parser_backend, chunk_backend,
embedding_model, index_name, started_at, stored_at, parsed_at, indexed_at,
finished_at, layout_count, structure_node_count, semantic_block_count,
vector_chunk_count, chunk_count, failure_stage, error_message, metadata)
VALUES
(%(run_id)s, %(doc_id)s, %(trigger_type)s, %(run_status)s, %(parser_backend)s,
%(chunk_backend)s, %(embedding_model)s, %(index_name)s, %(started_at)s,
%(stored_at)s, %(parsed_at)s, %(indexed_at)s, %(finished_at)s, %(layout_count)s,
%(structure_node_count)s, %(semantic_block_count)s, %(vector_chunk_count)s,
%(chunk_count)s, %(failure_stage)s, %(error_message)s, %(metadata)s)
"""
with self._conn() as conn:
with conn.cursor() as cur:
cur.execute(
sql,
{
"run_id": run.run_id,
"doc_id": run.doc_id,
"trigger_type": run.trigger_type,
"run_status": run.run_status,
"parser_backend": run.parser_backend,
"chunk_backend": run.chunk_backend,
"embedding_model": run.embedding_model,
"index_name": run.index_name,
"started_at": run.started_at,
"stored_at": run.stored_at,
"parsed_at": run.parsed_at,
"indexed_at": run.indexed_at,
"finished_at": run.finished_at,
"layout_count": run.layout_count,
"structure_node_count": run.structure_node_count,
"semantic_block_count": run.semantic_block_count,
"vector_chunk_count": run.vector_chunk_count,
"chunk_count": run.chunk_count,
"failure_stage": run.failure_stage,
"error_message": run.error_message,
"metadata": json.dumps(run.metadata, ensure_ascii=False),
},
)
conn.commit()
return run
def mark_run_stored(
self,
run_id: str,
*,
stored_at: datetime | None = None,
metadata: dict | None = None,
) -> DocumentProcessingRun | None:
"""Mark a run as having persisted its source file."""
return self._update_run(
run_id,
assignments={"stored_at": stored_at or datetime.now(UTC)},
metadata=metadata,
)
def mark_run_parsed(
self,
run_id: str,
*,
parser_backend: str,
layout_count: int,
structure_node_count: int,
semantic_block_count: int,
vector_chunk_count: int,
parsed_at: datetime | None = None,
metadata: dict | None = None,
) -> DocumentProcessingRun | None:
"""Record parse completion metrics for a run."""
return self._update_run(
run_id,
assignments={
"parser_backend": parser_backend,
"parsed_at": parsed_at or datetime.now(UTC),
"layout_count": layout_count,
"structure_node_count": structure_node_count,
"semantic_block_count": semantic_block_count,
"vector_chunk_count": vector_chunk_count,
},
metadata=metadata,
)
def mark_run_indexed(
self,
run_id: str,
*,
chunk_count: int,
index_name: str,
indexed_at: datetime | None = None,
finished_at: datetime | None = None,
metadata: dict | None = None,
) -> DocumentProcessingRun | None:
"""Mark a run as successfully indexed."""
now = datetime.now(UTC)
return self._update_run(
run_id,
assignments={
"run_status": "succeeded",
"chunk_count": chunk_count,
"index_name": index_name,
"indexed_at": indexed_at or now,
"finished_at": finished_at or now,
},
metadata=metadata,
)
def mark_run_failed(
self,
run_id: str,
*,
failure_stage: str,
error_message: str,
finished_at: datetime | None = None,
metadata: dict | None = None,
) -> DocumentProcessingRun | None:
"""Mark a run as failed and persist the terminal error details."""
return self._update_run(
run_id,
assignments={
"run_status": "failed",
"failure_stage": failure_stage,
"error_message": error_message,
"finished_at": finished_at or datetime.now(UTC),
},
metadata=metadata,
)
def append_status_event(self, event: DocumentStatusEvent) -> DocumentStatusEvent:
"""Append a document status event."""
sql = """
INSERT INTO document_status_history
(event_id, doc_id, run_id, from_status, to_status, stage, message, metadata, occurred_at)
VALUES
(%(event_id)s, %(doc_id)s, %(run_id)s, %(from_status)s, %(to_status)s,
%(stage)s, %(message)s, %(metadata)s, %(occurred_at)s)
"""
with self._conn() as conn:
with conn.cursor() as cur:
cur.execute(
sql,
{
"event_id": event.event_id,
"doc_id": event.doc_id,
"run_id": event.run_id,
"from_status": event.from_status,
"to_status": event.to_status,
"stage": event.stage,
"message": event.message,
"metadata": json.dumps(event.metadata, ensure_ascii=False),
"occurred_at": event.occurred_at,
},
)
conn.commit()
return event
def replace_artifacts_for_run(self, run_id: str, artifacts: list[DocumentArtifact]) -> list[DocumentArtifact]:
"""Replace all artifact references for one run using a delete-then-insert strategy."""
with self._conn() as conn:
with conn.cursor() as cur:
cur.execute("DELETE FROM document_artifacts WHERE run_id = %s", (run_id,))
if artifacts:
psycopg2.extras.execute_values(
cur,
"""
INSERT INTO document_artifacts
(artifact_id, doc_id, run_id, artifact_type, object_name,
content_type, byte_size, checksum, metadata, created_at)
VALUES %s
""",
[
(
artifact.artifact_id,
artifact.doc_id,
artifact.run_id,
artifact.artifact_type,
artifact.object_name,
artifact.content_type,
artifact.byte_size,
artifact.checksum,
json.dumps(artifact.metadata, ensure_ascii=False),
artifact.created_at,
)
for artifact in artifacts
],
)
conn.commit()
return artifacts
def delete_by_document(self, doc_id: str) -> None:
"""Delete all processing rows for a document explicitly."""
with self._conn() as conn:
with conn.cursor() as cur:
cur.execute("DELETE FROM document_status_history WHERE doc_id = %s", (doc_id,))
cur.execute("DELETE FROM document_artifacts WHERE doc_id = %s", (doc_id,))
cur.execute("DELETE FROM document_processing_runs WHERE doc_id = %s", (doc_id,))
conn.commit()
def list_runs_by_document(self, doc_id: str) -> list[DocumentProcessingRun]:
"""List processing runs for a document in chronological order."""
sql = "SELECT * FROM document_processing_runs WHERE doc_id = %s ORDER BY started_at ASC"
with self._conn() as conn:
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
cur.execute(sql, (doc_id,))
rows = cur.fetchall()
return [self._row_to_run(dict(row)) for row in rows]
def get_run(self, run_id: str) -> DocumentProcessingRun | None:
"""Return one processing run by identifier."""
sql = "SELECT * FROM document_processing_runs WHERE run_id = %s"
with self._conn() as conn:
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
cur.execute(sql, (run_id,))
row = cur.fetchone()
return self._row_to_run(dict(row)) if row else None
def list_status_events_by_document(self, doc_id: str) -> list[DocumentStatusEvent]:
"""List all status events for a document."""
sql = "SELECT * FROM document_status_history WHERE doc_id = %s ORDER BY occurred_at ASC"
with self._conn() as conn:
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
cur.execute(sql, (doc_id,))
rows = cur.fetchall()
return [self._row_to_event(dict(row)) for row in rows]
def list_status_events_by_run(self, run_id: str) -> list[DocumentStatusEvent]:
"""List all status events for a run."""
sql = "SELECT * FROM document_status_history WHERE run_id = %s ORDER BY occurred_at ASC"
with self._conn() as conn:
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
cur.execute(sql, (run_id,))
rows = cur.fetchall()
return [self._row_to_event(dict(row)) for row in rows]
def list_artifacts_by_document(self, doc_id: str) -> list[DocumentArtifact]:
"""List all artifact references for a document."""
sql = "SELECT * FROM document_artifacts WHERE doc_id = %s ORDER BY created_at ASC"
with self._conn() as conn:
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
cur.execute(sql, (doc_id,))
rows = cur.fetchall()
return [self._row_to_artifact(dict(row)) for row in rows]
def list_artifacts_by_run(self, run_id: str) -> list[DocumentArtifact]:
"""List all artifact references for a run."""
sql = "SELECT * FROM document_artifacts WHERE run_id = %s ORDER BY created_at ASC"
with self._conn() as conn:
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
cur.execute(sql, (run_id,))
rows = cur.fetchall()
return [self._row_to_artifact(dict(row)) for row in rows]

View File

@@ -20,8 +20,10 @@ from app.infrastructure.parser.local_document_parser import LocalDocumentParser
from app.infrastructure.parser.vector_chunk_builder import AliyunVectorChunkBuilder
from app.infrastructure.perception.mock_event_store import MockEventStore
from app.infrastructure.session.in_memory_conversation_store import InMemoryConversationStore
from app.infrastructure.storage.json_document_processing_store import JsonDocumentProcessingStore
from app.infrastructure.storage.json_document_repository import JsonDocumentRepository
from app.infrastructure.storage.minio_binary_store import MinioDocumentBinaryStore
from app.infrastructure.storage.postgres_document_processing_store import PostgresDocumentProcessingStore
from app.infrastructure.storage.postgres_document_repository import PostgresDocumentRepository
from app.infrastructure.storage.postgres_parse_artifact_store import PostgresParseArtifactStore
from app.infrastructure.vectorstore.bm25_retriever import BM25Retriever
@@ -148,6 +150,14 @@ def get_parse_artifact_store():
return None
@lru_cache
def get_document_processing_store():
"""Return document processing store for the active repository backend."""
if settings.document_repository_backend == "postgres":
return PostgresDocumentProcessingStore()
return JsonDocumentProcessingStore(settings.document_processing_metadata_path)
@lru_cache
def get_binary_store() -> DocumentBinaryStore:
"""Return binary store."""
@@ -226,6 +236,7 @@ def get_document_command_service() -> DocumentCommandService:
embedding_provider=get_embedding_provider(),
vector_index=get_vector_index(),
parse_artifact_store=get_parse_artifact_store(),
document_processing_store=get_document_processing_store(),
)