327 lines
11 KiB
Python
327 lines
11 KiB
Python
|
|
#!/usr/bin/env python3
|
|||
|
|
# -*- coding: utf-8 -*-
|
|||
|
|
"""
|
|||
|
|
将 vector_chunks.json 向量化并上传到 Milvus 和 PostgreSQL
|
|||
|
|
使用中转站的 OpenAI 兼容 API
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import argparse
|
|||
|
|
import json
|
|||
|
|
import time
|
|||
|
|
from pathlib import Path
|
|||
|
|
from typing import List, Dict
|
|||
|
|
|
|||
|
|
import psycopg2
|
|||
|
|
from psycopg2.extras import execute_values
|
|||
|
|
from pymilvus import (
|
|||
|
|
connections,
|
|||
|
|
Collection,
|
|||
|
|
FieldSchema,
|
|||
|
|
CollectionSchema,
|
|||
|
|
DataType,
|
|||
|
|
utility,
|
|||
|
|
)
|
|||
|
|
from openai import OpenAI
|
|||
|
|
|
|||
|
|
# ===================== 配置 =====================
|
|||
|
|
# 中转站配置
|
|||
|
|
RELAY_BASE_URL = "http://6.86.80.4:30080/v1"
|
|||
|
|
RELAY_API_KEY = "sk-5HeY7gfSIlyZMacfuXOf5cphpymsNqufEu1ou4U3avbULcyY"
|
|||
|
|
EMBEDDING_MODEL = "text-embedding-v3" # 中转站支持的 embedding 模型
|
|||
|
|
|
|||
|
|
# Milvus 配置
|
|||
|
|
MILVUS_HOST = "localhost"
|
|||
|
|
MILVUS_PORT = "19530"
|
|||
|
|
COLLECTION_NAME = "regulation_chunks"
|
|||
|
|
|
|||
|
|
# PostgreSQL 配置
|
|||
|
|
PG_HOST = "6.86.80.10"
|
|||
|
|
PG_PORT = 5432
|
|||
|
|
PG_USER = "postgresql"
|
|||
|
|
PG_PASSWORD = "postgresql123456"
|
|||
|
|
PG_DATABASE = "postgres"
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ===================== Embedding =====================
|
|||
|
|
def get_openai_client(api_key: str, base_url: str) -> OpenAI:
|
|||
|
|
"""创建 OpenAI 客户端连接到中转站"""
|
|||
|
|
return OpenAI(api_key=api_key, base_url=base_url)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def get_embeddings_batch(client: OpenAI, texts: List[str], batch_size: int = 10) -> List[List[float]]:
|
|||
|
|
"""批量获取文本向量"""
|
|||
|
|
all_embeddings = []
|
|||
|
|
|
|||
|
|
for i in range(0, len(texts), batch_size):
|
|||
|
|
batch = texts[i:i + batch_size]
|
|||
|
|
print(f"Embedding batch {i // batch_size + 1}/{(len(texts) - 1) // batch_size + 1}...")
|
|||
|
|
|
|||
|
|
response = client.embeddings.create(
|
|||
|
|
model=EMBEDDING_MODEL,
|
|||
|
|
input=batch,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
embeddings = [item.embedding for item in response.data]
|
|||
|
|
all_embeddings.extend(embeddings)
|
|||
|
|
|
|||
|
|
return all_embeddings
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ===================== Milvus =====================
|
|||
|
|
def init_milvus(host: str, port: str):
|
|||
|
|
connections.connect("default", host=host, port=port)
|
|||
|
|
print(f"已连接 Milvus: {host}:{port}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
def create_collection(name: str, dim: int) -> Collection:
|
|||
|
|
"""创建或获取 collection"""
|
|||
|
|
if utility.has_collection(name):
|
|||
|
|
print(f"Collection '{name}' 已存在,删除重建")
|
|||
|
|
utility.drop_collection(name)
|
|||
|
|
|
|||
|
|
fields = [
|
|||
|
|
FieldSchema(name="chunk_id", dtype=DataType.VARCHAR, max_length=64, is_primary=True),
|
|||
|
|
FieldSchema(name="doc_id", dtype=DataType.VARCHAR, max_length=128),
|
|||
|
|
FieldSchema(name="doc_title", dtype=DataType.VARCHAR, max_length=512),
|
|||
|
|
FieldSchema(name="chunk_index", dtype=DataType.INT64),
|
|||
|
|
FieldSchema(name="semantic_id", dtype=DataType.VARCHAR, max_length=64),
|
|||
|
|
FieldSchema(name="chunk_type", dtype=DataType.VARCHAR, max_length=32),
|
|||
|
|
FieldSchema(name="page_start", dtype=DataType.INT64),
|
|||
|
|
FieldSchema(name="page_end", dtype=DataType.INT64),
|
|||
|
|
FieldSchema(name="section_title", dtype=DataType.VARCHAR, max_length=512),
|
|||
|
|
FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=2048),
|
|||
|
|
FieldSchema(name="source_ids", dtype=DataType.VARCHAR, max_length=4096), # JSON 字符串
|
|||
|
|
FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=dim),
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
schema = CollectionSchema(fields, description="法规文档检索 chunks")
|
|||
|
|
collection = Collection(name, schema)
|
|||
|
|
|
|||
|
|
# 创建向量索引(IVF_FLAT,适合中小规模)
|
|||
|
|
index_params = {
|
|||
|
|
"metric_type": "COSINE",
|
|||
|
|
"index_type": "IVF_FLAT",
|
|||
|
|
"params": {"nlist": 128},
|
|||
|
|
}
|
|||
|
|
collection.create_index("embedding", index_params)
|
|||
|
|
print(f"Collection '{name}' 创建完成,索引已建立")
|
|||
|
|
|
|||
|
|
return collection
|
|||
|
|
|
|||
|
|
|
|||
|
|
def insert_chunks(collection: Collection, chunks: List[Dict], embeddings: List[List[float]]):
|
|||
|
|
"""插入 chunks 到 Milvus"""
|
|||
|
|
data = [
|
|||
|
|
[c["chunk_id"] for c in chunks],
|
|||
|
|
[c["doc_id"] for c in chunks],
|
|||
|
|
[c["doc_title"] for c in chunks],
|
|||
|
|
[c["chunk_index"] for c in chunks],
|
|||
|
|
[c["semantic_id"] for c in chunks],
|
|||
|
|
[c["chunk_type"] for c in chunks],
|
|||
|
|
[c["page_start"] for c in chunks],
|
|||
|
|
[c["page_end"] for c in chunks],
|
|||
|
|
[c["section_title"] for c in chunks],
|
|||
|
|
[c["text"] for c in chunks],
|
|||
|
|
[json.dumps(c.get("source_ids", [])) for c in chunks], # JSON 字符串
|
|||
|
|
embeddings,
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
collection.insert(data)
|
|||
|
|
collection.flush()
|
|||
|
|
print(f"已插入 {len(chunks)} 个 chunks")
|
|||
|
|
|
|||
|
|
|
|||
|
|
def load_collection(collection: Collection):
|
|||
|
|
"""加载 collection 到内存(搜索前必须)"""
|
|||
|
|
collection.load()
|
|||
|
|
print(f"Collection 已加载到内存")
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ===================== PostgreSQL =====================
|
|||
|
|
def get_pg_connection(host: str, port: int, user: str, password: str, database: str):
|
|||
|
|
"""获取 PostgreSQL 连接"""
|
|||
|
|
conn = psycopg2.connect(
|
|||
|
|
host=host,
|
|||
|
|
port=port,
|
|||
|
|
user=user,
|
|||
|
|
password=password,
|
|||
|
|
database=database,
|
|||
|
|
)
|
|||
|
|
print(f"已连接 PostgreSQL: {host}:{port}/{database}")
|
|||
|
|
return conn
|
|||
|
|
|
|||
|
|
|
|||
|
|
def insert_chunks_to_pg(conn, chunks: List[Dict], doc_data: Dict):
|
|||
|
|
"""插入 chunks 和相关数据到 PostgreSQL"""
|
|||
|
|
cursor = conn.cursor()
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
# 1. 插入文档
|
|||
|
|
cursor.execute("""
|
|||
|
|
INSERT INTO documents (doc_id, title, standard_number, upload_time)
|
|||
|
|
VALUES (%s, %s, %s, NOW())
|
|||
|
|
ON CONFLICT (doc_id) DO UPDATE SET title = EXCLUDED.title, updated_at = NOW()
|
|||
|
|
""", (doc_data["doc_id"], doc_data["doc_title"], doc_data.get("standard_number")))
|
|||
|
|
|
|||
|
|
# 2. 插入语义块
|
|||
|
|
semantic_blocks = doc_data.get("semantic_blocks", [])
|
|||
|
|
if semantic_blocks:
|
|||
|
|
block_rows = [
|
|||
|
|
(
|
|||
|
|
doc_data["doc_id"],
|
|||
|
|
block["semantic_id"],
|
|||
|
|
block["block_type"],
|
|||
|
|
block["page_start"],
|
|||
|
|
block["page_end"],
|
|||
|
|
block.get("section_title"),
|
|||
|
|
block.get("section_level"),
|
|||
|
|
json.dumps(block.get("source_ids", [])),
|
|||
|
|
block["text"],
|
|||
|
|
)
|
|||
|
|
for block in semantic_blocks
|
|||
|
|
]
|
|||
|
|
execute_values(
|
|||
|
|
cursor,
|
|||
|
|
"""
|
|||
|
|
INSERT INTO semantic_blocks
|
|||
|
|
(doc_id, semantic_id, block_type, page_start, page_end, section_title, section_level, source_ids, text)
|
|||
|
|
VALUES %s
|
|||
|
|
ON CONFLICT (doc_id, semantic_id) DO UPDATE SET text = EXCLUDED.text
|
|||
|
|
""",
|
|||
|
|
block_rows,
|
|||
|
|
)
|
|||
|
|
print(f"已插入 {len(semantic_blocks)} 个语义块")
|
|||
|
|
|
|||
|
|
# 3. 插入向量块元数据
|
|||
|
|
chunk_rows = [
|
|||
|
|
(
|
|||
|
|
doc_data["doc_id"],
|
|||
|
|
chunk["chunk_id"],
|
|||
|
|
chunk["semantic_id"],
|
|||
|
|
chunk["chunk_index"],
|
|||
|
|
chunk.get("piece_index"),
|
|||
|
|
chunk["page_start"],
|
|||
|
|
chunk["page_end"],
|
|||
|
|
chunk.get("section_title"),
|
|||
|
|
chunk["text"],
|
|||
|
|
json.dumps(chunk.get("source_ids", [])),
|
|||
|
|
)
|
|||
|
|
for chunk in chunks
|
|||
|
|
]
|
|||
|
|
execute_values(
|
|||
|
|
cursor,
|
|||
|
|
"""
|
|||
|
|
INSERT INTO vector_chunks
|
|||
|
|
(doc_id, chunk_id, semantic_id, chunk_index, piece_index, page_start, page_end, section_title, text, source_ids)
|
|||
|
|
VALUES %s
|
|||
|
|
ON CONFLICT (doc_id, chunk_id) DO UPDATE SET text = EXCLUDED.text
|
|||
|
|
""",
|
|||
|
|
chunk_rows,
|
|||
|
|
)
|
|||
|
|
print(f"已插入 {len(chunks)} 个向量块元数据")
|
|||
|
|
|
|||
|
|
conn.commit()
|
|||
|
|
print("PostgreSQL 数据插入完成")
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
conn.rollback()
|
|||
|
|
raise e
|
|||
|
|
finally:
|
|||
|
|
cursor.close()
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ===================== 主流程 =====================
|
|||
|
|
def load_data(file_path: Path) -> Dict:
|
|||
|
|
"""加载 vector_chunks.json,返回完整数据"""
|
|||
|
|
data = json.loads(file_path.read_text(encoding="utf-8"))
|
|||
|
|
return data
|
|||
|
|
|
|||
|
|
|
|||
|
|
def upload_to_milvus_and_pg(
|
|||
|
|
chunks_file: str,
|
|||
|
|
api_key: str,
|
|||
|
|
base_url: str,
|
|||
|
|
milvus_host: str,
|
|||
|
|
milvus_port: str,
|
|||
|
|
collection_name: str,
|
|||
|
|
batch_size: int,
|
|||
|
|
pg_host: str,
|
|||
|
|
pg_port: int,
|
|||
|
|
pg_user: str,
|
|||
|
|
pg_password: str,
|
|||
|
|
pg_database: str,
|
|||
|
|
):
|
|||
|
|
# 1. 加载完整数据
|
|||
|
|
chunks_path = Path(chunks_file).expanduser().resolve()
|
|||
|
|
if not chunks_path.exists():
|
|||
|
|
raise FileNotFoundError(f"文件不存在: {chunks_path}")
|
|||
|
|
|
|||
|
|
data = load_data(chunks_path)
|
|||
|
|
chunks = data.get("vector_chunks", [])
|
|||
|
|
if not chunks:
|
|||
|
|
raise ValueError("vector_chunks 为空")
|
|||
|
|
print(f"加载 {len(chunks)} 个 chunks")
|
|||
|
|
|
|||
|
|
# 2. 初始化连接
|
|||
|
|
client = get_openai_client(api_key, base_url)
|
|||
|
|
init_milvus(milvus_host, milvus_port)
|
|||
|
|
pg_conn = get_pg_connection(pg_host, pg_port, pg_user, pg_password, pg_database)
|
|||
|
|
|
|||
|
|
# 3. 获取 embeddings
|
|||
|
|
texts = [c["embedding_text"] for c in chunks]
|
|||
|
|
embeddings = get_embeddings_batch(client, texts, batch_size)
|
|||
|
|
print(f"生成 {len(embeddings)} 个向量")
|
|||
|
|
|
|||
|
|
# 4. 获取 embedding 维度
|
|||
|
|
embedding_dim = len(embeddings[0])
|
|||
|
|
print(f"Embedding 维度: {embedding_dim}")
|
|||
|
|
|
|||
|
|
# 5. 创建 collection 并插入 Milvus
|
|||
|
|
collection = create_collection(collection_name, embedding_dim)
|
|||
|
|
insert_chunks(collection, chunks, embeddings)
|
|||
|
|
load_collection(collection)
|
|||
|
|
|
|||
|
|
# 6. 插入 PostgreSQL
|
|||
|
|
insert_chunks_to_pg(pg_conn, chunks, data)
|
|||
|
|
|
|||
|
|
# 7. 关闭连接
|
|||
|
|
pg_conn.close()
|
|||
|
|
|
|||
|
|
print("上传完成!")
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ===================== CLI =====================
|
|||
|
|
def main():
|
|||
|
|
parser = argparse.ArgumentParser(description="将 vector_chunks 向量化并上传到 Milvus 和 PostgreSQL")
|
|||
|
|
parser.add_argument("chunks_file", help="vector_chunks.json 文件路径")
|
|||
|
|
parser.add_argument("--api-key", default=RELAY_API_KEY, help="中转站 API Key")
|
|||
|
|
parser.add_argument("--base-url", default=RELAY_BASE_URL, help="中转站 Base URL")
|
|||
|
|
parser.add_argument("--milvus-host", default=MILVUS_HOST, help="Milvus host")
|
|||
|
|
parser.add_argument("--milvus-port", default=MILVUS_PORT, help="Milvus port")
|
|||
|
|
parser.add_argument("--collection", default=COLLECTION_NAME, help="Milvus collection 名称")
|
|||
|
|
parser.add_argument("--batch-size", type=int, default=10, help="Embedding 批量大小(中转站限制最大10)")
|
|||
|
|
parser.add_argument("--pg-host", default=PG_HOST, help="PostgreSQL host")
|
|||
|
|
parser.add_argument("--pg-port", type=int, default=PG_PORT, help="PostgreSQL port")
|
|||
|
|
parser.add_argument("--pg-user", default=PG_USER, help="PostgreSQL user")
|
|||
|
|
parser.add_argument("--pg-password", default=PG_PASSWORD, help="PostgreSQL password")
|
|||
|
|
parser.add_argument("--pg-database", default=PG_DATABASE, help="PostgreSQL database")
|
|||
|
|
args = parser.parse_args()
|
|||
|
|
|
|||
|
|
upload_to_milvus_and_pg(
|
|||
|
|
chunks_file=args.chunks_file,
|
|||
|
|
api_key=args.api_key,
|
|||
|
|
base_url=args.base_url,
|
|||
|
|
milvus_host=args.milvus_host,
|
|||
|
|
milvus_port=args.milvus_port,
|
|||
|
|
collection_name=args.collection,
|
|||
|
|
batch_size=args.batch_size,
|
|||
|
|
pg_host=args.pg_host,
|
|||
|
|
pg_port=args.pg_port,
|
|||
|
|
pg_user=args.pg_user,
|
|||
|
|
pg_password=args.pg_password,
|
|||
|
|
pg_database=args.pg_database,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
main()
|