118 lines
4.0 KiB
Python
118 lines
4.0 KiB
Python
|
|
"""Test runtime recovery and API error serialization for the Milvus vector index."""
|
||
|
|
|
||
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
from fastapi.encoders import jsonable_encoder
|
||
|
|
from pymilvus import MilvusException
|
||
|
|
|
||
|
|
from app.api.models import ErrorResponse
|
||
|
|
from app.infrastructure.vectorstore.milvus_vector_index import MilvusVectorIndex
|
||
|
|
from app.shared.errors import VectorStoreSchemaError
|
||
|
|
|
||
|
|
|
||
|
|
class FakeField:
|
||
|
|
"""Represent a minimal Milvus schema field for tests."""
|
||
|
|
|
||
|
|
def __init__(self, name: str) -> None:
|
||
|
|
"""Initialize the fake field."""
|
||
|
|
self.name = name
|
||
|
|
|
||
|
|
|
||
|
|
class FakeSchema:
|
||
|
|
"""Represent a minimal Milvus schema container for tests."""
|
||
|
|
|
||
|
|
def __init__(self, field_names: list[str]) -> None:
|
||
|
|
"""Initialize the fake schema from field names."""
|
||
|
|
self.fields = [FakeField(name) for name in field_names]
|
||
|
|
|
||
|
|
|
||
|
|
class FakeCollection:
|
||
|
|
"""Represent a minimal collection object for runtime recovery tests."""
|
||
|
|
|
||
|
|
def __init__(self, field_names: list[str], responses: list[object]) -> None:
|
||
|
|
"""Initialize the fake collection with schema fields and queued responses."""
|
||
|
|
self.schema = FakeSchema(field_names)
|
||
|
|
self.responses = responses
|
||
|
|
self.num_entities = 0
|
||
|
|
self.search_calls = 0
|
||
|
|
|
||
|
|
def search(self, **kwargs):
|
||
|
|
"""Return the next queued response or raise the next queued exception."""
|
||
|
|
self.search_calls += 1
|
||
|
|
response = self.responses.pop(0)
|
||
|
|
if isinstance(response, Exception):
|
||
|
|
raise response
|
||
|
|
return response
|
||
|
|
|
||
|
|
|
||
|
|
def _build_index_for_test(*, collection: FakeCollection) -> MilvusVectorIndex:
|
||
|
|
"""Create a MilvusVectorIndex instance without opening a real Milvus connection."""
|
||
|
|
index = MilvusVectorIndex.__new__(MilvusVectorIndex)
|
||
|
|
index.collection_name = "regulations_dense_1024_v2"
|
||
|
|
index.db_name = "default"
|
||
|
|
index.host = "6.86.80.8"
|
||
|
|
index.port = 19530
|
||
|
|
index.alias = "vector-index::test"
|
||
|
|
index.collection = collection
|
||
|
|
return index
|
||
|
|
|
||
|
|
|
||
|
|
def test_search_rebinds_and_retries_after_stale_schema_error(monkeypatch):
|
||
|
|
"""Refresh the bound collection once when Milvus reports a stale schema field."""
|
||
|
|
schema_fields = [
|
||
|
|
"id",
|
||
|
|
"doc_id",
|
||
|
|
"doc_title",
|
||
|
|
"chunk_id",
|
||
|
|
"text",
|
||
|
|
"embedding",
|
||
|
|
"section_title",
|
||
|
|
"metadata_json",
|
||
|
|
]
|
||
|
|
stale_collection = FakeCollection(
|
||
|
|
schema_fields,
|
||
|
|
[MilvusException(code=65535, message="field doc_title not exist")],
|
||
|
|
)
|
||
|
|
refreshed_collection = FakeCollection(schema_fields, [[]])
|
||
|
|
index = _build_index_for_test(collection=stale_collection)
|
||
|
|
|
||
|
|
def fake_bind_collection(*, force_refresh: bool = False):
|
||
|
|
"""Return the refreshed collection on forced rebinding."""
|
||
|
|
assert force_refresh is True
|
||
|
|
return refreshed_collection
|
||
|
|
|
||
|
|
monkeypatch.setattr(index, "_bind_collection", fake_bind_collection)
|
||
|
|
|
||
|
|
results = index.search([0.0] * 1024, 1)
|
||
|
|
|
||
|
|
assert results == []
|
||
|
|
assert stale_collection.search_calls == 1
|
||
|
|
assert refreshed_collection.search_calls == 1
|
||
|
|
assert index.collection is refreshed_collection
|
||
|
|
|
||
|
|
|
||
|
|
def test_validate_schema_raises_detailed_vector_store_schema_error():
|
||
|
|
"""Raise a typed schema error when required Milvus fields are missing."""
|
||
|
|
invalid_collection = FakeCollection(
|
||
|
|
["id", "doc_id", "doc_name", "content", "dense_vector"],
|
||
|
|
[[]],
|
||
|
|
)
|
||
|
|
index = _build_index_for_test(collection=invalid_collection)
|
||
|
|
|
||
|
|
try:
|
||
|
|
index._validate_schema(invalid_collection)
|
||
|
|
except VectorStoreSchemaError as exc:
|
||
|
|
assert "doc_title" in str(exc)
|
||
|
|
assert "actual_fields=['id', 'doc_id', 'doc_name', 'content', 'dense_vector']" in str(exc)
|
||
|
|
else:
|
||
|
|
raise AssertionError("VectorStoreSchemaError was not raised")
|
||
|
|
|
||
|
|
|
||
|
|
def test_error_response_is_json_serializable():
|
||
|
|
"""Ensure shared API error responses encode datetime fields safely."""
|
||
|
|
payload = jsonable_encoder(ErrorResponse(error="InternalServerError", message="boom"))
|
||
|
|
|
||
|
|
assert payload["error"] == "InternalServerError"
|
||
|
|
assert payload["message"] == "boom"
|
||
|
|
assert isinstance(payload["timestamp"], str)
|