#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ 将 vector_chunks.json 向量化并上传到 Milvus 和 PostgreSQL 使用中转站的 OpenAI 兼容 API """ import argparse import json import time from pathlib import Path from typing import List, Dict import psycopg2 from psycopg2.extras import execute_values from pymilvus import ( connections, Collection, FieldSchema, CollectionSchema, DataType, utility, ) from openai import OpenAI # ===================== 配置 ===================== # 中转站配置 RELAY_BASE_URL = "http://6.86.80.4:30080/v1" RELAY_API_KEY = "sk-5HeY7gfSIlyZMacfuXOf5cphpymsNqufEu1ou4U3avbULcyY" EMBEDDING_MODEL = "text-embedding-v3" # 中转站支持的 embedding 模型 # Milvus 配置 MILVUS_HOST = "localhost" MILVUS_PORT = "19530" COLLECTION_NAME = "regulation_chunks" # PostgreSQL 配置 PG_HOST = "6.86.80.10" PG_PORT = 5432 PG_USER = "postgresql" PG_PASSWORD = "postgresql123456" PG_DATABASE = "postgres" # ===================== Embedding ===================== def get_openai_client(api_key: str, base_url: str) -> OpenAI: """创建 OpenAI 客户端连接到中转站""" return OpenAI(api_key=api_key, base_url=base_url) def get_embeddings_batch(client: OpenAI, texts: List[str], batch_size: int = 10) -> List[List[float]]: """批量获取文本向量""" all_embeddings = [] for i in range(0, len(texts), batch_size): batch = texts[i:i + batch_size] print(f"Embedding batch {i // batch_size + 1}/{(len(texts) - 1) // batch_size + 1}...") response = client.embeddings.create( model=EMBEDDING_MODEL, input=batch, ) embeddings = [item.embedding for item in response.data] all_embeddings.extend(embeddings) return all_embeddings # ===================== Milvus ===================== def init_milvus(host: str, port: str): connections.connect("default", host=host, port=port) print(f"已连接 Milvus: {host}:{port}") def create_collection(name: str, dim: int) -> Collection: """创建或获取 collection""" if utility.has_collection(name): print(f"Collection '{name}' 已存在,删除重建") utility.drop_collection(name) fields = [ FieldSchema(name="chunk_id", dtype=DataType.VARCHAR, max_length=64, is_primary=True), FieldSchema(name="doc_id", dtype=DataType.VARCHAR, max_length=128), FieldSchema(name="doc_title", dtype=DataType.VARCHAR, max_length=512), FieldSchema(name="chunk_index", dtype=DataType.INT64), FieldSchema(name="semantic_id", dtype=DataType.VARCHAR, max_length=64), FieldSchema(name="chunk_type", dtype=DataType.VARCHAR, max_length=32), FieldSchema(name="page_start", dtype=DataType.INT64), FieldSchema(name="page_end", dtype=DataType.INT64), FieldSchema(name="section_title", dtype=DataType.VARCHAR, max_length=512), FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=2048), FieldSchema(name="source_ids", dtype=DataType.VARCHAR, max_length=4096), # JSON 字符串 FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=dim), ] schema = CollectionSchema(fields, description="法规文档检索 chunks") collection = Collection(name, schema) # 创建向量索引(IVF_FLAT,适合中小规模) index_params = { "metric_type": "COSINE", "index_type": "IVF_FLAT", "params": {"nlist": 128}, } collection.create_index("embedding", index_params) print(f"Collection '{name}' 创建完成,索引已建立") return collection def insert_chunks(collection: Collection, chunks: List[Dict], embeddings: List[List[float]]): """插入 chunks 到 Milvus""" data = [ [c["chunk_id"] for c in chunks], [c["doc_id"] for c in chunks], [c["doc_title"] for c in chunks], [c["chunk_index"] for c in chunks], [c["semantic_id"] for c in chunks], [c["chunk_type"] for c in chunks], [c["page_start"] for c in chunks], [c["page_end"] for c in chunks], [c["section_title"] for c in chunks], [c["text"] for c in chunks], [json.dumps(c.get("source_ids", [])) for c in chunks], # JSON 字符串 embeddings, ] collection.insert(data) collection.flush() print(f"已插入 {len(chunks)} 个 chunks") def load_collection(collection: Collection): """加载 collection 到内存(搜索前必须)""" collection.load() print(f"Collection 已加载到内存") # ===================== PostgreSQL ===================== def get_pg_connection(host: str, port: int, user: str, password: str, database: str): """获取 PostgreSQL 连接""" conn = psycopg2.connect( host=host, port=port, user=user, password=password, database=database, ) print(f"已连接 PostgreSQL: {host}:{port}/{database}") return conn def insert_chunks_to_pg(conn, chunks: List[Dict], doc_data: Dict): """插入 chunks 和相关数据到 PostgreSQL""" cursor = conn.cursor() try: # 1. 插入文档 cursor.execute(""" INSERT INTO documents (doc_id, title, standard_number, upload_time) VALUES (%s, %s, %s, NOW()) ON CONFLICT (doc_id) DO UPDATE SET title = EXCLUDED.title, updated_at = NOW() """, (doc_data["doc_id"], doc_data["doc_title"], doc_data.get("standard_number"))) # 2. 插入语义块 semantic_blocks = doc_data.get("semantic_blocks", []) if semantic_blocks: block_rows = [ ( doc_data["doc_id"], block["semantic_id"], block["block_type"], block["page_start"], block["page_end"], block.get("section_title"), block.get("section_level"), json.dumps(block.get("source_ids", [])), block["text"], ) for block in semantic_blocks ] execute_values( cursor, """ INSERT INTO semantic_blocks (doc_id, semantic_id, block_type, page_start, page_end, section_title, section_level, source_ids, text) VALUES %s ON CONFLICT (doc_id, semantic_id) DO UPDATE SET text = EXCLUDED.text """, block_rows, ) print(f"已插入 {len(semantic_blocks)} 个语义块") # 3. 插入向量块元数据 chunk_rows = [ ( doc_data["doc_id"], chunk["chunk_id"], chunk["semantic_id"], chunk["chunk_index"], chunk.get("piece_index"), chunk["page_start"], chunk["page_end"], chunk.get("section_title"), chunk["text"], json.dumps(chunk.get("source_ids", [])), ) for chunk in chunks ] execute_values( cursor, """ INSERT INTO vector_chunks (doc_id, chunk_id, semantic_id, chunk_index, piece_index, page_start, page_end, section_title, text, source_ids) VALUES %s ON CONFLICT (doc_id, chunk_id) DO UPDATE SET text = EXCLUDED.text """, chunk_rows, ) print(f"已插入 {len(chunks)} 个向量块元数据") conn.commit() print("PostgreSQL 数据插入完成") except Exception as e: conn.rollback() raise e finally: cursor.close() # ===================== 主流程 ===================== def load_data(file_path: Path) -> Dict: """加载 vector_chunks.json,返回完整数据""" data = json.loads(file_path.read_text(encoding="utf-8")) return data def upload_to_milvus_and_pg( chunks_file: str, api_key: str, base_url: str, milvus_host: str, milvus_port: str, collection_name: str, batch_size: int, pg_host: str, pg_port: int, pg_user: str, pg_password: str, pg_database: str, ): # 1. 加载完整数据 chunks_path = Path(chunks_file).expanduser().resolve() if not chunks_path.exists(): raise FileNotFoundError(f"文件不存在: {chunks_path}") data = load_data(chunks_path) chunks = data.get("vector_chunks", []) if not chunks: raise ValueError("vector_chunks 为空") print(f"加载 {len(chunks)} 个 chunks") # 2. 初始化连接 client = get_openai_client(api_key, base_url) init_milvus(milvus_host, milvus_port) pg_conn = get_pg_connection(pg_host, pg_port, pg_user, pg_password, pg_database) # 3. 获取 embeddings texts = [c["embedding_text"] for c in chunks] embeddings = get_embeddings_batch(client, texts, batch_size) print(f"生成 {len(embeddings)} 个向量") # 4. 获取 embedding 维度 embedding_dim = len(embeddings[0]) print(f"Embedding 维度: {embedding_dim}") # 5. 创建 collection 并插入 Milvus collection = create_collection(collection_name, embedding_dim) insert_chunks(collection, chunks, embeddings) load_collection(collection) # 6. 插入 PostgreSQL insert_chunks_to_pg(pg_conn, chunks, data) # 7. 关闭连接 pg_conn.close() print("上传完成!") # ===================== CLI ===================== def main(): parser = argparse.ArgumentParser(description="将 vector_chunks 向量化并上传到 Milvus 和 PostgreSQL") parser.add_argument("chunks_file", help="vector_chunks.json 文件路径") parser.add_argument("--api-key", default=RELAY_API_KEY, help="中转站 API Key") parser.add_argument("--base-url", default=RELAY_BASE_URL, help="中转站 Base URL") parser.add_argument("--milvus-host", default=MILVUS_HOST, help="Milvus host") parser.add_argument("--milvus-port", default=MILVUS_PORT, help="Milvus port") parser.add_argument("--collection", default=COLLECTION_NAME, help="Milvus collection 名称") parser.add_argument("--batch-size", type=int, default=10, help="Embedding 批量大小(中转站限制最大10)") parser.add_argument("--pg-host", default=PG_HOST, help="PostgreSQL host") parser.add_argument("--pg-port", type=int, default=PG_PORT, help="PostgreSQL port") parser.add_argument("--pg-user", default=PG_USER, help="PostgreSQL user") parser.add_argument("--pg-password", default=PG_PASSWORD, help="PostgreSQL password") parser.add_argument("--pg-database", default=PG_DATABASE, help="PostgreSQL database") args = parser.parse_args() upload_to_milvus_and_pg( chunks_file=args.chunks_file, api_key=args.api_key, base_url=args.base_url, milvus_host=args.milvus_host, milvus_port=args.milvus_port, collection_name=args.collection, batch_size=args.batch_size, pg_host=args.pg_host, pg_port=args.pg_port, pg_user=args.pg_user, pg_password=args.pg_password, pg_database=args.pg_database, ) if __name__ == "__main__": main()