from functools import lru_cache from typing import AsyncGenerator import httpx from neo4j import AsyncGraphDatabase from pymilvus import connections, Collection from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker from .config import settings # ── PostgreSQL ────────────────────────────────── engine = create_async_engine(settings.database_url, pool_size=10, max_overflow=20) AsyncSessionLocal = async_sessionmaker(engine, expire_on_commit=False) async def get_db() -> AsyncGenerator[AsyncSession, None]: async with AsyncSessionLocal() as session: try: yield session await session.commit() except Exception: await session.rollback() raise # ── Milvus ────────────────────────────────────── def get_milvus_collection(name: str) -> Collection: connections.connect(host=settings.milvus_host, port=settings.milvus_port) return Collection(name) # ── Neo4j ─────────────────────────────────────── _neo4j_driver = None def get_neo4j(): global _neo4j_driver if _neo4j_driver is None: _neo4j_driver = AsyncGraphDatabase.driver( settings.neo4j_uri, auth=(settings.neo4j_user, settings.neo4j_password), ) return _neo4j_driver # ── HTTP 客户端(复用连接池)──────────────────── _http_client = None def get_http_client() -> httpx.AsyncClient: global _http_client if _http_client is None: _http_client = httpx.AsyncClient(timeout=120.0) return _http_client