116 lines
5.3 KiB
Python
116 lines
5.3 KiB
Python
|
|
"""Rebuild the migrated Milvus collection from saved vector chunks."""
|
||
|
|
|
||
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
import argparse
|
||
|
|
import json
|
||
|
|
from pathlib import Path
|
||
|
|
|
||
|
|
from pymilvus import Collection, CollectionSchema, DataType, FieldSchema, connections, utility
|
||
|
|
|
||
|
|
|
||
|
|
DEFAULT_COLLECTION = "regulations_dense_1024_v2"
|
||
|
|
DEFAULT_DIM = 1024
|
||
|
|
|
||
|
|
|
||
|
|
def build_collection(name: str, dim: int) -> Collection:
|
||
|
|
"""Create the migrated Milvus collection from scratch."""
|
||
|
|
if utility.has_collection(name):
|
||
|
|
utility.drop_collection(name)
|
||
|
|
|
||
|
|
schema = CollectionSchema(
|
||
|
|
fields=[
|
||
|
|
FieldSchema(name="id", dtype=DataType.VARCHAR, max_length=128, is_primary=True, auto_id=False),
|
||
|
|
FieldSchema(name="doc_id", dtype=DataType.VARCHAR, max_length=64),
|
||
|
|
FieldSchema(name="doc_title", dtype=DataType.VARCHAR, max_length=256),
|
||
|
|
FieldSchema(name="chunk_id", dtype=DataType.VARCHAR, max_length=128),
|
||
|
|
FieldSchema(name="chunk_index", dtype=DataType.INT64),
|
||
|
|
FieldSchema(name="piece_index", dtype=DataType.INT64),
|
||
|
|
FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=65535),
|
||
|
|
FieldSchema(name="embedding_text", dtype=DataType.VARCHAR, max_length=65535),
|
||
|
|
FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=dim),
|
||
|
|
FieldSchema(name="semantic_id", dtype=DataType.VARCHAR, max_length=128),
|
||
|
|
FieldSchema(name="chunk_type", dtype=DataType.VARCHAR, max_length=64),
|
||
|
|
FieldSchema(name="page_start", dtype=DataType.INT64),
|
||
|
|
FieldSchema(name="page_end", dtype=DataType.INT64),
|
||
|
|
FieldSchema(name="section_level", dtype=DataType.INT64),
|
||
|
|
FieldSchema(name="source_ids", dtype=DataType.VARCHAR, max_length=4096),
|
||
|
|
FieldSchema(name="section_path", dtype=DataType.VARCHAR, max_length=4096),
|
||
|
|
FieldSchema(name="section_title", dtype=DataType.VARCHAR, max_length=512),
|
||
|
|
FieldSchema(name="metadata_json", dtype=DataType.VARCHAR, max_length=65535),
|
||
|
|
FieldSchema(name="created_at", dtype=DataType.INT64),
|
||
|
|
],
|
||
|
|
description="Dense-only regulations index",
|
||
|
|
enable_dynamic_field=False,
|
||
|
|
)
|
||
|
|
collection = Collection(name=name, schema=schema)
|
||
|
|
collection.create_index(
|
||
|
|
field_name="embedding",
|
||
|
|
index_params={
|
||
|
|
"metric_type": "COSINE",
|
||
|
|
"index_type": "IVF_FLAT",
|
||
|
|
"params": {"nlist": 128},
|
||
|
|
},
|
||
|
|
)
|
||
|
|
return collection
|
||
|
|
|
||
|
|
|
||
|
|
def load_chunks(payload_path: Path) -> list[dict]:
|
||
|
|
"""Load vector chunks emitted by the Aliyun parser pipeline."""
|
||
|
|
payload = json.loads(payload_path.read_text(encoding="utf-8"))
|
||
|
|
if isinstance(payload, dict):
|
||
|
|
chunks = payload.get("vector_chunks", [])
|
||
|
|
else:
|
||
|
|
chunks = payload
|
||
|
|
if not isinstance(chunks, list):
|
||
|
|
raise ValueError("vector chunk payload must be a list or a dict containing vector_chunks")
|
||
|
|
return chunks
|
||
|
|
|
||
|
|
|
||
|
|
def main() -> None:
|
||
|
|
"""Rebuild the target collection from a vector chunk payload."""
|
||
|
|
parser = argparse.ArgumentParser(description="Rebuild the migrated Milvus collection.")
|
||
|
|
parser.add_argument("--host", default="127.0.0.1", help="Milvus host")
|
||
|
|
parser.add_argument("--port", default="19530", help="Milvus port")
|
||
|
|
parser.add_argument("--collection", default=DEFAULT_COLLECTION, help="Milvus collection name")
|
||
|
|
parser.add_argument("--dim", type=int, default=DEFAULT_DIM, help="Embedding dimension")
|
||
|
|
parser.add_argument("--payload", required=True, help="Path to vector_chunks.json or a compatible JSON file")
|
||
|
|
args = parser.parse_args()
|
||
|
|
|
||
|
|
connections.connect("default", host=args.host, port=args.port)
|
||
|
|
collection = build_collection(args.collection, args.dim)
|
||
|
|
chunks = load_chunks(Path(args.payload))
|
||
|
|
if not chunks:
|
||
|
|
print("No vector chunks found; collection was created but remains empty.")
|
||
|
|
return
|
||
|
|
|
||
|
|
data = [
|
||
|
|
[chunk["chunk_id"] for chunk in chunks],
|
||
|
|
[chunk["doc_id"] for chunk in chunks],
|
||
|
|
[chunk["doc_title"] for chunk in chunks],
|
||
|
|
[chunk["chunk_id"] for chunk in chunks],
|
||
|
|
[int(chunk.get("chunk_index", 0) or 0) for chunk in chunks],
|
||
|
|
[int(chunk.get("piece_index", 0) or 0) for chunk in chunks],
|
||
|
|
[str(chunk.get("text", ""))[:65535] for chunk in chunks],
|
||
|
|
[str(chunk.get("embedding_text", chunk.get("text", "")))[:65535] for chunk in chunks],
|
||
|
|
[chunk["embedding"] for chunk in chunks],
|
||
|
|
[str(chunk.get("semantic_id", "")) for chunk in chunks],
|
||
|
|
[str(chunk.get("chunk_type", "")) for chunk in chunks],
|
||
|
|
[int(chunk.get("page_start", 0) or 0) for chunk in chunks],
|
||
|
|
[int(chunk.get("page_end", 0) or 0) for chunk in chunks],
|
||
|
|
[int(chunk.get("section_level", 0) or 0) for chunk in chunks],
|
||
|
|
[json.dumps(chunk.get("source_ids", []), ensure_ascii=False) for chunk in chunks],
|
||
|
|
[json.dumps(chunk.get("section_path", []), ensure_ascii=False) for chunk in chunks],
|
||
|
|
[str(chunk.get("section_title", "")) for chunk in chunks],
|
||
|
|
[json.dumps(chunk, ensure_ascii=False) for chunk in chunks],
|
||
|
|
[int(chunk.get("created_at", 0) or 0) for chunk in chunks],
|
||
|
|
]
|
||
|
|
collection.insert(data)
|
||
|
|
collection.flush()
|
||
|
|
collection.load()
|
||
|
|
print(f"Rebuilt collection {args.collection} with {len(chunks)} chunks.")
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
main()
|