Files
AIRegulation-DocAnalysis/backend/app/shared/bootstrap.py

405 lines
15 KiB
Python

"""Share backend wiring for bootstrap."""
from __future__ import annotations
from functools import lru_cache
from typing import Callable
from app.application.agent import AgentConversationService, AgentSessionService
from app.application.documents import DocumentCommandService, DocumentQueryService
from app.application.knowledge import KnowledgeRetrievalService
from app.application.perception.services import PerceptionService
from app.config.settings import settings
from app.domain.documents import DocumentBinaryStore
from app.domain.retrieval import VectorIndex
from app.infrastructure.embedding.openai_compatible_embedding_provider import OpenAICompatibleEmbeddingProvider
from app.infrastructure.llm.openai_compatible_answer_generator import OpenAICompatibleAnswerGenerator
from app.infrastructure.parser.aliyun_document_parser import AliyunDocumentParser
from app.infrastructure.parser.local_chunk_builder import LocalRegulationChunkBuilder
from app.infrastructure.parser.local_document_parser import LocalDocumentParser
from app.infrastructure.parser.vector_chunk_builder import AliyunVectorChunkBuilder
from app.infrastructure.perception.mock_event_store import MockEventStore
from app.application.perception.crawl_service import CrawlService
from app.infrastructure.perception.base_event_store import BaseEventStore
from app.infrastructure.perception.crawlers.catarc_crawler import CatarcCrawler
from app.infrastructure.perception.crawlers.guobiao_crawler import (
GuobiaoMandatoryCrawler,
GuobiaoRecommendedCrawler,
)
from app.infrastructure.perception.crawlers.eurlex_crawler import EurlexCrawler
from app.infrastructure.perception.llm_pipeline import LlmPipeline
from app.infrastructure.session.in_memory_conversation_store import InMemoryConversationStore
from app.infrastructure.storage.json_document_processing_store import JsonDocumentProcessingStore
from app.infrastructure.storage.json_document_repository import JsonDocumentRepository
from app.infrastructure.storage.minio_binary_store import MinioDocumentBinaryStore
from app.infrastructure.storage.postgres_document_processing_store import PostgresDocumentProcessingStore
from app.infrastructure.storage.postgres_document_repository import PostgresDocumentRepository
from app.infrastructure.storage.postgres_parse_artifact_store import PostgresParseArtifactStore
from app.infrastructure.vectorstore.bm25_retriever import BM25Retriever
from app.infrastructure.vectorstore.cross_encoder_reranker import OpenAICompatibleReranker
from app.infrastructure.vectorstore.dense_retriever import DenseRetriever
from app.infrastructure.vectorstore.milvus_vector_index import MilvusVectorIndex
from app.services.llm.llm_factory import LLMFactory
from app.domain.compliance.ports import ComplianceRepository
from app.infrastructure.compliance.repository import PostgresComplianceRepository
# Keep shared wiring centralized so dependency construction remains consistent.
class LazyBinaryStore(DocumentBinaryStore):
"""Delay MinIO connection work until binary storage is actually needed."""
def __init__(self, factory: Callable[[], DocumentBinaryStore]) -> None:
"""Initialize the lazy binary store wrapper."""
self._factory = factory
self._store: DocumentBinaryStore | None = None
def _get_store(self) -> DocumentBinaryStore:
"""Create the underlying store on first use and reuse it afterwards."""
if self._store is None:
self._store = self._factory()
return self._store
@property
def client(self):
"""Expose the underlying client for compatibility with health endpoints."""
return self._get_store().client
def save(
self,
*,
object_name: str,
data: bytes,
content_type: str,
metadata: dict[str, str] | None = None,
) -> None:
"""Save data through the underlying binary store implementation."""
self._get_store().save(
object_name=object_name,
data=data,
content_type=content_type,
metadata=metadata,
)
def read(self, object_name: str) -> bytes:
"""Read data through the underlying binary store implementation."""
return self._get_store().read(object_name)
def delete(self, object_name: str) -> None:
"""Delete data through the underlying binary store implementation."""
self._get_store().delete(object_name)
class LazyVectorIndex(VectorIndex):
"""Delay Milvus connection work until vector operations are actually needed."""
def __init__(self, factory: Callable[[], VectorIndex]) -> None:
"""Initialize the lazy vector index wrapper."""
self._factory = factory
self._index: VectorIndex | None = None
def _get_index(self) -> VectorIndex:
"""Create the underlying index on first use and reuse it afterwards."""
if self._index is None:
self._index = self._factory()
return self._index
@property
def collection(self):
"""Expose the underlying Milvus collection for compatibility adapters."""
return self._get_index().collection
def upsert(self, chunks, vectors) -> int:
"""Insert or update vectors through the underlying vector index implementation."""
return self._get_index().upsert(chunks, vectors)
def delete_by_document(self, doc_id: str) -> int:
"""Delete vectors through the underlying vector index implementation."""
return self._get_index().delete_by_document(doc_id)
def search(self, query_vector: list[float], top_k: int, filters: str | None = None):
"""Search vectors through the underlying vector index implementation."""
return self._get_index().search(query_vector, top_k, filters)
def count_by_document(self) -> dict[str, int]:
"""Count document vectors through the underlying vector index implementation."""
return self._get_index().count_by_document()
def list_document_metadata(self) -> list[dict]:
"""List document metadata through the underlying vector index implementation."""
return self._get_index().list_document_metadata()
def health(self) -> dict:
"""Return vector index health through the underlying vector index implementation."""
return self._get_index().health()
@lru_cache
def _build_binary_store() -> MinioDocumentBinaryStore:
"""Return the concrete binary store implementation."""
return MinioDocumentBinaryStore()
@lru_cache
def _build_vector_index() -> MilvusVectorIndex:
"""Return the concrete vector index implementation."""
return MilvusVectorIndex()
@lru_cache
def get_document_repository():
"""Return document repository (json or postgres, controlled by settings)."""
if settings.document_repository_backend == "postgres":
return PostgresDocumentRepository()
return JsonDocumentRepository(settings.document_metadata_path)
@lru_cache
def get_parse_artifact_store():
"""Return parse artifact store, or None when postgres backend is not enabled."""
if settings.document_repository_backend == "postgres":
return PostgresParseArtifactStore()
return None
@lru_cache
def get_document_processing_store():
"""Return document processing store for the active repository backend."""
if settings.document_repository_backend == "postgres":
return PostgresDocumentProcessingStore()
return JsonDocumentProcessingStore(settings.document_processing_metadata_path)
@lru_cache
def get_binary_store() -> DocumentBinaryStore:
"""Return binary store."""
return LazyBinaryStore(_build_binary_store)
@lru_cache
def get_parser():
"""Return parser."""
if settings.parser_backend == "aliyun":
return AliyunDocumentParser()
return LocalDocumentParser()
@lru_cache
def get_chunk_builder():
"""Return chunk builder."""
if settings.chunk_backend == "aliyun":
return AliyunVectorChunkBuilder()
return LocalRegulationChunkBuilder(
chunk_size=settings.chunk_size,
chunk_overlap=settings.chunk_overlap,
)
@lru_cache
def get_embedding_provider() -> OpenAICompatibleEmbeddingProvider:
"""Return embedding provider."""
return OpenAICompatibleEmbeddingProvider()
@lru_cache
def get_vector_index() -> VectorIndex:
"""Return vector index."""
return LazyVectorIndex(_build_vector_index)
@lru_cache
def get_reranker():
"""Return reranker if enabled, else None."""
if settings.reranker_enabled and settings.reranker_base_url:
return OpenAICompatibleReranker()
return None
@lru_cache
def get_bm25_retriever() -> BM25Retriever | None:
"""Return BM25 retriever if rank_bm25 + jieba are installed, else None."""
retriever = BM25Retriever(vector_index=get_vector_index())
return retriever if retriever.available else None
@lru_cache
def get_retrieval_service() -> KnowledgeRetrievalService:
"""Return retrieval service."""
retriever = DenseRetriever(
embedding_provider=get_embedding_provider(),
vector_index=get_vector_index(),
)
return KnowledgeRetrievalService(
retriever=retriever,
bm25_retriever=get_bm25_retriever(),
reranker=get_reranker(),
reranker_top_k=settings.reranker_top_k,
)
@lru_cache
def get_document_command_service() -> DocumentCommandService:
"""Return document command service."""
return DocumentCommandService(
document_repository=get_document_repository(),
binary_store=get_binary_store(),
parser=get_parser(),
chunk_builder=get_chunk_builder(),
embedding_provider=get_embedding_provider(),
vector_index=get_vector_index(),
parse_artifact_store=get_parse_artifact_store(),
document_processing_store=get_document_processing_store(),
)
@lru_cache
def get_document_query_service() -> DocumentQueryService:
"""Return document query service."""
return DocumentQueryService(
document_repository=get_document_repository(),
binary_store=get_binary_store(),
vector_index=get_vector_index(),
)
@lru_cache
def get_conversation_store() -> InMemoryConversationStore:
"""Return the active conversation store based on settings.
When session_backend='redis', sessions survive backend restarts and scale
across multiple API worker processes. When session_backend='memory' (default),
sessions are process-local and lost on restart.
"""
if settings.session_backend == "redis":
import redis as redis_lib
from app.infrastructure.session.redis_conversation_store import RedisConversationStore
# Build the Redis client from the same connection settings used by Celery.
kwargs: dict = {
"host": settings.redis_host,
"port": settings.redis_port,
"db": settings.redis_db,
"decode_responses": False,
}
if settings.redis_password:
kwargs["password"] = settings.redis_password
redis_client = redis_lib.Redis(**kwargs)
return RedisConversationStore( # type: ignore[return-value]
redis_client=redis_client,
timeout_seconds=settings.session_timeout_minutes * 60,
)
return InMemoryConversationStore(
max_sessions=settings.session_max_sessions,
timeout_minutes=settings.session_timeout_minutes,
)
@lru_cache
def get_agent_conversation_service() -> AgentConversationService:
"""Return agent conversation service."""
return AgentConversationService(
retrieval_service=get_retrieval_service(),
answer_generator=OpenAICompatibleAnswerGenerator(),
conversation_store=get_conversation_store(),
)
@lru_cache
def get_event_store() -> BaseEventStore:
"""Return event store selected by DOCUMENT_REPOSITORY_BACKEND setting."""
if settings.document_repository_backend == "postgres":
from app.infrastructure.perception.postgres_event_store import PostgresEventStore
return PostgresEventStore()
return MockEventStore()
@lru_cache
def get_compliance_repository() -> ComplianceRepository:
"""Return the compliance analysis repository.
Requires document_repository_backend=postgres and valid postgres_* settings.
Raises NotImplementedError for any other backend value.
"""
if settings.document_repository_backend != "postgres":
raise NotImplementedError(
f"ComplianceRepository requires document_repository_backend=postgres, "
f"got '{settings.document_repository_backend}'. "
"Set DOCUMENT_REPOSITORY_BACKEND=postgres in your .env file."
)
return PostgresComplianceRepository(
host=settings.postgres_host,
port=settings.postgres_port,
user=settings.postgres_user,
password=settings.postgres_password,
dbname=settings.postgres_db,
)
@lru_cache
def get_perception_service() -> PerceptionService:
return PerceptionService(
event_store=get_event_store(),
retrieval_service=get_retrieval_service(),
)
@lru_cache
def get_crawl_service() -> CrawlService:
crawlers = {
"CATARC": CatarcCrawler(),
"国标委·强制性": GuobiaoMandatoryCrawler(),
"国标委·推荐性": GuobiaoRecommendedCrawler(),
"EUR-Lex": EurlexCrawler(),
}
return CrawlService(
crawlers=crawlers,
event_store=get_event_store(),
llm_pipeline=LlmPipeline(),
retrieval_service=get_retrieval_service(),
)
@lru_cache
def get_agent_session_service() -> AgentSessionService:
"""Return agent session service."""
return AgentSessionService(conversation_store=get_conversation_store())
@lru_cache
def get_celery_app():
"""Return the shared Celery application instance.
Imported lazily so Celery is not required when running without workers
(e.g., tests that mock bootstrap or dev without Redis).
"""
from app.infrastructure.tasks.celery_app import celery_app
return celery_app
@lru_cache
def get_jwt_handler():
"""Return the shared JWTHandler instance for token creation and validation."""
from app.infrastructure.auth.jwt_handler import JWTHandler
return JWTHandler(
secret_key=settings.auth_secret_key,
algorithm=settings.auth_algorithm,
expire_minutes=settings.auth_token_expire_minutes,
)
@lru_cache
def get_user_store():
"""Return the PostgreSQL user store (lazy-connects on first call)."""
from app.infrastructure.auth.user_store import PostgresUserStore
return PostgresUserStore()
def preload_runtime_dependencies() -> None:
"""Warm dependencies that are safe and useful to preload during startup."""
LLMFactory.preload_clients(["qwen", "deepseek"])
def cleanup_runtime_dependencies() -> None:
"""Release runtime dependencies that expose explicit cleanup hooks."""
LLMFactory.cleanup()