Files
AIRegulation-DocAnalysis/aliyun_parser/upload_to_milvus.py
wangwei dcda7e0423 @
chore: delete old layout/common/tabs components before redesign
@
2026-06-03 16:58:35 +08:00

327 lines
11 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
将 vector_chunks.json 向量化并上传到 Milvus 和 PostgreSQL
使用中转站的 OpenAI 兼容 API
"""
import argparse
import json
import time
from pathlib import Path
from typing import List, Dict
import psycopg2
from psycopg2.extras import execute_values
from pymilvus import (
connections,
Collection,
FieldSchema,
CollectionSchema,
DataType,
utility,
)
from openai import OpenAI
# ===================== 配置 =====================
# 中转站配置
RELAY_BASE_URL = "http://6.86.80.4:30080/v1"
RELAY_API_KEY = "sk-5HeY7gfSIlyZMacfuXOf5cphpymsNqufEu1ou4U3avbULcyY"
EMBEDDING_MODEL = "text-embedding-v3" # 中转站支持的 embedding 模型
# Milvus 配置
MILVUS_HOST = "localhost"
MILVUS_PORT = "19530"
COLLECTION_NAME = "regulation_chunks"
# PostgreSQL 配置
PG_HOST = "6.86.80.10"
PG_PORT = 5432
PG_USER = "postgresql"
PG_PASSWORD = "postgresql123456"
PG_DATABASE = "postgres"
# ===================== Embedding =====================
def get_openai_client(api_key: str, base_url: str) -> OpenAI:
"""创建 OpenAI 客户端连接到中转站"""
return OpenAI(api_key=api_key, base_url=base_url)
def get_embeddings_batch(client: OpenAI, texts: List[str], batch_size: int = 10) -> List[List[float]]:
"""批量获取文本向量"""
all_embeddings = []
for i in range(0, len(texts), batch_size):
batch = texts[i:i + batch_size]
print(f"Embedding batch {i // batch_size + 1}/{(len(texts) - 1) // batch_size + 1}...")
response = client.embeddings.create(
model=EMBEDDING_MODEL,
input=batch,
)
embeddings = [item.embedding for item in response.data]
all_embeddings.extend(embeddings)
return all_embeddings
# ===================== Milvus =====================
def init_milvus(host: str, port: str):
connections.connect("default", host=host, port=port)
print(f"已连接 Milvus: {host}:{port}")
def create_collection(name: str, dim: int) -> Collection:
"""创建或获取 collection"""
if utility.has_collection(name):
print(f"Collection '{name}' 已存在,删除重建")
utility.drop_collection(name)
fields = [
FieldSchema(name="chunk_id", dtype=DataType.VARCHAR, max_length=64, is_primary=True),
FieldSchema(name="doc_id", dtype=DataType.VARCHAR, max_length=128),
FieldSchema(name="doc_title", dtype=DataType.VARCHAR, max_length=512),
FieldSchema(name="chunk_index", dtype=DataType.INT64),
FieldSchema(name="semantic_id", dtype=DataType.VARCHAR, max_length=64),
FieldSchema(name="chunk_type", dtype=DataType.VARCHAR, max_length=32),
FieldSchema(name="page_start", dtype=DataType.INT64),
FieldSchema(name="page_end", dtype=DataType.INT64),
FieldSchema(name="section_title", dtype=DataType.VARCHAR, max_length=512),
FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=2048),
FieldSchema(name="source_ids", dtype=DataType.VARCHAR, max_length=4096), # JSON 字符串
FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=dim),
]
schema = CollectionSchema(fields, description="法规文档检索 chunks")
collection = Collection(name, schema)
# 创建向量索引IVF_FLAT适合中小规模
index_params = {
"metric_type": "COSINE",
"index_type": "IVF_FLAT",
"params": {"nlist": 128},
}
collection.create_index("embedding", index_params)
print(f"Collection '{name}' 创建完成,索引已建立")
return collection
def insert_chunks(collection: Collection, chunks: List[Dict], embeddings: List[List[float]]):
"""插入 chunks 到 Milvus"""
data = [
[c["chunk_id"] for c in chunks],
[c["doc_id"] for c in chunks],
[c["doc_title"] for c in chunks],
[c["chunk_index"] for c in chunks],
[c["semantic_id"] for c in chunks],
[c["chunk_type"] for c in chunks],
[c["page_start"] for c in chunks],
[c["page_end"] for c in chunks],
[c["section_title"] for c in chunks],
[c["text"] for c in chunks],
[json.dumps(c.get("source_ids", [])) for c in chunks], # JSON 字符串
embeddings,
]
collection.insert(data)
collection.flush()
print(f"已插入 {len(chunks)} 个 chunks")
def load_collection(collection: Collection):
"""加载 collection 到内存(搜索前必须)"""
collection.load()
print(f"Collection 已加载到内存")
# ===================== PostgreSQL =====================
def get_pg_connection(host: str, port: int, user: str, password: str, database: str):
"""获取 PostgreSQL 连接"""
conn = psycopg2.connect(
host=host,
port=port,
user=user,
password=password,
database=database,
)
print(f"已连接 PostgreSQL: {host}:{port}/{database}")
return conn
def insert_chunks_to_pg(conn, chunks: List[Dict], doc_data: Dict):
"""插入 chunks 和相关数据到 PostgreSQL"""
cursor = conn.cursor()
try:
# 1. 插入文档
cursor.execute("""
INSERT INTO documents (doc_id, title, standard_number, upload_time)
VALUES (%s, %s, %s, NOW())
ON CONFLICT (doc_id) DO UPDATE SET title = EXCLUDED.title, updated_at = NOW()
""", (doc_data["doc_id"], doc_data["doc_title"], doc_data.get("standard_number")))
# 2. 插入语义块
semantic_blocks = doc_data.get("semantic_blocks", [])
if semantic_blocks:
block_rows = [
(
doc_data["doc_id"],
block["semantic_id"],
block["block_type"],
block["page_start"],
block["page_end"],
block.get("section_title"),
block.get("section_level"),
json.dumps(block.get("source_ids", [])),
block["text"],
)
for block in semantic_blocks
]
execute_values(
cursor,
"""
INSERT INTO semantic_blocks
(doc_id, semantic_id, block_type, page_start, page_end, section_title, section_level, source_ids, text)
VALUES %s
ON CONFLICT (doc_id, semantic_id) DO UPDATE SET text = EXCLUDED.text
""",
block_rows,
)
print(f"已插入 {len(semantic_blocks)} 个语义块")
# 3. 插入向量块元数据
chunk_rows = [
(
doc_data["doc_id"],
chunk["chunk_id"],
chunk["semantic_id"],
chunk["chunk_index"],
chunk.get("piece_index"),
chunk["page_start"],
chunk["page_end"],
chunk.get("section_title"),
chunk["text"],
json.dumps(chunk.get("source_ids", [])),
)
for chunk in chunks
]
execute_values(
cursor,
"""
INSERT INTO vector_chunks
(doc_id, chunk_id, semantic_id, chunk_index, piece_index, page_start, page_end, section_title, text, source_ids)
VALUES %s
ON CONFLICT (doc_id, chunk_id) DO UPDATE SET text = EXCLUDED.text
""",
chunk_rows,
)
print(f"已插入 {len(chunks)} 个向量块元数据")
conn.commit()
print("PostgreSQL 数据插入完成")
except Exception as e:
conn.rollback()
raise e
finally:
cursor.close()
# ===================== 主流程 =====================
def load_data(file_path: Path) -> Dict:
"""加载 vector_chunks.json返回完整数据"""
data = json.loads(file_path.read_text(encoding="utf-8"))
return data
def upload_to_milvus_and_pg(
chunks_file: str,
api_key: str,
base_url: str,
milvus_host: str,
milvus_port: str,
collection_name: str,
batch_size: int,
pg_host: str,
pg_port: int,
pg_user: str,
pg_password: str,
pg_database: str,
):
# 1. 加载完整数据
chunks_path = Path(chunks_file).expanduser().resolve()
if not chunks_path.exists():
raise FileNotFoundError(f"文件不存在: {chunks_path}")
data = load_data(chunks_path)
chunks = data.get("vector_chunks", [])
if not chunks:
raise ValueError("vector_chunks 为空")
print(f"加载 {len(chunks)} 个 chunks")
# 2. 初始化连接
client = get_openai_client(api_key, base_url)
init_milvus(milvus_host, milvus_port)
pg_conn = get_pg_connection(pg_host, pg_port, pg_user, pg_password, pg_database)
# 3. 获取 embeddings
texts = [c["embedding_text"] for c in chunks]
embeddings = get_embeddings_batch(client, texts, batch_size)
print(f"生成 {len(embeddings)} 个向量")
# 4. 获取 embedding 维度
embedding_dim = len(embeddings[0])
print(f"Embedding 维度: {embedding_dim}")
# 5. 创建 collection 并插入 Milvus
collection = create_collection(collection_name, embedding_dim)
insert_chunks(collection, chunks, embeddings)
load_collection(collection)
# 6. 插入 PostgreSQL
insert_chunks_to_pg(pg_conn, chunks, data)
# 7. 关闭连接
pg_conn.close()
print("上传完成!")
# ===================== CLI =====================
def main():
parser = argparse.ArgumentParser(description="将 vector_chunks 向量化并上传到 Milvus 和 PostgreSQL")
parser.add_argument("chunks_file", help="vector_chunks.json 文件路径")
parser.add_argument("--api-key", default=RELAY_API_KEY, help="中转站 API Key")
parser.add_argument("--base-url", default=RELAY_BASE_URL, help="中转站 Base URL")
parser.add_argument("--milvus-host", default=MILVUS_HOST, help="Milvus host")
parser.add_argument("--milvus-port", default=MILVUS_PORT, help="Milvus port")
parser.add_argument("--collection", default=COLLECTION_NAME, help="Milvus collection 名称")
parser.add_argument("--batch-size", type=int, default=10, help="Embedding 批量大小中转站限制最大10")
parser.add_argument("--pg-host", default=PG_HOST, help="PostgreSQL host")
parser.add_argument("--pg-port", type=int, default=PG_PORT, help="PostgreSQL port")
parser.add_argument("--pg-user", default=PG_USER, help="PostgreSQL user")
parser.add_argument("--pg-password", default=PG_PASSWORD, help="PostgreSQL password")
parser.add_argument("--pg-database", default=PG_DATABASE, help="PostgreSQL database")
args = parser.parse_args()
upload_to_milvus_and_pg(
chunks_file=args.chunks_file,
api_key=args.api_key,
base_url=args.base_url,
milvus_host=args.milvus_host,
milvus_port=args.milvus_port,
collection_name=args.collection,
batch_size=args.batch_size,
pg_host=args.pg_host,
pg_port=args.pg_port,
pg_user=args.pg_user,
pg_password=args.pg_password,
pg_database=args.pg_database,
)
if __name__ == "__main__":
main()