改造 chunk_app.py

This commit is contained in:
2025-10-29 11:26:35 +08:00
parent 6f2f26be10
commit 7d0d65a0ac
2 changed files with 209 additions and 108 deletions

View File

@@ -16,10 +16,10 @@
import datetime import datetime
import json import json
import re import re
from typing import Optional, List
import xxhash import xxhash
from flask import request from fastapi import APIRouter, Depends, Query, HTTPException
from flask_login import current_user, login_required
from api import settings from api import settings
from api.db import LLMType, ParserType from api.db import LLMType, ParserType
@@ -29,7 +29,17 @@ from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle from api.db.services.llm_service import LLMBundle
from api.db.services.search_service import SearchService from api.db.services.search_service import SearchService
from api.db.services.user_service import UserTenantService from api.db.services.user_service import UserTenantService
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request from api.models.chunk_models import (
ListChunkRequest,
GetChunkRequest,
SetChunkRequest,
SwitchChunkRequest,
RemoveChunkRequest,
CreateChunkRequest,
RetrievalTestRequest,
KnowledgeGraphRequest
)
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, get_current_user
from rag.app.qa import beAdoc, rmPrefix from rag.app.qa import beAdoc, rmPrefix
from rag.app.tag import label_question from rag.app.tag import label_question
from rag.nlp import rag_tokenizer, search from rag.nlp import rag_tokenizer, search
@@ -37,18 +47,21 @@ from rag.prompts.generator import gen_meta_filter, cross_languages, keyword_extr
from rag.settings import PAGERANK_FLD from rag.settings import PAGERANK_FLD
from rag.utils import rmSpace from rag.utils import rmSpace
# 创建 FastAPI 路由器
router = APIRouter()
@manager.route('/list', methods=['POST']) # noqa: F821
@login_required @router.post('/list')
@validate_request("doc_id") async def list_chunk(
def list_chunk(): request: ListChunkRequest,
req = request.json current_user = Depends(get_current_user)
doc_id = req["doc_id"] ):
page = int(req.get("page", 1)) doc_id = request.doc_id
size = int(req.get("size", 30)) page = request.page
question = req.get("keywords", "") size = request.size
question = request.keywords
try: try:
tenant_id = DocumentService.get_tenant_id(req["doc_id"]) tenant_id = DocumentService.get_tenant_id(doc_id)
if not tenant_id: if not tenant_id:
return get_data_error_result(message="Tenant not found!") return get_data_error_result(message="Tenant not found!")
e, doc = DocumentService.get_by_id(doc_id) e, doc = DocumentService.get_by_id(doc_id)
@@ -58,8 +71,8 @@ def list_chunk():
query = { query = {
"doc_ids": [doc_id], "page": page, "size": size, "question": question, "sort": True "doc_ids": [doc_id], "page": page, "size": size, "question": question, "sort": True
} }
if "available_int" in req: if request.available_int is not None:
query["available_int"] = int(req["available_int"]) query["available_int"] = int(request.available_int)
sres = settings.retrievaler.search(query, search.index_name(tenant_id), kb_ids, highlight=True) sres = settings.retrievaler.search(query, search.index_name(tenant_id), kb_ids, highlight=True)
res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()} res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()}
for id in sres.ids: for id in sres.ids:
@@ -87,10 +100,11 @@ def list_chunk():
return server_error_response(e) return server_error_response(e)
@manager.route('/get', methods=['GET']) # noqa: F821 @router.get('/get')
@login_required async def get(
def get(): chunk_id: str = Query(..., description="块ID"),
chunk_id = request.args["chunk_id"] current_user = Depends(get_current_user)
):
try: try:
chunk = None chunk = None
tenants = UserTenantService.query(user_id=current_user.id) tenants = UserTenantService.query(user_id=current_user.id)
@@ -119,42 +133,42 @@ def get():
return server_error_response(e) return server_error_response(e)
@manager.route('/set', methods=['POST']) # noqa: F821 @router.post('/set')
@login_required async def set(
@validate_request("doc_id", "chunk_id", "content_with_weight") request: SetChunkRequest,
def set(): current_user = Depends(get_current_user)
req = request.json ):
d = { d = {
"id": req["chunk_id"], "id": request.chunk_id,
"content_with_weight": req["content_with_weight"]} "content_with_weight": request.content_with_weight}
d["content_ltks"] = rag_tokenizer.tokenize(req["content_with_weight"]) d["content_ltks"] = rag_tokenizer.tokenize(request.content_with_weight)
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"]) d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
if "important_kwd" in req: if request.important_kwd is not None:
if not isinstance(req["important_kwd"], list): if not isinstance(request.important_kwd, list):
return get_data_error_result(message="`important_kwd` should be a list") return get_data_error_result(message="`important_kwd` should be a list")
d["important_kwd"] = req["important_kwd"] d["important_kwd"] = request.important_kwd
d["important_tks"] = rag_tokenizer.tokenize(" ".join(req["important_kwd"])) d["important_tks"] = rag_tokenizer.tokenize(" ".join(request.important_kwd))
if "question_kwd" in req: if request.question_kwd is not None:
if not isinstance(req["question_kwd"], list): if not isinstance(request.question_kwd, list):
return get_data_error_result(message="`question_kwd` should be a list") return get_data_error_result(message="`question_kwd` should be a list")
d["question_kwd"] = req["question_kwd"] d["question_kwd"] = request.question_kwd
d["question_tks"] = rag_tokenizer.tokenize("\n".join(req["question_kwd"])) d["question_tks"] = rag_tokenizer.tokenize("\n".join(request.question_kwd))
if "tag_kwd" in req: if request.tag_kwd is not None:
d["tag_kwd"] = req["tag_kwd"] d["tag_kwd"] = request.tag_kwd
if "tag_feas" in req: if request.tag_feas is not None:
d["tag_feas"] = req["tag_feas"] d["tag_feas"] = request.tag_feas
if "available_int" in req: if request.available_int is not None:
d["available_int"] = req["available_int"] d["available_int"] = request.available_int
try: try:
tenant_id = DocumentService.get_tenant_id(req["doc_id"]) tenant_id = DocumentService.get_tenant_id(request.doc_id)
if not tenant_id: if not tenant_id:
return get_data_error_result(message="Tenant not found!") return get_data_error_result(message="Tenant not found!")
embd_id = DocumentService.get_embd_id(req["doc_id"]) embd_id = DocumentService.get_embd_id(request.doc_id)
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_id) embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_id)
e, doc = DocumentService.get_by_id(req["doc_id"]) e, doc = DocumentService.get_by_id(request.doc_id)
if not e: if not e:
return get_data_error_result(message="Document not found!") return get_data_error_result(message="Document not found!")
@@ -162,33 +176,33 @@ def set():
arr = [ arr = [
t for t in re.split( t for t in re.split(
r"[\n\t]", r"[\n\t]",
req["content_with_weight"]) if len(t) > 1] request.content_with_weight) if len(t) > 1]
q, a = rmPrefix(arr[0]), rmPrefix("\n".join(arr[1:])) q, a = rmPrefix(arr[0]), rmPrefix("\n".join(arr[1:]))
d = beAdoc(d, q, a, not any( d = beAdoc(d, q, a, not any(
[rag_tokenizer.is_chinese(t) for t in q + a])) [rag_tokenizer.is_chinese(t) for t in q + a]))
v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not d.get("question_kwd") else "\n".join(d["question_kwd"])]) v, c = embd_mdl.encode([doc.name, request.content_with_weight if not d.get("question_kwd") else "\n".join(d["question_kwd"])])
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1] v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
d["q_%d_vec" % len(v)] = v.tolist() d["q_%d_vec" % len(v)] = v.tolist()
settings.docStoreConn.update({"id": req["chunk_id"]}, d, search.index_name(tenant_id), doc.kb_id) settings.docStoreConn.update({"id": request.chunk_id}, d, search.index_name(tenant_id), doc.kb_id)
return get_json_result(data=True) return get_json_result(data=True)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@manager.route('/switch', methods=['POST']) # noqa: F821 @router.post('/switch')
@login_required async def switch(
@validate_request("chunk_ids", "available_int", "doc_id") request: SwitchChunkRequest,
def switch(): current_user = Depends(get_current_user)
req = request.json ):
try: try:
e, doc = DocumentService.get_by_id(req["doc_id"]) e, doc = DocumentService.get_by_id(request.doc_id)
if not e: if not e:
return get_data_error_result(message="Document not found!") return get_data_error_result(message="Document not found!")
for cid in req["chunk_ids"]: for cid in request.chunk_ids:
if not settings.docStoreConn.update({"id": cid}, if not settings.docStoreConn.update({"id": cid},
{"available_int": int(req["available_int"])}, {"available_int": int(request.available_int)},
search.index_name(DocumentService.get_tenant_id(req["doc_id"])), search.index_name(DocumentService.get_tenant_id(request.doc_id)),
doc.kb_id): doc.kb_id):
return get_data_error_result(message="Index updating failure") return get_data_error_result(message="Index updating failure")
return get_json_result(data=True) return get_json_result(data=True)
@@ -196,21 +210,21 @@ def switch():
return server_error_response(e) return server_error_response(e)
@manager.route('/rm', methods=['POST']) # noqa: F821 @router.post('/rm')
@login_required async def rm(
@validate_request("chunk_ids", "doc_id") request: RemoveChunkRequest,
def rm(): current_user = Depends(get_current_user)
):
from rag.utils.storage_factory import STORAGE_IMPL from rag.utils.storage_factory import STORAGE_IMPL
req = request.json
try: try:
e, doc = DocumentService.get_by_id(req["doc_id"]) e, doc = DocumentService.get_by_id(request.doc_id)
if not e: if not e:
return get_data_error_result(message="Document not found!") return get_data_error_result(message="Document not found!")
if not settings.docStoreConn.delete({"id": req["chunk_ids"]}, if not settings.docStoreConn.delete({"id": request.chunk_ids},
search.index_name(DocumentService.get_tenant_id(req["doc_id"])), search.index_name(DocumentService.get_tenant_id(request.doc_id)),
doc.kb_id): doc.kb_id):
return get_data_error_result(message="Chunk deleting failure") return get_data_error_result(message="Chunk deleting failure")
deleted_chunk_ids = req["chunk_ids"] deleted_chunk_ids = request.chunk_ids
chunk_number = len(deleted_chunk_ids) chunk_number = len(deleted_chunk_ids)
DocumentService.decrement_chunk_num(doc.id, doc.kb_id, 1, chunk_number, 0) DocumentService.decrement_chunk_num(doc.id, doc.kb_id, 1, chunk_number, 0)
for cid in deleted_chunk_ids: for cid in deleted_chunk_ids:
@@ -221,32 +235,30 @@ def rm():
return server_error_response(e) return server_error_response(e)
@manager.route('/create', methods=['POST']) # noqa: F821 @router.post('/create')
@login_required async def create(
@validate_request("doc_id", "content_with_weight") request: CreateChunkRequest,
def create(): current_user = Depends(get_current_user)
req = request.json ):
chunck_id = xxhash.xxh64((req["content_with_weight"] + req["doc_id"]).encode("utf-8")).hexdigest() chunck_id = xxhash.xxh64((request.content_with_weight + request.doc_id).encode("utf-8")).hexdigest()
d = {"id": chunck_id, "content_ltks": rag_tokenizer.tokenize(req["content_with_weight"]), d = {"id": chunck_id, "content_ltks": rag_tokenizer.tokenize(request.content_with_weight),
"content_with_weight": req["content_with_weight"]} "content_with_weight": request.content_with_weight}
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"]) d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
d["important_kwd"] = req.get("important_kwd", []) d["important_kwd"] = request.important_kwd
if not isinstance(d["important_kwd"], list): if not isinstance(d["important_kwd"], list):
return get_data_error_result(message="`important_kwd` is required to be a list") return get_data_error_result(message="`important_kwd` is required to be a list")
d["important_tks"] = rag_tokenizer.tokenize(" ".join(d["important_kwd"])) d["important_tks"] = rag_tokenizer.tokenize(" ".join(d["important_kwd"]))
d["question_kwd"] = req.get("question_kwd", []) d["question_kwd"] = request.question_kwd
if not isinstance(d["question_kwd"], list): if not isinstance(d["question_kwd"], list):
return get_data_error_result(message="`question_kwd` is required to be a list") return get_data_error_result(message="`question_kwd` is required to be a list")
d["question_tks"] = rag_tokenizer.tokenize("\n".join(d["question_kwd"])) d["question_tks"] = rag_tokenizer.tokenize("\n".join(d["question_kwd"]))
d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19] d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19]
d["create_timestamp_flt"] = datetime.datetime.now().timestamp() d["create_timestamp_flt"] = datetime.datetime.now().timestamp()
if "tag_feas" in req: if request.tag_feas is not None:
d["tag_feas"] = req["tag_feas"] d["tag_feas"] = request.tag_feas
if "tag_feas" in req:
d["tag_feas"] = req["tag_feas"]
try: try:
e, doc = DocumentService.get_by_id(req["doc_id"]) e, doc = DocumentService.get_by_id(request.doc_id)
if not e: if not e:
return get_data_error_result(message="Document not found!") return get_data_error_result(message="Document not found!")
d["kb_id"] = [doc.kb_id] d["kb_id"] = [doc.kb_id]
@@ -254,7 +266,7 @@ def create():
d["title_tks"] = rag_tokenizer.tokenize(doc.name) d["title_tks"] = rag_tokenizer.tokenize(doc.name)
d["doc_id"] = doc.id d["doc_id"] = doc.id
tenant_id = DocumentService.get_tenant_id(req["doc_id"]) tenant_id = DocumentService.get_tenant_id(request.doc_id)
if not tenant_id: if not tenant_id:
return get_data_error_result(message="Tenant not found!") return get_data_error_result(message="Tenant not found!")
@@ -264,10 +276,10 @@ def create():
if kb.pagerank: if kb.pagerank:
d[PAGERANK_FLD] = kb.pagerank d[PAGERANK_FLD] = kb.pagerank
embd_id = DocumentService.get_embd_id(req["doc_id"]) embd_id = DocumentService.get_embd_id(request.doc_id)
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING.value, embd_id) embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING.value, embd_id)
v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not d["question_kwd"] else "\n".join(d["question_kwd"])]) v, c = embd_mdl.encode([doc.name, request.content_with_weight if not d["question_kwd"] else "\n".join(d["question_kwd"])])
v = 0.1 * v[0] + 0.9 * v[1] v = 0.1 * v[0] + 0.9 * v[1]
d["q_%d_vec" % len(v)] = v.tolist() d["q_%d_vec" % len(v)] = v.tolist()
settings.docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id) settings.docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id)
@@ -279,29 +291,29 @@ def create():
return server_error_response(e) return server_error_response(e)
@manager.route('/retrieval_test', methods=['POST']) # noqa: F821 @router.post('/retrieval_test')
@login_required async def retrieval_test(
@validate_request("kb_id", "question") request: RetrievalTestRequest,
def retrieval_test(): current_user = Depends(get_current_user)
req = request.json ):
page = int(req.get("page", 1)) page = request.page
size = int(req.get("size", 30)) size = request.size
question = req["question"] question = request.question
kb_ids = req["kb_id"] kb_ids = request.kb_id
if isinstance(kb_ids, str): if isinstance(kb_ids, str):
kb_ids = [kb_ids] kb_ids = [kb_ids]
if not kb_ids: if not kb_ids:
return get_json_result(data=False, message='Please specify dataset firstly.', return get_json_result(data=False, message='Please specify dataset firstly.',
code=settings.RetCode.DATA_ERROR) code=settings.RetCode.DATA_ERROR)
doc_ids = req.get("doc_ids", []) doc_ids = request.doc_ids
use_kg = req.get("use_kg", False) use_kg = request.use_kg
top = int(req.get("top_k", 1024)) top = request.top_k
langs = req.get("cross_languages", []) langs = request.cross_languages
tenant_ids = [] tenant_ids = []
if req.get("search_id", ""): if request.search_id:
search_config = SearchService.get_detail(req.get("search_id", "")).get("search_config", {}) search_config = SearchService.get_detail(request.search_id).get("search_config", {})
meta_data_filter = search_config.get("meta_data_filter", {}) meta_data_filter = search_config.get("meta_data_filter", {})
metas = DocumentService.get_meta_by_kbs(kb_ids) metas = DocumentService.get_meta_by_kbs(kb_ids)
if meta_data_filter.get("method") == "auto": if meta_data_filter.get("method") == "auto":
@@ -338,19 +350,19 @@ def retrieval_test():
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id) embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
rerank_mdl = None rerank_mdl = None
if req.get("rerank_id"): if request.rerank_id:
rerank_mdl = LLMBundle(kb.tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"]) rerank_mdl = LLMBundle(kb.tenant_id, LLMType.RERANK.value, llm_name=request.rerank_id)
if req.get("keyword", False): if request.keyword:
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT) chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
question += keyword_extraction(chat_mdl, question) question += keyword_extraction(chat_mdl, question)
labels = label_question(question, [kb]) labels = label_question(question, [kb])
ranks = settings.retrievaler.retrieval(question, embd_mdl, tenant_ids, kb_ids, page, size, ranks = settings.retrievaler.retrieval(question, embd_mdl, tenant_ids, kb_ids, page, size,
float(req.get("similarity_threshold", 0.0)), float(request.similarity_threshold),
float(req.get("vector_similarity_weight", 0.3)), float(request.vector_similarity_weight),
top, top,
doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"), doc_ids, rerank_mdl=rerank_mdl, highlight=request.highlight,
rank_feature=labels rank_feature=labels
) )
if use_kg: if use_kg:
@@ -374,10 +386,11 @@ def retrieval_test():
return server_error_response(e) return server_error_response(e)
@manager.route('/knowledge_graph', methods=['GET']) # noqa: F821 @router.get('/knowledge_graph')
@login_required async def knowledge_graph(
def knowledge_graph(): doc_id: str = Query(..., description="文档ID"),
doc_id = request.args["doc_id"] current_user = Depends(get_current_user)
):
tenant_id = DocumentService.get_tenant_id(doc_id) tenant_id = DocumentService.get_tenant_id(doc_id)
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id) kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
req = { req = {

View File

@@ -0,0 +1,88 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Optional, List, Dict, Any
from pydantic import BaseModel, Field
class ListChunkRequest(BaseModel):
"""列出块请求模型"""
doc_id: str = Field(..., description="文档ID")
page: Optional[int] = Field(1, description="页码")
size: Optional[int] = Field(30, description="每页大小")
keywords: Optional[str] = Field("", description="关键词")
available_int: Optional[int] = Field(None, description="可用性状态")
class GetChunkRequest(BaseModel):
"""获取块请求模型"""
chunk_id: str = Field(..., description="块ID")
class SetChunkRequest(BaseModel):
"""设置块请求模型"""
doc_id: str = Field(..., description="文档ID")
chunk_id: str = Field(..., description="块ID")
content_with_weight: str = Field(..., description="带权重的内容")
important_kwd: Optional[List[str]] = Field(None, description="重要关键词")
question_kwd: Optional[List[str]] = Field(None, description="问题关键词")
tag_kwd: Optional[str] = Field(None, description="标签关键词")
tag_feas: Optional[Any] = Field(None, description="标签特征")
available_int: Optional[int] = Field(None, description="可用性状态")
class SwitchChunkRequest(BaseModel):
"""切换块状态请求模型"""
chunk_ids: List[str] = Field(..., description="块ID列表")
available_int: int = Field(..., description="可用性状态")
doc_id: str = Field(..., description="文档ID")
class RemoveChunkRequest(BaseModel):
"""删除块请求模型"""
chunk_ids: List[str] = Field(..., description="块ID列表")
doc_id: str = Field(..., description="文档ID")
class CreateChunkRequest(BaseModel):
"""创建块请求模型"""
doc_id: str = Field(..., description="文档ID")
content_with_weight: str = Field(..., description="带权重的内容")
important_kwd: Optional[List[str]] = Field([], description="重要关键词")
question_kwd: Optional[List[str]] = Field([], description="问题关键词")
tag_feas: Optional[Any] = Field(None, description="标签特征")
class RetrievalTestRequest(BaseModel):
"""检索测试请求模型"""
kb_id: List[str] = Field(..., description="知识库ID列表")
question: str = Field(..., description="问题")
page: Optional[int] = Field(1, description="页码")
size: Optional[int] = Field(30, description="每页大小")
doc_ids: Optional[List[str]] = Field([], description="文档ID列表")
use_kg: Optional[bool] = Field(False, description="是否使用知识图谱")
top_k: Optional[int] = Field(1024, description="返回数量")
cross_languages: Optional[List[str]] = Field([], description="跨语言列表")
search_id: Optional[str] = Field("", description="搜索ID")
rerank_id: Optional[str] = Field(None, description="重排序ID")
keyword: Optional[bool] = Field(False, description="是否使用关键词")
similarity_threshold: Optional[float] = Field(0.0, description="相似度阈值")
vector_similarity_weight: Optional[float] = Field(0.3, description="向量相似度权重")
highlight: Optional[bool] = Field(None, description="是否高亮")
class KnowledgeGraphRequest(BaseModel):
"""知识图谱请求模型"""
doc_id: str = Field(..., description="文档ID")