Files
TERES_fastapi_backend/api/apps/chunk_app.py
2025-10-29 11:26:35 +08:00

429 lines
17 KiB
Python

#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import datetime
import json
import re
from typing import Optional, List
import xxhash
from fastapi import APIRouter, Depends, Query, HTTPException
from api import settings
from api.db import LLMType, ParserType
from api.db.services.dialog_service import meta_filter
from api.db.services.document_service import DocumentService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle
from api.db.services.search_service import SearchService
from api.db.services.user_service import UserTenantService
from api.models.chunk_models import (
ListChunkRequest,
GetChunkRequest,
SetChunkRequest,
SwitchChunkRequest,
RemoveChunkRequest,
CreateChunkRequest,
RetrievalTestRequest,
KnowledgeGraphRequest
)
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, get_current_user
from rag.app.qa import beAdoc, rmPrefix
from rag.app.tag import label_question
from rag.nlp import rag_tokenizer, search
from rag.prompts.generator import gen_meta_filter, cross_languages, keyword_extraction
from rag.settings import PAGERANK_FLD
from rag.utils import rmSpace
# 创建 FastAPI 路由器
router = APIRouter()
@router.post('/list')
async def list_chunk(
request: ListChunkRequest,
current_user = Depends(get_current_user)
):
doc_id = request.doc_id
page = request.page
size = request.size
question = request.keywords
try:
tenant_id = DocumentService.get_tenant_id(doc_id)
if not tenant_id:
return get_data_error_result(message="Tenant not found!")
e, doc = DocumentService.get_by_id(doc_id)
if not e:
return get_data_error_result(message="Document not found!")
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
query = {
"doc_ids": [doc_id], "page": page, "size": size, "question": question, "sort": True
}
if request.available_int is not None:
query["available_int"] = int(request.available_int)
sres = settings.retrievaler.search(query, search.index_name(tenant_id), kb_ids, highlight=True)
res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()}
for id in sres.ids:
d = {
"chunk_id": id,
"content_with_weight": rmSpace(sres.highlight[id]) if question and id in sres.highlight else sres.field[
id].get(
"content_with_weight", ""),
"doc_id": sres.field[id]["doc_id"],
"docnm_kwd": sres.field[id]["docnm_kwd"],
"important_kwd": sres.field[id].get("important_kwd", []),
"question_kwd": sres.field[id].get("question_kwd", []),
"image_id": sres.field[id].get("img_id", ""),
"available_int": int(sres.field[id].get("available_int", 1)),
"positions": sres.field[id].get("position_int", []),
}
assert isinstance(d["positions"], list)
assert len(d["positions"]) == 0 or (isinstance(d["positions"][0], list) and len(d["positions"][0]) == 5)
res["chunks"].append(d)
return get_json_result(data=res)
except Exception as e:
if str(e).find("not_found") > 0:
return get_json_result(data=False, message='No chunk found!',
code=settings.RetCode.DATA_ERROR)
return server_error_response(e)
@router.get('/get')
async def get(
chunk_id: str = Query(..., description="块ID"),
current_user = Depends(get_current_user)
):
try:
chunk = None
tenants = UserTenantService.query(user_id=current_user.id)
if not tenants:
return get_data_error_result(message="Tenant not found!")
for tenant in tenants:
kb_ids = KnowledgebaseService.get_kb_ids(tenant.tenant_id)
chunk = settings.docStoreConn.get(chunk_id, search.index_name(tenant.tenant_id), kb_ids)
if chunk:
break
if chunk is None:
return server_error_response(Exception("Chunk not found"))
k = []
for n in chunk.keys():
if re.search(r"(_vec$|_sm_|_tks|_ltks)", n):
k.append(n)
for n in k:
del chunk[n]
return get_json_result(data=chunk)
except Exception as e:
if str(e).find("NotFoundError") >= 0:
return get_json_result(data=False, message='Chunk not found!',
code=settings.RetCode.DATA_ERROR)
return server_error_response(e)
@router.post('/set')
async def set(
request: SetChunkRequest,
current_user = Depends(get_current_user)
):
d = {
"id": request.chunk_id,
"content_with_weight": request.content_with_weight}
d["content_ltks"] = rag_tokenizer.tokenize(request.content_with_weight)
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
if request.important_kwd is not None:
if not isinstance(request.important_kwd, list):
return get_data_error_result(message="`important_kwd` should be a list")
d["important_kwd"] = request.important_kwd
d["important_tks"] = rag_tokenizer.tokenize(" ".join(request.important_kwd))
if request.question_kwd is not None:
if not isinstance(request.question_kwd, list):
return get_data_error_result(message="`question_kwd` should be a list")
d["question_kwd"] = request.question_kwd
d["question_tks"] = rag_tokenizer.tokenize("\n".join(request.question_kwd))
if request.tag_kwd is not None:
d["tag_kwd"] = request.tag_kwd
if request.tag_feas is not None:
d["tag_feas"] = request.tag_feas
if request.available_int is not None:
d["available_int"] = request.available_int
try:
tenant_id = DocumentService.get_tenant_id(request.doc_id)
if not tenant_id:
return get_data_error_result(message="Tenant not found!")
embd_id = DocumentService.get_embd_id(request.doc_id)
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_id)
e, doc = DocumentService.get_by_id(request.doc_id)
if not e:
return get_data_error_result(message="Document not found!")
if doc.parser_id == ParserType.QA:
arr = [
t for t in re.split(
r"[\n\t]",
request.content_with_weight) if len(t) > 1]
q, a = rmPrefix(arr[0]), rmPrefix("\n".join(arr[1:]))
d = beAdoc(d, q, a, not any(
[rag_tokenizer.is_chinese(t) for t in q + a]))
v, c = embd_mdl.encode([doc.name, request.content_with_weight if not d.get("question_kwd") else "\n".join(d["question_kwd"])])
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
d["q_%d_vec" % len(v)] = v.tolist()
settings.docStoreConn.update({"id": request.chunk_id}, d, search.index_name(tenant_id), doc.kb_id)
return get_json_result(data=True)
except Exception as e:
return server_error_response(e)
@router.post('/switch')
async def switch(
request: SwitchChunkRequest,
current_user = Depends(get_current_user)
):
try:
e, doc = DocumentService.get_by_id(request.doc_id)
if not e:
return get_data_error_result(message="Document not found!")
for cid in request.chunk_ids:
if not settings.docStoreConn.update({"id": cid},
{"available_int": int(request.available_int)},
search.index_name(DocumentService.get_tenant_id(request.doc_id)),
doc.kb_id):
return get_data_error_result(message="Index updating failure")
return get_json_result(data=True)
except Exception as e:
return server_error_response(e)
@router.post('/rm')
async def rm(
request: RemoveChunkRequest,
current_user = Depends(get_current_user)
):
from rag.utils.storage_factory import STORAGE_IMPL
try:
e, doc = DocumentService.get_by_id(request.doc_id)
if not e:
return get_data_error_result(message="Document not found!")
if not settings.docStoreConn.delete({"id": request.chunk_ids},
search.index_name(DocumentService.get_tenant_id(request.doc_id)),
doc.kb_id):
return get_data_error_result(message="Chunk deleting failure")
deleted_chunk_ids = request.chunk_ids
chunk_number = len(deleted_chunk_ids)
DocumentService.decrement_chunk_num(doc.id, doc.kb_id, 1, chunk_number, 0)
for cid in deleted_chunk_ids:
if STORAGE_IMPL.obj_exist(doc.kb_id, cid):
STORAGE_IMPL.rm(doc.kb_id, cid)
return get_json_result(data=True)
except Exception as e:
return server_error_response(e)
@router.post('/create')
async def create(
request: CreateChunkRequest,
current_user = Depends(get_current_user)
):
chunck_id = xxhash.xxh64((request.content_with_weight + request.doc_id).encode("utf-8")).hexdigest()
d = {"id": chunck_id, "content_ltks": rag_tokenizer.tokenize(request.content_with_weight),
"content_with_weight": request.content_with_weight}
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
d["important_kwd"] = request.important_kwd
if not isinstance(d["important_kwd"], list):
return get_data_error_result(message="`important_kwd` is required to be a list")
d["important_tks"] = rag_tokenizer.tokenize(" ".join(d["important_kwd"]))
d["question_kwd"] = request.question_kwd
if not isinstance(d["question_kwd"], list):
return get_data_error_result(message="`question_kwd` is required to be a list")
d["question_tks"] = rag_tokenizer.tokenize("\n".join(d["question_kwd"]))
d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19]
d["create_timestamp_flt"] = datetime.datetime.now().timestamp()
if request.tag_feas is not None:
d["tag_feas"] = request.tag_feas
try:
e, doc = DocumentService.get_by_id(request.doc_id)
if not e:
return get_data_error_result(message="Document not found!")
d["kb_id"] = [doc.kb_id]
d["docnm_kwd"] = doc.name
d["title_tks"] = rag_tokenizer.tokenize(doc.name)
d["doc_id"] = doc.id
tenant_id = DocumentService.get_tenant_id(request.doc_id)
if not tenant_id:
return get_data_error_result(message="Tenant not found!")
e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
if not e:
return get_data_error_result(message="Knowledgebase not found!")
if kb.pagerank:
d[PAGERANK_FLD] = kb.pagerank
embd_id = DocumentService.get_embd_id(request.doc_id)
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING.value, embd_id)
v, c = embd_mdl.encode([doc.name, request.content_with_weight if not d["question_kwd"] else "\n".join(d["question_kwd"])])
v = 0.1 * v[0] + 0.9 * v[1]
d["q_%d_vec" % len(v)] = v.tolist()
settings.docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id)
DocumentService.increment_chunk_num(
doc.id, doc.kb_id, c, 1, 0)
return get_json_result(data={"chunk_id": chunck_id})
except Exception as e:
return server_error_response(e)
@router.post('/retrieval_test')
async def retrieval_test(
request: RetrievalTestRequest,
current_user = Depends(get_current_user)
):
page = request.page
size = request.size
question = request.question
kb_ids = request.kb_id
if isinstance(kb_ids, str):
kb_ids = [kb_ids]
if not kb_ids:
return get_json_result(data=False, message='Please specify dataset firstly.',
code=settings.RetCode.DATA_ERROR)
doc_ids = request.doc_ids
use_kg = request.use_kg
top = request.top_k
langs = request.cross_languages
tenant_ids = []
if request.search_id:
search_config = SearchService.get_detail(request.search_id).get("search_config", {})
meta_data_filter = search_config.get("meta_data_filter", {})
metas = DocumentService.get_meta_by_kbs(kb_ids)
if meta_data_filter.get("method") == "auto":
chat_mdl = LLMBundle(current_user.id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
filters = gen_meta_filter(chat_mdl, metas, question)
doc_ids.extend(meta_filter(metas, filters))
if not doc_ids:
doc_ids = None
elif meta_data_filter.get("method") == "manual":
doc_ids.extend(meta_filter(metas, meta_data_filter["manual"]))
if not doc_ids:
doc_ids = None
try:
tenants = UserTenantService.query(user_id=current_user.id)
for kb_id in kb_ids:
for tenant in tenants:
if KnowledgebaseService.query(
tenant_id=tenant.tenant_id, id=kb_id):
tenant_ids.append(tenant.tenant_id)
break
else:
return get_json_result(
data=False, message='Only owner of knowledgebase authorized for this operation.',
code=settings.RetCode.OPERATING_ERROR)
e, kb = KnowledgebaseService.get_by_id(kb_ids[0])
if not e:
return get_data_error_result(message="Knowledgebase not found!")
if langs:
question = cross_languages(kb.tenant_id, None, question, langs)
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
rerank_mdl = None
if request.rerank_id:
rerank_mdl = LLMBundle(kb.tenant_id, LLMType.RERANK.value, llm_name=request.rerank_id)
if request.keyword:
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
question += keyword_extraction(chat_mdl, question)
labels = label_question(question, [kb])
ranks = settings.retrievaler.retrieval(question, embd_mdl, tenant_ids, kb_ids, page, size,
float(request.similarity_threshold),
float(request.vector_similarity_weight),
top,
doc_ids, rerank_mdl=rerank_mdl, highlight=request.highlight,
rank_feature=labels
)
if use_kg:
ck = settings.kg_retrievaler.retrieval(question,
tenant_ids,
kb_ids,
embd_mdl,
LLMBundle(kb.tenant_id, LLMType.CHAT))
if ck["content_with_weight"]:
ranks["chunks"].insert(0, ck)
for c in ranks["chunks"]:
c.pop("vector", None)
ranks["labels"] = labels
return get_json_result(data=ranks)
except Exception as e:
if str(e).find("not_found") > 0:
return get_json_result(data=False, message='No chunk found! Check the chunk status please!',
code=settings.RetCode.DATA_ERROR)
return server_error_response(e)
@router.get('/knowledge_graph')
async def knowledge_graph(
doc_id: str = Query(..., description="文档ID"),
current_user = Depends(get_current_user)
):
tenant_id = DocumentService.get_tenant_id(doc_id)
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
req = {
"doc_ids": [doc_id],
"knowledge_graph_kwd": ["graph", "mind_map"]
}
sres = settings.retrievaler.search(req, search.index_name(tenant_id), kb_ids)
obj = {"graph": {}, "mind_map": {}}
for id in sres.ids[:2]:
ty = sres.field[id]["knowledge_graph_kwd"]
try:
content_json = json.loads(sres.field[id]["content_with_weight"])
except Exception:
continue
if ty == 'mind_map':
node_dict = {}
def repeat_deal(content_json, node_dict):
if 'id' in content_json:
if content_json['id'] in node_dict:
node_name = content_json['id']
content_json['id'] += f"({node_dict[content_json['id']]})"
node_dict[node_name] += 1
else:
node_dict[content_json['id']] = 1
if 'children' in content_json and content_json['children']:
for item in content_json['children']:
repeat_deal(item, node_dict)
repeat_deal(content_json, node_dict)
obj[ty] = content_json
return get_json_result(data=obj)