55 lines
1.8 KiB
Python
55 lines
1.8 KiB
Python
|
|
from functools import lru_cache
|
||
|
|
from typing import AsyncGenerator
|
||
|
|
|
||
|
|
import httpx
|
||
|
|
from neo4j import AsyncGraphDatabase
|
||
|
|
from pymilvus import connections, Collection
|
||
|
|
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
|
||
|
|
|
||
|
|
from .config import settings
|
||
|
|
|
||
|
|
# ── PostgreSQL ──────────────────────────────────
|
||
|
|
engine = create_async_engine(settings.database_url, pool_size=10, max_overflow=20)
|
||
|
|
AsyncSessionLocal = async_sessionmaker(engine, expire_on_commit=False)
|
||
|
|
|
||
|
|
|
||
|
|
async def get_db() -> AsyncGenerator[AsyncSession, None]:
|
||
|
|
async with AsyncSessionLocal() as session:
|
||
|
|
try:
|
||
|
|
yield session
|
||
|
|
await session.commit()
|
||
|
|
except Exception:
|
||
|
|
await session.rollback()
|
||
|
|
raise
|
||
|
|
|
||
|
|
|
||
|
|
# ── Milvus ──────────────────────────────────────
|
||
|
|
def get_milvus_collection(name: str) -> Collection:
|
||
|
|
connections.connect(host=settings.milvus_host, port=settings.milvus_port)
|
||
|
|
return Collection(name)
|
||
|
|
|
||
|
|
|
||
|
|
# ── Neo4j ───────────────────────────────────────
|
||
|
|
_neo4j_driver = None
|
||
|
|
|
||
|
|
|
||
|
|
def get_neo4j():
|
||
|
|
global _neo4j_driver
|
||
|
|
if _neo4j_driver is None:
|
||
|
|
_neo4j_driver = AsyncGraphDatabase.driver(
|
||
|
|
settings.neo4j_uri,
|
||
|
|
auth=(settings.neo4j_user, settings.neo4j_password),
|
||
|
|
)
|
||
|
|
return _neo4j_driver
|
||
|
|
|
||
|
|
|
||
|
|
# ── HTTP 客户端(复用连接池)────────────────────
|
||
|
|
_http_client = None
|
||
|
|
|
||
|
|
|
||
|
|
def get_http_client() -> httpx.AsyncClient:
|
||
|
|
global _http_client
|
||
|
|
if _http_client is None:
|
||
|
|
_http_client = httpx.AsyncClient(timeout=120.0)
|
||
|
|
return _http_client
|