Refactor document handling and update Milvus collection settings

- Removed multiple failed document entries from `documents.json`.
- Added a new document entry with updated metadata and changed the index name to `regulations_dense_1024_v2`.
- Updated architecture documentation to reflect changes in the Milvus collection name.
- Adjusted requirements by removing the sqlalchemy dependency.
- Modified test cases to align with new document structure and naming conventions.
- Introduced a new test file for Milvus vector index runtime recovery and error handling.
- Updated assertions in various test files to ensure compatibility with the new schema.
This commit is contained in:
ash66
2026-05-26 20:21:31 +08:00
parent fec22a3a2c
commit 30c7bda389
42 changed files with 7482 additions and 569 deletions

View File

@@ -3,6 +3,7 @@
from contextlib import asynccontextmanager
from fastapi import FastAPI, Request
from fastapi.encoders import jsonable_encoder
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from loguru import logger
@@ -12,6 +13,7 @@ from app.api.routes import api_router
from app.config.logging import setup_logging
from app.config.settings import settings
from app.shared.bootstrap import cleanup_runtime_dependencies, preload_runtime_dependencies
from app.shared.errors import VectorStoreSchemaError
# Keep module behavior explicit so the backend flow stays easy to audit.
@@ -55,16 +57,33 @@ app.add_middleware(
app.include_router(api_router, prefix="/api/v1")
@app.exception_handler(VectorStoreSchemaError)
async def vector_store_schema_exception_handler(request: Request, exc: VectorStoreSchemaError):
"""Return a stable JSON response for vector store schema/runtime errors."""
logger.error(f"向量库 schema 异常: {exc}")
return JSONResponse(
status_code=500,
content=jsonable_encoder(
ErrorResponse(
error="VectorStoreSchemaError",
message=str(exc),
)
),
)
@app.exception_handler(Exception)
async def global_exception_handler(request: Request, exc: Exception):
"""Global exception handler."""
logger.error(f"未处理的异常: {exc}")
return JSONResponse(
status_code=500,
content=ErrorResponse(
error="InternalServerError",
message=str(exc),
).model_dump(),
content=jsonable_encoder(
ErrorResponse(
error="InternalServerError",
message=str(exc),
)
),
)

View File

@@ -7,6 +7,7 @@ from .knowledge import router as knowledge_router
from .agent import router as agent_router
from .status import router as status_router
from .perception import router as perception_router
from .rag import router as rag_router
# Keep package boundaries explicit so backend imports stay predictable.
@@ -20,6 +21,7 @@ api_router.include_router(agent_router)
api_router.include_router(compliance_router)
api_router.include_router(status_router)
api_router.include_router(perception_router)
api_router.include_router(rag_router)
__all__ = [
"api_router",
@@ -29,4 +31,5 @@ __all__ = [
"compliance_router",
"status_router",
"perception_router",
"rag_router",
]

View File

@@ -29,14 +29,19 @@ async def search_knowledge(request: SearchRequest):
results=[
SearchResultItem(
id=index + 1,
content=item.content,
content=item.text,
score=item.score,
metadata={
"doc_id": item.doc_id,
"doc_name": item.doc_name,
"doc_title": item.doc_title,
"chunk_id": item.chunk_id,
"chunk_type": item.chunk_type,
"section_title": item.section_title,
"page_number": item.page_number,
"page_start": item.page_start,
"page_end": item.page_end,
"section_level": item.section_level,
"chunk_index": item.chunk_index,
"piece_index": item.piece_index,
**item.metadata,
},
)

View File

@@ -50,8 +50,8 @@ async def rag_chat(request: RagChatRequest):
{
"id": str(s.get("chunk_id") or s.get("doc_id") or idx + 1),
"score": s.get("score", 0),
"preview": s.get("content", "")[:200],
"doc_name": s.get("doc_name", ""),
"preview": s.get("text", s.get("content", ""))[:200],
"doc_name": s.get("doc_title", s.get("doc_name", "")),
"clause": s.get("section_title", "法规片段"),
"doc_id": s.get("doc_id"),
"download_url": (

View File

@@ -508,7 +508,7 @@ class DocumentQueryService:
"""Return documents with real-time state from Milvus as the authoritative source.
Algorithm:
1. Query Milvus for all doc metadata (doc_id, doc_name, chunk_count, …).
1. Query Milvus for all doc metadata (doc_id, doc_title, chunk_count, …).
2. Load JSON/PG metadata records and index them by doc_id.
3. Merge: Milvus-present docs get status=INDEXED and live chunk_count;
metadata-only docs with status=INDEXED are demoted to FAILED.
@@ -536,8 +536,8 @@ class DocumentQueryService:
doc.chunk_count = row["chunk_count"]
doc.status = DocumentStatus.INDEXED
# Backfill fields that may be missing from older JSON records.
if not doc.doc_name and row.get("doc_name"):
doc.doc_name = row["doc_name"]
if not doc.doc_name and row.get("doc_title"):
doc.doc_name = row["doc_title"]
if not doc.regulation_type and row.get("regulation_type"):
doc.regulation_type = row["regulation_type"]
if not doc.version and row.get("version"):
@@ -553,8 +553,8 @@ class DocumentQueryService:
if doc_id not in meta_by_id:
synthetic = Document(
doc_id=doc_id,
doc_name=row.get("doc_name", doc_id),
file_name=row.get("doc_name", doc_id),
doc_name=row.get("doc_title", doc_id),
file_name=row.get("doc_title", doc_id),
object_name="",
content_type="",
size_bytes=0,

View File

@@ -29,11 +29,16 @@ def _reciprocal_rank_fusion(
RetrievedChunk(
chunk_id=chunk_map[ck].chunk_id,
doc_id=chunk_map[ck].doc_id,
doc_name=chunk_map[ck].doc_name,
content=chunk_map[ck].content,
doc_title=chunk_map[ck].doc_title,
text=chunk_map[ck].text,
score=scores[ck],
chunk_type=chunk_map[ck].chunk_type,
section_title=chunk_map[ck].section_title,
page_number=chunk_map[ck].page_number,
page_start=chunk_map[ck].page_start,
page_end=chunk_map[ck].page_end,
section_level=chunk_map[ck].section_level,
chunk_index=chunk_map[ck].chunk_index,
piece_index=chunk_map[ck].piece_index,
metadata=chunk_map[ck].metadata,
)
for ck in sorted_keys

View File

@@ -71,9 +71,9 @@ class PerceptionService:
affected_docs.append(
{
"doc_id": chunk.doc_id,
"doc_name": chunk.doc_name,
"doc_title": chunk.doc_title,
"score": round(float(chunk.score), 4),
"snippet": (chunk.content or "")[:180],
"snippet": (chunk.text or "")[:180],
"clause": getattr(chunk, "section_title", "") or "",
}
)
@@ -84,7 +84,7 @@ class PerceptionService:
# --- 2. Build context from retrieved chunks ---
context_parts = [
f"[文档{i}: {c.doc_name}]\n{(c.content or '')[:400]}"
f"[文档{i}: {c.doc_title}]\n{(c.text or '')[:400]}"
for i, c in enumerate(chunks[:5], 1)
]
context = "\n\n".join(context_parts) if context_parts else "(知识库中暂无相关文档)"

View File

@@ -33,7 +33,7 @@ class Settings(BaseSettings):
# Keep configuration setup explicit so runtime behavior is easy to reason about.
milvus_host: str = Field(default="6.86.80.8", description="Milvus服务地址")
milvus_port: int = Field(default=19530, description="Milvus服务端口")
milvus_collection: str = Field(default="regulations_dense_1024_v1", description="法规向量集合名称")
milvus_collection: str = Field(default="regulations_dense_1024_v2", description="法规向量集合名称")
milvus_db_name: str = Field(default="default", description="Milvus数据库名称")
# Keep configuration setup explicit so runtime behavior is easy to reason about.

View File

@@ -27,7 +27,7 @@ class Settings(BaseSettings):
# Milvus
milvus_host: str = "6.86.80.8"
milvus_port: int = 19530
milvus_collection: str = "regulations_dense_1024_v1"
milvus_collection: str = "regulations_dense_1024_v2"
# LLM / embedding defaults aligned with the migrated backend path.
llm_model: str = "qwen-max"
@@ -47,7 +47,7 @@ class Settings(BaseSettings):
api_port: int = 8000
# Legacy aliases retained for old utility modules.
regulations_collection: str = "regulations_dense_1024_v1"
regulations_collection: str = "regulations_dense_1024_v2"
compliance_collection: str = "compliance_cache"
# Preserve the legacy module API while keeping env resolution centralized at the repo root.

View File

@@ -8,18 +8,91 @@ from typing import Any
@dataclass
@dataclass(init=False)
class AnswerSource:
"""Represent answer source data."""
"""Represent answer source data with legacy aliases."""
doc_id: str
doc_name: str
doc_title: str
chunk_id: str
chunk_type: str
section_title: str
page_number: int
page_start: int
page_end: int
section_level: int
chunk_index: int
piece_index: int
score: float
content: str
text: str
metadata: dict[str, Any] = field(default_factory=dict)
def __init__(
self,
*,
doc_id: str,
doc_title: str | None = None,
chunk_id: str,
chunk_type: str = "",
section_title: str = "",
page_start: int = 0,
page_end: int = 0,
section_level: int = 0,
chunk_index: int = 0,
piece_index: int = 0,
score: float = 0.0,
text: str | None = None,
metadata: dict[str, Any] | None = None,
doc_name: str | None = None,
content: str | None = None,
page_number: int | None = None,
**_: Any,
) -> None:
"""Initialize the answer source while accepting legacy field names."""
self.doc_id = doc_id
self.doc_title = doc_title if doc_title is not None else (doc_name or "")
self.chunk_id = chunk_id
self.chunk_type = chunk_type
self.section_title = section_title
self.page_start = int(page_start or page_number or 0)
self.page_end = int(page_end or self.page_start)
self.section_level = int(section_level or 0)
self.chunk_index = int(chunk_index or 0)
self.piece_index = int(piece_index or 0)
self.score = float(score)
self.text = text if text is not None else (content or "")
self.metadata = dict(metadata or {})
@property
def doc_name(self) -> str:
"""Return the legacy document name alias."""
return self.doc_title
@doc_name.setter
def doc_name(self, value: str) -> None:
"""Update the legacy document name alias."""
self.doc_title = value
@property
def content(self) -> str:
"""Return the legacy content alias."""
return self.text
@content.setter
def content(self, value: str) -> None:
"""Update the legacy content alias."""
self.text = value
@property
def page_number(self) -> int:
"""Return the legacy page number alias."""
return self.page_start
@page_number.setter
def page_number(self, value: int) -> None:
"""Update the legacy page number alias."""
self.page_start = value
self.page_end = max(self.page_end, value)
@dataclass
class ConversationMessage:

View File

@@ -60,23 +60,117 @@ class ParsedDocument:
metadata: dict[str, Any] = field(default_factory=dict)
@dataclass
@dataclass(init=False)
class Chunk:
"""Represent the Chunk type."""
"""Represent one retrieval chunk with backward-compatible aliases."""
chunk_id: str
doc_id: str
doc_name: str
content: str
doc_title: str
text: str
embedding_text: str
chunk_type: str = ""
chunk_index: int = 0
piece_index: int = 0
page_start: int = 0
page_end: int = 0
section_title: str = ""
section_path: list[str] = field(default_factory=list)
page_number: int = 0
section_level: int = 0
source_ids: list[str] = field(default_factory=list)
regulation_type: str = ""
version: str = ""
semantic_id: str = ""
block_type: str = ""
metadata: dict[str, Any] = field(default_factory=dict)
def __init__(
self,
*,
chunk_id: str,
doc_id: str,
doc_title: str | None = None,
text: str | None = None,
embedding_text: str = "",
chunk_type: str = "",
chunk_index: int = 0,
piece_index: int = 0,
page_start: int = 0,
page_end: int = 0,
section_title: str = "",
section_path: list[str] | None = None,
section_level: int = 0,
source_ids: list[str] | None = None,
regulation_type: str = "",
version: str = "",
semantic_id: str = "",
metadata: dict[str, Any] | None = None,
doc_name: str | None = None,
content: str | None = None,
page_number: int | None = None,
block_type: str | None = None,
**_: Any,
) -> None:
"""Initialize the chunk while accepting legacy field names."""
self.chunk_id = chunk_id
self.doc_id = doc_id
self.doc_title = doc_title if doc_title is not None else (doc_name or "")
self.text = text if text is not None else (content or "")
self.embedding_text = embedding_text or self.text
self.chunk_type = chunk_type or (block_type or "")
self.chunk_index = int(chunk_index or 0)
self.piece_index = int(piece_index or 0)
self.page_start = int(page_start or page_number or 0)
self.page_end = int(page_end or self.page_start)
self.section_title = section_title
self.section_path = list(section_path or [])
self.section_level = int(section_level or 0)
self.source_ids = list(source_ids or [])
self.regulation_type = regulation_type
self.version = version
self.semantic_id = semantic_id
self.metadata = dict(metadata or {})
@property
def doc_name(self) -> str:
"""Return the legacy document name alias."""
return self.doc_title
@doc_name.setter
def doc_name(self, value: str) -> None:
"""Update the legacy document name alias."""
self.doc_title = value
@property
def content(self) -> str:
"""Return the legacy content alias."""
return self.text
@content.setter
def content(self, value: str) -> None:
"""Update the legacy content alias."""
self.text = value
@property
def page_number(self) -> int:
"""Return the legacy page number alias."""
return self.page_start
@page_number.setter
def page_number(self, value: int) -> None:
"""Update the legacy page number alias."""
self.page_start = value
self.page_end = max(self.page_end, value)
@property
def block_type(self) -> str:
"""Return the legacy block type alias."""
return self.chunk_type
@block_type.setter
def block_type(self, value: str) -> None:
"""Update the legacy block type alias."""
self.chunk_type = value
@dataclass
class DocumentProcessingRun:

View File

@@ -16,14 +16,88 @@ class RetrievalQuery:
filters: str | None = None
@dataclass
@dataclass(init=False)
class RetrievedChunk:
"""Represent the Retrieved Chunk type."""
"""Represent the retrieved chunk payload with legacy aliases."""
chunk_id: str
doc_id: str
doc_name: str
content: str
doc_title: str
text: str
score: float
chunk_type: str = ""
section_title: str = ""
page_number: int = 0
page_start: int = 0
page_end: int = 0
section_level: int = 0
chunk_index: int = 0
piece_index: int = 0
metadata: dict[str, Any] = field(default_factory=dict)
def __init__(
self,
*,
chunk_id: str,
doc_id: str,
doc_title: str | None = None,
text: str | None = None,
score: float = 0.0,
chunk_type: str = "",
section_title: str = "",
page_start: int = 0,
page_end: int = 0,
section_level: int = 0,
chunk_index: int = 0,
piece_index: int = 0,
metadata: dict[str, Any] | None = None,
doc_name: str | None = None,
content: str | None = None,
page_number: int | None = None,
block_type: str | None = None,
**_: Any,
) -> None:
"""Initialize the retrieved chunk while accepting legacy field names."""
self.chunk_id = chunk_id
self.doc_id = doc_id
self.doc_title = doc_title if doc_title is not None else (doc_name or "")
self.text = text if text is not None else (content or "")
self.score = float(score)
self.chunk_type = chunk_type or (block_type or "")
self.section_title = section_title
self.page_start = int(page_start or page_number or 0)
self.page_end = int(page_end or self.page_start)
self.section_level = int(section_level or 0)
self.chunk_index = int(chunk_index or 0)
self.piece_index = int(piece_index or 0)
self.metadata = dict(metadata or {})
@property
def doc_name(self) -> str:
"""Return the legacy document name alias."""
return self.doc_title
@doc_name.setter
def doc_name(self, value: str) -> None:
"""Update the legacy document name alias."""
self.doc_title = value
@property
def content(self) -> str:
"""Return the legacy content alias."""
return self.text
@content.setter
def content(self, value: str) -> None:
"""Update the legacy content alias."""
self.text = value
@property
def page_number(self) -> int:
"""Return the legacy page number alias."""
return self.page_start
@page_number.setter
def page_number(self, value: int) -> None:
"""Update the legacy page number alias."""
self.page_start = value
self.page_end = max(self.page_end, value)

View File

@@ -45,10 +45,10 @@ class OpenAICompatibleAnswerGenerator(AnswerGenerator):
context_tokens = 0
for idx, chunk in enumerate(retrieved_chunks, start=1):
block = (
f"[{idx}] 文档: {chunk.doc_name}\n"
f"[{idx}] 文档: {chunk.doc_title}\n"
f"章节: {chunk.section_title or '未标注'}\n"
f"页码: {chunk.page_number}\n"
f"内容: {chunk.content}"
f"页码: {chunk.page_start}" + (f"-{chunk.page_end}" if chunk.page_end and chunk.page_end != chunk.page_start else "") + "\n"
f"内容: {chunk.text}"
)
block_tokens = self._estimate_tokens(block)
if context_tokens + block_tokens > settings.rag_max_context_tokens:
@@ -73,10 +73,10 @@ class OpenAICompatibleAnswerGenerator(AnswerGenerator):
return False
estimated_total_tokens = sum(
self._estimate_tokens(
f"[{idx}] 文档: {chunk.doc_name}\n"
f"[{idx}] 文档: {chunk.doc_title}\n"
f"章节: {chunk.section_title or '未标注'}\n"
f"页码: {chunk.page_number}\n"
f"内容: {chunk.content}"
f"页码: {chunk.page_start}" + (f"-{chunk.page_end}" if chunk.page_end and chunk.page_end != chunk.page_start else "") + "\n"
f"内容: {chunk.text}"
)
for idx, chunk in enumerate(retrieved_chunks, start=1)
)
@@ -87,12 +87,17 @@ class OpenAICompatibleAnswerGenerator(AnswerGenerator):
return [
AnswerSource(
doc_id=chunk.doc_id,
doc_name=chunk.doc_name,
doc_title=chunk.doc_title,
chunk_id=chunk.chunk_id,
chunk_type=chunk.chunk_type,
section_title=chunk.section_title,
page_number=chunk.page_number,
page_start=chunk.page_start,
page_end=chunk.page_end,
section_level=chunk.section_level,
chunk_index=chunk.chunk_index,
piece_index=chunk.piece_index,
score=chunk.score,
content=chunk.content,
text=chunk.text,
metadata=chunk.metadata,
)
for chunk in chunks

View File

@@ -10,6 +10,7 @@ class LocalRegulationChunkBuilder(ChunkBuilder):
"""Adapt the existing markdown chunker to the new chunk builder port."""
def __init__(self, *, chunk_size: int = 512, chunk_overlap: int = 50) -> None:
"""Initialize the local markdown chunk builder."""
self.chunker = RegulationChunker(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
@@ -22,6 +23,7 @@ class LocalRegulationChunkBuilder(ChunkBuilder):
regulation_type: str,
version: str,
) -> list[Chunk]:
"""Build migrated chunk objects from the legacy markdown chunker output."""
markdown_text = parsed_document.raw_text.strip()
if not markdown_text:
return []
@@ -50,16 +52,18 @@ class LocalRegulationChunkBuilder(ChunkBuilder):
Chunk(
chunk_id=item.metadata.chunk_id,
doc_id=parsed_document.doc_id,
doc_name=parsed_document.doc_name,
content=item.content,
doc_title=parsed_document.doc_name,
text=item.content,
embedding_text=item.content,
chunk_type="local_markdown_chunk",
section_title=item.metadata.section_title or item.metadata.section_number,
section_path=section_path,
page_number=item.metadata.page_number,
page_start=item.metadata.page_number,
page_end=item.metadata.page_number,
section_level=len(section_path),
regulation_type=regulation_type,
version=version,
semantic_id=item.metadata.clause_number,
block_type="local_markdown_chunk",
metadata=metadata,
)
)

View File

@@ -19,29 +19,35 @@ class AliyunVectorChunkBuilder(ChunkBuilder):
"""Handle build for the Aliyun Vector Chunk Builder instance."""
chunks: list[Chunk] = []
for index, item in enumerate(parsed_document.vector_chunks):
content = item.get("content") or item.get("text") or ""
embedding_text = item.get("embedding_text") or content
text = item.get("text") or ""
embedding_text = item.get("embedding_text") or text
if not embedding_text.strip():
continue
section_path = item.get("section_path") or []
section_title = item.get("section_title") or (section_path[-1] if section_path else "")
page_number = item.get("page_start") or item.get("page") or 0
chunk_id = item.get("chunk_id") or f"{parsed_document.doc_id}-chunk-{index}"
metadata = {k: v for k, v in item.items() if k not in {"content", "embedding_text"}}
metadata = dict(item)
metadata["regulation_type"] = regulation_type
metadata["version"] = version
chunks.append(
Chunk(
chunk_id=str(chunk_id),
doc_id=parsed_document.doc_id,
doc_name=parsed_document.doc_name,
content=content,
doc_title=str(item.get("doc_title") or parsed_document.doc_name),
text=text,
embedding_text=embedding_text,
chunk_type=str(item.get("chunk_type", item.get("block_type", ""))),
chunk_index=int(item.get("chunk_index") or 0),
piece_index=int(item.get("piece_index") or 0),
page_start=int(item.get("page_start") or 0),
page_end=int(item.get("page_end") or 0),
section_title=section_title,
section_path=section_path,
page_number=int(page_number or 0),
section_level=int(item.get("section_level") or len(section_path)),
source_ids=[str(v) for v in item.get("source_ids", [])],
regulation_type=regulation_type,
version=version,
semantic_id=item.get("semantic_id", ""),
block_type=item.get("block_type", ""),
metadata=metadata,
)
)

View File

@@ -56,7 +56,21 @@ class BM25Retriever:
try:
rows = self._vector_index.collection.query(
expr='doc_id != ""',
output_fields=["id", "doc_id", "doc_name", "content", "section_title", "page_number"],
output_fields=[
"id",
"chunk_id",
"doc_id",
"doc_title",
"text",
"chunk_type",
"section_title",
"page_start",
"page_end",
"section_level",
"chunk_index",
"piece_index",
"metadata_json",
],
limit=16384,
)
except Exception:
@@ -64,19 +78,33 @@ class BM25Retriever:
return []
return [
RetrievedChunk(
chunk_id=str(row.get("id", "")),
chunk_id=str(row.get("chunk_id") or row.get("id", "")),
doc_id=str(row.get("doc_id", "")),
doc_name=str(row.get("doc_name", "")),
content=str(row.get("content", "")),
doc_title=str(row.get("doc_title", "")),
text=str(row.get("text", "")),
score=0.0,
chunk_type=str(row.get("chunk_type", "")),
section_title=str(row.get("section_title", "")),
page_number=int(row.get("page_number") or 0),
metadata={},
page_start=int(row.get("page_start") or 0),
page_end=int(row.get("page_end") or 0),
section_level=int(row.get("section_level") or 0),
chunk_index=int(row.get("chunk_index") or 0),
piece_index=int(row.get("piece_index") or 0),
metadata=self._parse_metadata_json(row.get("metadata_json", "")),
)
for row in rows
if row.get("content")
if row.get("text")
]
def _parse_metadata_json(self, raw_metadata: str) -> dict:
"""Parse metadata_json into a dict for BM25-side filtering."""
if not raw_metadata:
return {}
try:
return dict(__import__("json").loads(raw_metadata))
except Exception:
return {}
def _ensure_built(self) -> None:
if self._index is not None:
return
@@ -93,7 +121,7 @@ class BM25Retriever:
self._chunks = []
self._index = BM25Okapi([[]])
return
tokenized = [_tokenize(c.content) for c in chunks]
tokenized = [_tokenize(c.text) for c in chunks]
self._chunks = chunks
self._index = BM25Okapi(tokenized)
logger.info("BM25Retriever: index built with %d chunks", len(chunks))
@@ -127,20 +155,26 @@ class BM25Retriever:
for score, chunk in ranked[: top_k * 2]:
if score <= 0:
break
# Apply simple regulation_type filter if provided
if filters and chunk.metadata.get("regulation_type"):
types = [t.strip() for t in filters.split(",")]
if chunk.metadata.get("regulation_type") not in types:
continue
if filters:
normalized_filter = filters.replace("doc_name", "doc_title").strip()
if normalized_filter.startswith('doc_title == "'):
expected_title = normalized_filter[len('doc_title == "'):-1]
if chunk.doc_title != expected_title:
continue
results.append(
RetrievedChunk(
chunk_id=chunk.chunk_id,
doc_id=chunk.doc_id,
doc_name=chunk.doc_name,
content=chunk.content,
doc_title=chunk.doc_title,
text=chunk.text,
score=score,
chunk_type=chunk.chunk_type,
section_title=chunk.section_title,
page_number=chunk.page_number,
page_start=chunk.page_start,
page_end=chunk.page_end,
section_level=chunk.section_level,
chunk_index=chunk.chunk_index,
piece_index=chunk.piece_index,
metadata=chunk.metadata,
)
)

View File

@@ -31,7 +31,7 @@ class OpenAICompatibleReranker(Reranker):
if not chunks:
return []
texts = [chunk.content for chunk in chunks]
texts = [chunk.text for chunk in chunks]
start = time.time()
try:
scores = self._call_reranker(query, texts)

View File

@@ -4,57 +4,150 @@ from __future__ import annotations
import json
import time
from typing import Iterable
from pymilvus import Collection, CollectionSchema, DataType, FieldSchema, connections, utility
from loguru import logger
from pymilvus import Collection, CollectionSchema, DataType, FieldSchema, MilvusException, connections, utility
from app.config.settings import settings
from app.domain.documents import Chunk
from app.domain.retrieval import RetrievedChunk, VectorIndex
from app.shared.errors import VectorStoreSchemaError
# Keep adapter behavior explicit so integration details remain easy to audit.
_REQUIRED_SCHEMA_FIELDS = (
"doc_id",
"doc_title",
"chunk_id",
"text",
"embedding",
"section_title",
"metadata_json",
)
_SCHEMA_RECOVERY_TOKENS = (
"field doc_title not exist",
"field text not exist",
"field embedding not exist",
"collection not loaded",
"can't find collection",
"not found[collection",
)
class MilvusVectorIndex(VectorIndex):
"""Provide the Milvus Vector Index index implementation."""
def __init__(self) -> None:
"""Initialize the Milvus Vector Index instance."""
self.collection_name = settings.milvus_collection
self.db_name = settings.milvus_db_name
self.host = settings.milvus_host
self.port = settings.milvus_port
# Use an adapter-specific alias so this index never reuses unrelated global Milvus state.
self.alias = f"vector-index::{self.host}:{self.port}/{self.db_name}/{self.collection_name}"
self._connect()
self.collection = self._bind_collection()
def _connect(self, *, refresh: bool = False) -> None:
"""Establish the Milvus connection for this adapter."""
if refresh:
try:
connections.disconnect(self.alias)
except Exception:
# Best-effort disconnect keeps refresh idempotent when no alias is active yet.
pass
connections.connect(
alias="default",
host=settings.milvus_host,
port=settings.milvus_port,
alias=self.alias,
host=self.host,
port=self.port,
db_name=self.db_name,
)
self.collection = self._ensure_collection()
def _schema_field_names(self, collection: Collection) -> list[str]:
"""Return the field names exposed by the bound Milvus collection."""
return [field.name for field in collection.schema.fields]
def _raise_schema_error(self, *, message: str, actual_fields: Iterable[str]) -> None:
"""Raise a typed schema error for the active collection."""
raise VectorStoreSchemaError(
message=message,
host=self.host,
db_name=self.db_name,
collection_name=self.collection_name,
expected_fields=list(_REQUIRED_SCHEMA_FIELDS),
actual_fields=list(actual_fields),
)
def _validate_schema(self, collection: Collection) -> None:
"""Ensure the collection schema matches the dense-only adapter contract."""
actual_fields = self._schema_field_names(collection)
missing_fields = [field_name for field_name in _REQUIRED_SCHEMA_FIELDS if field_name not in actual_fields]
if missing_fields:
self._raise_schema_error(
message=f"Milvus collection schema mismatch; missing required fields: {missing_fields}",
actual_fields=actual_fields,
)
def _log_collection_binding(self, collection: Collection, *, event: str) -> None:
"""Record the bound collection details for runtime diagnostics."""
try:
num_entities = collection.num_entities
except Exception:
num_entities = "unknown"
logger.info(
"Milvus binding {} alias={} host={} db={} collection={} fields={} num_entities={}",
event,
self.alias,
self.host,
self.db_name,
self.collection_name,
self._schema_field_names(collection),
num_entities,
)
def _bind_collection(self, *, force_refresh: bool = False) -> Collection:
"""Bind and validate the configured Milvus collection."""
if force_refresh:
self._connect(refresh=True)
collection = self._ensure_collection()
self._validate_schema(collection)
self._log_collection_binding(collection, event="refreshed" if force_refresh else "initialized")
return collection
def _ensure_collection(self) -> Collection:
"""Handle ensure collection for this module for the Milvus Vector Index instance."""
if utility.has_collection(self.collection_name):
collection = Collection(self.collection_name)
if utility.has_collection(self.collection_name, using=self.alias):
collection = Collection(self.collection_name, using=self.alias)
collection.load()
return collection
schema = CollectionSchema(
fields=[
FieldSchema(name="id", dtype=DataType.VARCHAR, max_length=128, is_primary=True, auto_id=False),
FieldSchema(name="doc_id", dtype=DataType.VARCHAR, max_length=64),
FieldSchema(name="doc_name", dtype=DataType.VARCHAR, max_length=256),
FieldSchema(name="content", dtype=DataType.VARCHAR, max_length=65535),
FieldSchema(name="doc_title", dtype=DataType.VARCHAR, max_length=256),
FieldSchema(name="chunk_id", dtype=DataType.VARCHAR, max_length=128),
FieldSchema(name="chunk_index", dtype=DataType.INT64),
FieldSchema(name="piece_index", dtype=DataType.INT64),
FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=65535),
FieldSchema(name="embedding_text", dtype=DataType.VARCHAR, max_length=65535),
FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=settings.embedding_dim),
FieldSchema(name="section_title", dtype=DataType.VARCHAR, max_length=512),
FieldSchema(name="section_path", dtype=DataType.VARCHAR, max_length=4096),
FieldSchema(name="page_number", dtype=DataType.INT64),
FieldSchema(name="regulation_type", dtype=DataType.VARCHAR, max_length=128),
FieldSchema(name="version", dtype=DataType.VARCHAR, max_length=64),
FieldSchema(name="semantic_id", dtype=DataType.VARCHAR, max_length=128),
FieldSchema(name="block_type", dtype=DataType.VARCHAR, max_length=64),
FieldSchema(name="chunk_type", dtype=DataType.VARCHAR, max_length=64),
FieldSchema(name="page_start", dtype=DataType.INT64),
FieldSchema(name="page_end", dtype=DataType.INT64),
FieldSchema(name="section_level", dtype=DataType.INT64),
FieldSchema(name="source_ids", dtype=DataType.VARCHAR, max_length=4096),
FieldSchema(name="section_path", dtype=DataType.VARCHAR, max_length=4096),
FieldSchema(name="section_title", dtype=DataType.VARCHAR, max_length=512),
FieldSchema(name="metadata_json", dtype=DataType.VARCHAR, max_length=65535),
FieldSchema(name="created_at", dtype=DataType.INT64),
],
description="Dense-only regulations index",
enable_dynamic_field=False,
)
collection = Collection(name=self.collection_name, schema=schema)
collection = Collection(name=self.collection_name, schema=schema, using=self.alias)
collection.create_index(
field_name="embedding",
index_params={
@@ -73,21 +166,34 @@ class MilvusVectorIndex(VectorIndex):
data = []
now = int(time.time())
for chunk, vector in zip(chunks, vectors):
metadata = dict(chunk.metadata)
doc_title = str(metadata.get("doc_title", chunk.doc_title))
text = str(metadata.get("text", chunk.text))
embedding_text = str(metadata.get("embedding_text", chunk.embedding_text))
page_start = int(metadata.get("page_start", 0) or 0)
page_end = int(metadata.get("page_end", 0) or 0)
section_path = metadata.get("section_path", chunk.section_path)
source_ids = metadata.get("source_ids", [])
data.append(
{
"id": chunk.chunk_id,
"doc_id": chunk.doc_id,
"doc_name": chunk.doc_name,
"content": chunk.content[:65535],
"doc_title": doc_title[:256],
"chunk_id": chunk.chunk_id[:128],
"chunk_index": int(metadata.get("chunk_index", chunk.chunk_index) or 0),
"piece_index": int(metadata.get("piece_index", chunk.piece_index) or 0),
"text": text[:65535],
"embedding_text": embedding_text[:65535],
"embedding": vector,
"section_title": chunk.section_title[:512],
"section_path": json.dumps(chunk.section_path, ensure_ascii=False)[:4096],
"page_number": chunk.page_number,
"regulation_type": chunk.regulation_type[:128],
"version": chunk.version[:64],
"semantic_id": chunk.semantic_id[:128],
"block_type": chunk.block_type[:64],
"metadata_json": json.dumps(chunk.metadata, ensure_ascii=False)[:65535],
"semantic_id": str(metadata.get("semantic_id", chunk.semantic_id))[:128],
"chunk_type": str(metadata.get("chunk_type", chunk.chunk_type))[:64],
"page_start": page_start,
"page_end": page_end,
"section_level": int(metadata.get("section_level", chunk.section_level) or 0),
"source_ids": json.dumps(source_ids, ensure_ascii=False)[:4096],
"section_path": json.dumps(section_path, ensure_ascii=False)[:4096],
"section_title": str(metadata.get("section_title", chunk.section_title))[:512],
"metadata_json": json.dumps(metadata, ensure_ascii=False)[:65535],
"created_at": now,
}
)
@@ -107,47 +213,97 @@ class MilvusVectorIndex(VectorIndex):
filters = filters.strip()
# Normalize legacy field names so callers can keep older filter payloads.
replacements = {
"doc_name": "doc_title",
"content": "text",
"page_number": "page_start",
"block_type": "chunk_type",
}
for legacy_name, new_name in replacements.items():
filters = filters.replace(legacy_name, new_name)
# Check if already a Milvus expression (contains operators)
if any(op in filters for op in ["==", "!=", "in", "not in", ">", "<", ">=", "<=", "and", "or"]):
return filters
# Parse simple regulation_type filter
# Support: "GB" or "GB,UN-ECE" or "GB, UN-ECE"
types = [t.strip() for t in filters.split(",") if t.strip()]
# Parse simple document-title filter.
titles = [title.strip() for title in filters.split(",") if title.strip()]
if not types:
if not titles:
return None
if len(types) == 1:
# Single value: regulation_type == "GB"
return f'regulation_type == "{types[0]}"'
else:
# Multiple values: regulation_type in ["GB", "UN-ECE"]
quoted_types = [f'"{t}"' for t in types]
return f'regulation_type in [{", ".join(quoted_types)}]'
if len(titles) == 1:
return f'doc_title == "{titles[0]}"'
quoted_titles = [f'"{title}"' for title in titles]
return f'doc_title in [{", ".join(quoted_titles)}]'
def _should_refresh_after_exception(self, exc: Exception) -> bool:
"""Return whether the Milvus error suggests stale connection or collection state."""
if not isinstance(exc, MilvusException):
return False
normalized = str(exc).lower()
return any(token in normalized for token in _SCHEMA_RECOVERY_TOKENS)
def _run_with_refresh(self, operation):
"""Run a Milvus operation and retry once after a forced reconnect when appropriate."""
try:
return operation()
except VectorStoreSchemaError:
raise
except Exception as exc:
if not self._should_refresh_after_exception(exc):
raise
logger.warning(
"Milvus operation failed for alias={} collection={}; forcing reconnect and retry: {}",
self.alias,
self.collection_name,
exc,
)
self.collection = self._bind_collection(force_refresh=True)
try:
return operation()
except VectorStoreSchemaError:
raise
except Exception as retry_exc:
if isinstance(retry_exc, MilvusException):
self._raise_schema_error(
message=f"Milvus operation failed after refresh: {retry_exc}",
actual_fields=self._schema_field_names(self.collection),
)
raise
def search(self, query_vector: list[float], top_k: int, filters: str | None = None) -> list[RetrievedChunk]:
"""Handle search for the Milvus Vector Index instance."""
milvus_expr = self._parse_filters(filters)
results = self.collection.search(
data=[query_vector],
anns_field="embedding",
param={"metric_type": "COSINE", "params": {"nprobe": settings.milvus_nprobe}},
limit=top_k,
expr=milvus_expr,
output_fields=[
"doc_id",
"doc_name",
"content",
"section_title",
"page_number",
"regulation_type",
"version",
"semantic_id",
"block_type",
"metadata_json",
],
results = self._run_with_refresh(
lambda: self.collection.search(
data=[query_vector],
anns_field="embedding",
param={"metric_type": "COSINE", "params": {"nprobe": settings.milvus_nprobe}},
limit=top_k,
expr=milvus_expr,
output_fields=[
"doc_id",
"doc_title",
"chunk_id",
"chunk_index",
"piece_index",
"text",
"embedding_text",
"section_title",
"semantic_id",
"chunk_type",
"page_start",
"page_end",
"section_level",
"source_ids",
"section_path",
"metadata_json",
],
)
)
payload: list[RetrievedChunk] = []
for hits in results:
@@ -161,13 +317,18 @@ class MilvusVectorIndex(VectorIndex):
metadata = {"raw_metadata": raw_metadata}
payload.append(
RetrievedChunk(
chunk_id=str(hit.id),
chunk_id=str(hit.entity.get("chunk_id", hit.id)),
doc_id=hit.entity.get("doc_id", ""),
doc_name=hit.entity.get("doc_name", ""),
content=hit.entity.get("content", ""),
doc_title=hit.entity.get("doc_title", ""),
text=hit.entity.get("text", ""),
score=float(hit.score),
chunk_type=hit.entity.get("chunk_type", ""),
section_title=hit.entity.get("section_title", ""),
page_number=int(hit.entity.get("page_number", 0) or 0),
page_start=int(hit.entity.get("page_start", 0) or 0),
page_end=int(hit.entity.get("page_end", 0) or 0),
section_level=int(hit.entity.get("section_level", 0) or 0),
chunk_index=int(hit.entity.get("chunk_index", 0) or 0),
piece_index=int(hit.entity.get("piece_index", 0) or 0),
metadata=metadata,
)
)
@@ -176,7 +337,9 @@ class MilvusVectorIndex(VectorIndex):
def count_by_document(self) -> dict[str, int]:
"""Return doc_id -> chunk count from Milvus."""
try:
rows = self.collection.query(expr="doc_id != \"\"", output_fields=["doc_id"])
rows = self._run_with_refresh(
lambda: self.collection.query(expr="doc_id != \"\"", output_fields=["doc_id", "doc_title"])
)
except Exception:
return {}
counts: dict[str, int] = {}
@@ -189,9 +352,11 @@ class MilvusVectorIndex(VectorIndex):
def list_document_metadata(self) -> list[dict]:
"""Return one metadata row per document from Milvus (single query, no embeddings)."""
try:
rows = self.collection.query(
expr="doc_id != \"\"",
output_fields=["doc_id", "doc_name", "regulation_type", "version"],
rows = self._run_with_refresh(
lambda: self.collection.query(
expr="doc_id != \"\"",
output_fields=["doc_id", "doc_title", "metadata_json"],
)
)
except Exception:
return []
@@ -204,15 +369,26 @@ class MilvusVectorIndex(VectorIndex):
continue
counts[doc_id] = counts.get(doc_id, 0) + 1
if doc_id not in seen:
metadata: dict[str, object] = {}
raw_metadata = row.get("metadata_json", "")
if raw_metadata:
try:
metadata = json.loads(raw_metadata)
except json.JSONDecodeError:
metadata = {}
seen[doc_id] = {
"doc_id": doc_id,
"doc_name": row.get("doc_name", ""),
"regulation_type": row.get("regulation_type", ""),
"version": row.get("version", ""),
"doc_title": row.get("doc_title", ""),
"regulation_type": str(metadata.get("regulation_type", "")),
"version": str(metadata.get("version", "")),
}
return [
{**meta, "chunk_count": counts[meta["doc_id"]]}
{
**meta,
"doc_name": meta.get("doc_title", ""),
"chunk_count": counts[meta["doc_id"]],
}
for meta in seen.values()
]

View File

@@ -67,14 +67,14 @@ class DocumentProcessor:
return [
{
"id": item.chunk_id,
"content": item.content,
"content": item.text,
"score": item.score,
"metadata": {
"doc_id": item.doc_id,
"doc_name": item.doc_name,
"doc_name": item.doc_title,
"chunk_id": item.chunk_id,
"section_title": item.section_title,
"page_number": item.page_number,
"page_number": item.page_start,
**item.metadata,
},
}

View File

@@ -0,0 +1,30 @@
"""Define shared backend exception types."""
from __future__ import annotations
class VectorStoreSchemaError(RuntimeError):
"""Signal that the active vector store schema does not match backend expectations."""
def __init__(
self,
*,
message: str,
host: str,
db_name: str,
collection_name: str,
expected_fields: list[str],
actual_fields: list[str],
) -> None:
"""Initialize the vector store schema error details."""
self.host = host
self.db_name = db_name
self.collection_name = collection_name
self.expected_fields = expected_fields
self.actual_fields = actual_fields
# Keep the message self-contained so runtime logs show the full mismatch context.
details = (
f"{message} | host={host} db={db_name} collection={collection_name} "
f"expected_fields={expected_fields} actual_fields={actual_fields}"
)
super().__init__(details)