# # Copyright 2024 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import datetime import json import re from typing import Optional, List import xxhash from fastapi import APIRouter, Depends, Query, HTTPException from api import settings from api.db import LLMType, ParserType from api.db.services.dialog_service import meta_filter from api.db.services.document_service import DocumentService from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.llm_service import LLMBundle from api.db.services.search_service import SearchService from api.db.services.user_service import UserTenantService from api.models.chunk_models import ( ListChunkRequest, GetChunkRequest, SetChunkRequest, SwitchChunkRequest, RemoveChunkRequest, CreateChunkRequest, RetrievalTestRequest, KnowledgeGraphRequest ) from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, get_current_user from rag.app.qa import beAdoc, rmPrefix from rag.app.tag import label_question from rag.nlp import rag_tokenizer, search from rag.prompts.generator import gen_meta_filter, cross_languages, keyword_extraction from rag.settings import PAGERANK_FLD from rag.utils import rmSpace # 创建 FastAPI 路由器 router = APIRouter() @router.post('/list') async def list_chunk( request: ListChunkRequest, current_user = Depends(get_current_user) ): doc_id = request.doc_id page = request.page size = request.size question = request.keywords try: tenant_id = DocumentService.get_tenant_id(doc_id) if not tenant_id: return get_data_error_result(message="Tenant not found!") e, doc = DocumentService.get_by_id(doc_id) if not e: return get_data_error_result(message="Document not found!") kb_ids = KnowledgebaseService.get_kb_ids(tenant_id) query = { "doc_ids": [doc_id], "page": page, "size": size, "question": question, "sort": True } if request.available_int is not None: query["available_int"] = int(request.available_int) sres = settings.retrievaler.search(query, search.index_name(tenant_id), kb_ids, highlight=True) res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()} for id in sres.ids: d = { "chunk_id": id, "content_with_weight": rmSpace(sres.highlight[id]) if question and id in sres.highlight else sres.field[ id].get( "content_with_weight", ""), "doc_id": sres.field[id]["doc_id"], "docnm_kwd": sres.field[id]["docnm_kwd"], "important_kwd": sres.field[id].get("important_kwd", []), "question_kwd": sres.field[id].get("question_kwd", []), "image_id": sres.field[id].get("img_id", ""), "available_int": int(sres.field[id].get("available_int", 1)), "positions": sres.field[id].get("position_int", []), } assert isinstance(d["positions"], list) assert len(d["positions"]) == 0 or (isinstance(d["positions"][0], list) and len(d["positions"][0]) == 5) res["chunks"].append(d) return get_json_result(data=res) except Exception as e: if str(e).find("not_found") > 0: return get_json_result(data=False, message='No chunk found!', code=settings.RetCode.DATA_ERROR) return server_error_response(e) @router.get('/get') async def get( chunk_id: str = Query(..., description="块ID"), current_user = Depends(get_current_user) ): try: chunk = None tenants = UserTenantService.query(user_id=current_user.id) if not tenants: return get_data_error_result(message="Tenant not found!") for tenant in tenants: kb_ids = KnowledgebaseService.get_kb_ids(tenant.tenant_id) chunk = settings.docStoreConn.get(chunk_id, search.index_name(tenant.tenant_id), kb_ids) if chunk: break if chunk is None: return server_error_response(Exception("Chunk not found")) k = [] for n in chunk.keys(): if re.search(r"(_vec$|_sm_|_tks|_ltks)", n): k.append(n) for n in k: del chunk[n] return get_json_result(data=chunk) except Exception as e: if str(e).find("NotFoundError") >= 0: return get_json_result(data=False, message='Chunk not found!', code=settings.RetCode.DATA_ERROR) return server_error_response(e) @router.post('/set') async def set( request: SetChunkRequest, current_user = Depends(get_current_user) ): d = { "id": request.chunk_id, "content_with_weight": request.content_with_weight} d["content_ltks"] = rag_tokenizer.tokenize(request.content_with_weight) d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"]) if request.important_kwd is not None: if not isinstance(request.important_kwd, list): return get_data_error_result(message="`important_kwd` should be a list") d["important_kwd"] = request.important_kwd d["important_tks"] = rag_tokenizer.tokenize(" ".join(request.important_kwd)) if request.question_kwd is not None: if not isinstance(request.question_kwd, list): return get_data_error_result(message="`question_kwd` should be a list") d["question_kwd"] = request.question_kwd d["question_tks"] = rag_tokenizer.tokenize("\n".join(request.question_kwd)) if request.tag_kwd is not None: d["tag_kwd"] = request.tag_kwd if request.tag_feas is not None: d["tag_feas"] = request.tag_feas if request.available_int is not None: d["available_int"] = request.available_int try: tenant_id = DocumentService.get_tenant_id(request.doc_id) if not tenant_id: return get_data_error_result(message="Tenant not found!") embd_id = DocumentService.get_embd_id(request.doc_id) embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_id) e, doc = DocumentService.get_by_id(request.doc_id) if not e: return get_data_error_result(message="Document not found!") if doc.parser_id == ParserType.QA: arr = [ t for t in re.split( r"[\n\t]", request.content_with_weight) if len(t) > 1] q, a = rmPrefix(arr[0]), rmPrefix("\n".join(arr[1:])) d = beAdoc(d, q, a, not any( [rag_tokenizer.is_chinese(t) for t in q + a])) v, c = embd_mdl.encode([doc.name, request.content_with_weight if not d.get("question_kwd") else "\n".join(d["question_kwd"])]) v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1] d["q_%d_vec" % len(v)] = v.tolist() settings.docStoreConn.update({"id": request.chunk_id}, d, search.index_name(tenant_id), doc.kb_id) return get_json_result(data=True) except Exception as e: return server_error_response(e) @router.post('/switch') async def switch( request: SwitchChunkRequest, current_user = Depends(get_current_user) ): try: e, doc = DocumentService.get_by_id(request.doc_id) if not e: return get_data_error_result(message="Document not found!") for cid in request.chunk_ids: if not settings.docStoreConn.update({"id": cid}, {"available_int": int(request.available_int)}, search.index_name(DocumentService.get_tenant_id(request.doc_id)), doc.kb_id): return get_data_error_result(message="Index updating failure") return get_json_result(data=True) except Exception as e: return server_error_response(e) @router.post('/rm') async def rm( request: RemoveChunkRequest, current_user = Depends(get_current_user) ): from rag.utils.storage_factory import STORAGE_IMPL try: e, doc = DocumentService.get_by_id(request.doc_id) if not e: return get_data_error_result(message="Document not found!") if not settings.docStoreConn.delete({"id": request.chunk_ids}, search.index_name(DocumentService.get_tenant_id(request.doc_id)), doc.kb_id): return get_data_error_result(message="Chunk deleting failure") deleted_chunk_ids = request.chunk_ids chunk_number = len(deleted_chunk_ids) DocumentService.decrement_chunk_num(doc.id, doc.kb_id, 1, chunk_number, 0) for cid in deleted_chunk_ids: if STORAGE_IMPL.obj_exist(doc.kb_id, cid): STORAGE_IMPL.rm(doc.kb_id, cid) return get_json_result(data=True) except Exception as e: return server_error_response(e) @router.post('/create') async def create( request: CreateChunkRequest, current_user = Depends(get_current_user) ): chunck_id = xxhash.xxh64((request.content_with_weight + request.doc_id).encode("utf-8")).hexdigest() d = {"id": chunck_id, "content_ltks": rag_tokenizer.tokenize(request.content_with_weight), "content_with_weight": request.content_with_weight} d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"]) d["important_kwd"] = request.important_kwd if not isinstance(d["important_kwd"], list): return get_data_error_result(message="`important_kwd` is required to be a list") d["important_tks"] = rag_tokenizer.tokenize(" ".join(d["important_kwd"])) d["question_kwd"] = request.question_kwd if not isinstance(d["question_kwd"], list): return get_data_error_result(message="`question_kwd` is required to be a list") d["question_tks"] = rag_tokenizer.tokenize("\n".join(d["question_kwd"])) d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19] d["create_timestamp_flt"] = datetime.datetime.now().timestamp() if request.tag_feas is not None: d["tag_feas"] = request.tag_feas try: e, doc = DocumentService.get_by_id(request.doc_id) if not e: return get_data_error_result(message="Document not found!") d["kb_id"] = [doc.kb_id] d["docnm_kwd"] = doc.name d["title_tks"] = rag_tokenizer.tokenize(doc.name) d["doc_id"] = doc.id tenant_id = DocumentService.get_tenant_id(request.doc_id) if not tenant_id: return get_data_error_result(message="Tenant not found!") e, kb = KnowledgebaseService.get_by_id(doc.kb_id) if not e: return get_data_error_result(message="Knowledgebase not found!") if kb.pagerank: d[PAGERANK_FLD] = kb.pagerank embd_id = DocumentService.get_embd_id(request.doc_id) embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING.value, embd_id) v, c = embd_mdl.encode([doc.name, request.content_with_weight if not d["question_kwd"] else "\n".join(d["question_kwd"])]) v = 0.1 * v[0] + 0.9 * v[1] d["q_%d_vec" % len(v)] = v.tolist() settings.docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id) DocumentService.increment_chunk_num( doc.id, doc.kb_id, c, 1, 0) return get_json_result(data={"chunk_id": chunck_id}) except Exception as e: return server_error_response(e) @router.post('/retrieval_test') async def retrieval_test( request: RetrievalTestRequest, current_user = Depends(get_current_user) ): page = request.page size = request.size question = request.question kb_ids = request.kb_id if isinstance(kb_ids, str): kb_ids = [kb_ids] if not kb_ids: return get_json_result(data=False, message='Please specify dataset firstly.', code=settings.RetCode.DATA_ERROR) doc_ids = request.doc_ids use_kg = request.use_kg top = request.top_k langs = request.cross_languages tenant_ids = [] if request.search_id: search_config = SearchService.get_detail(request.search_id).get("search_config", {}) meta_data_filter = search_config.get("meta_data_filter", {}) metas = DocumentService.get_meta_by_kbs(kb_ids) if meta_data_filter.get("method") == "auto": chat_mdl = LLMBundle(current_user.id, LLMType.CHAT, llm_name=search_config.get("chat_id", "")) filters = gen_meta_filter(chat_mdl, metas, question) doc_ids.extend(meta_filter(metas, filters)) if not doc_ids: doc_ids = None elif meta_data_filter.get("method") == "manual": doc_ids.extend(meta_filter(metas, meta_data_filter["manual"])) if not doc_ids: doc_ids = None try: tenants = UserTenantService.query(user_id=current_user.id) for kb_id in kb_ids: for tenant in tenants: if KnowledgebaseService.query( tenant_id=tenant.tenant_id, id=kb_id): tenant_ids.append(tenant.tenant_id) break else: return get_json_result( data=False, message='Only owner of knowledgebase authorized for this operation.', code=settings.RetCode.OPERATING_ERROR) e, kb = KnowledgebaseService.get_by_id(kb_ids[0]) if not e: return get_data_error_result(message="Knowledgebase not found!") if langs: question = cross_languages(kb.tenant_id, None, question, langs) embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id) rerank_mdl = None if request.rerank_id: rerank_mdl = LLMBundle(kb.tenant_id, LLMType.RERANK.value, llm_name=request.rerank_id) if request.keyword: chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT) question += keyword_extraction(chat_mdl, question) labels = label_question(question, [kb]) ranks = settings.retrievaler.retrieval(question, embd_mdl, tenant_ids, kb_ids, page, size, float(request.similarity_threshold), float(request.vector_similarity_weight), top, doc_ids, rerank_mdl=rerank_mdl, highlight=request.highlight, rank_feature=labels ) if use_kg: ck = settings.kg_retrievaler.retrieval(question, tenant_ids, kb_ids, embd_mdl, LLMBundle(kb.tenant_id, LLMType.CHAT)) if ck["content_with_weight"]: ranks["chunks"].insert(0, ck) for c in ranks["chunks"]: c.pop("vector", None) ranks["labels"] = labels return get_json_result(data=ranks) except Exception as e: if str(e).find("not_found") > 0: return get_json_result(data=False, message='No chunk found! Check the chunk status please!', code=settings.RetCode.DATA_ERROR) return server_error_response(e) @router.get('/knowledge_graph') async def knowledge_graph( doc_id: str = Query(..., description="文档ID"), current_user = Depends(get_current_user) ): tenant_id = DocumentService.get_tenant_id(doc_id) kb_ids = KnowledgebaseService.get_kb_ids(tenant_id) req = { "doc_ids": [doc_id], "knowledge_graph_kwd": ["graph", "mind_map"] } sres = settings.retrievaler.search(req, search.index_name(tenant_id), kb_ids) obj = {"graph": {}, "mind_map": {}} for id in sres.ids[:2]: ty = sres.field[id]["knowledge_graph_kwd"] try: content_json = json.loads(sres.field[id]["content_with_weight"]) except Exception: continue if ty == 'mind_map': node_dict = {} def repeat_deal(content_json, node_dict): if 'id' in content_json: if content_json['id'] in node_dict: node_name = content_json['id'] content_json['id'] += f"({node_dict[content_json['id']]})" node_dict[node_name] += 1 else: node_dict[content_json['id']] = 1 if 'children' in content_json and content_json['children']: for item in content_json['children']: repeat_deal(item, node_dict) repeat_deal(content_json, node_dict) obj[ty] = content_json return get_json_result(data=obj)