From 7d0d65a0acdb0fb49784ebf67db6de3d3b4e8e6d Mon Sep 17 00:00:00 2001 From: dangzerong <429714019@qq.com> Date: Wed, 29 Oct 2025 11:26:35 +0800 Subject: [PATCH] =?UTF-8?q?=E6=94=B9=E9=80=A0=20chunk=5Fapp.py?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api/apps/chunk_app.py | 229 ++++++++++++++++++++----------------- api/models/chunk_models.py | 88 ++++++++++++++ 2 files changed, 209 insertions(+), 108 deletions(-) create mode 100644 api/models/chunk_models.py diff --git a/api/apps/chunk_app.py b/api/apps/chunk_app.py index bfd80ea..21d24dc 100644 --- a/api/apps/chunk_app.py +++ b/api/apps/chunk_app.py @@ -16,10 +16,10 @@ import datetime import json import re +from typing import Optional, List import xxhash -from flask import request -from flask_login import current_user, login_required +from fastapi import APIRouter, Depends, Query, HTTPException from api import settings from api.db import LLMType, ParserType @@ -29,7 +29,17 @@ from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.llm_service import LLMBundle from api.db.services.search_service import SearchService from api.db.services.user_service import UserTenantService -from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request +from api.models.chunk_models import ( + ListChunkRequest, + GetChunkRequest, + SetChunkRequest, + SwitchChunkRequest, + RemoveChunkRequest, + CreateChunkRequest, + RetrievalTestRequest, + KnowledgeGraphRequest +) +from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, get_current_user from rag.app.qa import beAdoc, rmPrefix from rag.app.tag import label_question from rag.nlp import rag_tokenizer, search @@ -37,18 +47,21 @@ from rag.prompts.generator import gen_meta_filter, cross_languages, keyword_extr from rag.settings import PAGERANK_FLD from rag.utils import rmSpace +# 创建 FastAPI 路由器 +router = APIRouter() -@manager.route('/list', methods=['POST']) # noqa: F821 -@login_required -@validate_request("doc_id") -def list_chunk(): - req = request.json - doc_id = req["doc_id"] - page = int(req.get("page", 1)) - size = int(req.get("size", 30)) - question = req.get("keywords", "") + +@router.post('/list') +async def list_chunk( + request: ListChunkRequest, + current_user = Depends(get_current_user) +): + doc_id = request.doc_id + page = request.page + size = request.size + question = request.keywords try: - tenant_id = DocumentService.get_tenant_id(req["doc_id"]) + tenant_id = DocumentService.get_tenant_id(doc_id) if not tenant_id: return get_data_error_result(message="Tenant not found!") e, doc = DocumentService.get_by_id(doc_id) @@ -58,8 +71,8 @@ def list_chunk(): query = { "doc_ids": [doc_id], "page": page, "size": size, "question": question, "sort": True } - if "available_int" in req: - query["available_int"] = int(req["available_int"]) + if request.available_int is not None: + query["available_int"] = int(request.available_int) sres = settings.retrievaler.search(query, search.index_name(tenant_id), kb_ids, highlight=True) res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()} for id in sres.ids: @@ -87,10 +100,11 @@ def list_chunk(): return server_error_response(e) -@manager.route('/get', methods=['GET']) # noqa: F821 -@login_required -def get(): - chunk_id = request.args["chunk_id"] +@router.get('/get') +async def get( + chunk_id: str = Query(..., description="块ID"), + current_user = Depends(get_current_user) +): try: chunk = None tenants = UserTenantService.query(user_id=current_user.id) @@ -119,42 +133,42 @@ def get(): return server_error_response(e) -@manager.route('/set', methods=['POST']) # noqa: F821 -@login_required -@validate_request("doc_id", "chunk_id", "content_with_weight") -def set(): - req = request.json +@router.post('/set') +async def set( + request: SetChunkRequest, + current_user = Depends(get_current_user) +): d = { - "id": req["chunk_id"], - "content_with_weight": req["content_with_weight"]} - d["content_ltks"] = rag_tokenizer.tokenize(req["content_with_weight"]) + "id": request.chunk_id, + "content_with_weight": request.content_with_weight} + d["content_ltks"] = rag_tokenizer.tokenize(request.content_with_weight) d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"]) - if "important_kwd" in req: - if not isinstance(req["important_kwd"], list): + if request.important_kwd is not None: + if not isinstance(request.important_kwd, list): return get_data_error_result(message="`important_kwd` should be a list") - d["important_kwd"] = req["important_kwd"] - d["important_tks"] = rag_tokenizer.tokenize(" ".join(req["important_kwd"])) - if "question_kwd" in req: - if not isinstance(req["question_kwd"], list): + d["important_kwd"] = request.important_kwd + d["important_tks"] = rag_tokenizer.tokenize(" ".join(request.important_kwd)) + if request.question_kwd is not None: + if not isinstance(request.question_kwd, list): return get_data_error_result(message="`question_kwd` should be a list") - d["question_kwd"] = req["question_kwd"] - d["question_tks"] = rag_tokenizer.tokenize("\n".join(req["question_kwd"])) - if "tag_kwd" in req: - d["tag_kwd"] = req["tag_kwd"] - if "tag_feas" in req: - d["tag_feas"] = req["tag_feas"] - if "available_int" in req: - d["available_int"] = req["available_int"] + d["question_kwd"] = request.question_kwd + d["question_tks"] = rag_tokenizer.tokenize("\n".join(request.question_kwd)) + if request.tag_kwd is not None: + d["tag_kwd"] = request.tag_kwd + if request.tag_feas is not None: + d["tag_feas"] = request.tag_feas + if request.available_int is not None: + d["available_int"] = request.available_int try: - tenant_id = DocumentService.get_tenant_id(req["doc_id"]) + tenant_id = DocumentService.get_tenant_id(request.doc_id) if not tenant_id: return get_data_error_result(message="Tenant not found!") - embd_id = DocumentService.get_embd_id(req["doc_id"]) + embd_id = DocumentService.get_embd_id(request.doc_id) embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_id) - e, doc = DocumentService.get_by_id(req["doc_id"]) + e, doc = DocumentService.get_by_id(request.doc_id) if not e: return get_data_error_result(message="Document not found!") @@ -162,33 +176,33 @@ def set(): arr = [ t for t in re.split( r"[\n\t]", - req["content_with_weight"]) if len(t) > 1] + request.content_with_weight) if len(t) > 1] q, a = rmPrefix(arr[0]), rmPrefix("\n".join(arr[1:])) d = beAdoc(d, q, a, not any( [rag_tokenizer.is_chinese(t) for t in q + a])) - v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not d.get("question_kwd") else "\n".join(d["question_kwd"])]) + v, c = embd_mdl.encode([doc.name, request.content_with_weight if not d.get("question_kwd") else "\n".join(d["question_kwd"])]) v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1] d["q_%d_vec" % len(v)] = v.tolist() - settings.docStoreConn.update({"id": req["chunk_id"]}, d, search.index_name(tenant_id), doc.kb_id) + settings.docStoreConn.update({"id": request.chunk_id}, d, search.index_name(tenant_id), doc.kb_id) return get_json_result(data=True) except Exception as e: return server_error_response(e) -@manager.route('/switch', methods=['POST']) # noqa: F821 -@login_required -@validate_request("chunk_ids", "available_int", "doc_id") -def switch(): - req = request.json +@router.post('/switch') +async def switch( + request: SwitchChunkRequest, + current_user = Depends(get_current_user) +): try: - e, doc = DocumentService.get_by_id(req["doc_id"]) + e, doc = DocumentService.get_by_id(request.doc_id) if not e: return get_data_error_result(message="Document not found!") - for cid in req["chunk_ids"]: + for cid in request.chunk_ids: if not settings.docStoreConn.update({"id": cid}, - {"available_int": int(req["available_int"])}, - search.index_name(DocumentService.get_tenant_id(req["doc_id"])), + {"available_int": int(request.available_int)}, + search.index_name(DocumentService.get_tenant_id(request.doc_id)), doc.kb_id): return get_data_error_result(message="Index updating failure") return get_json_result(data=True) @@ -196,21 +210,21 @@ def switch(): return server_error_response(e) -@manager.route('/rm', methods=['POST']) # noqa: F821 -@login_required -@validate_request("chunk_ids", "doc_id") -def rm(): +@router.post('/rm') +async def rm( + request: RemoveChunkRequest, + current_user = Depends(get_current_user) +): from rag.utils.storage_factory import STORAGE_IMPL - req = request.json try: - e, doc = DocumentService.get_by_id(req["doc_id"]) + e, doc = DocumentService.get_by_id(request.doc_id) if not e: return get_data_error_result(message="Document not found!") - if not settings.docStoreConn.delete({"id": req["chunk_ids"]}, - search.index_name(DocumentService.get_tenant_id(req["doc_id"])), + if not settings.docStoreConn.delete({"id": request.chunk_ids}, + search.index_name(DocumentService.get_tenant_id(request.doc_id)), doc.kb_id): return get_data_error_result(message="Chunk deleting failure") - deleted_chunk_ids = req["chunk_ids"] + deleted_chunk_ids = request.chunk_ids chunk_number = len(deleted_chunk_ids) DocumentService.decrement_chunk_num(doc.id, doc.kb_id, 1, chunk_number, 0) for cid in deleted_chunk_ids: @@ -221,32 +235,30 @@ def rm(): return server_error_response(e) -@manager.route('/create', methods=['POST']) # noqa: F821 -@login_required -@validate_request("doc_id", "content_with_weight") -def create(): - req = request.json - chunck_id = xxhash.xxh64((req["content_with_weight"] + req["doc_id"]).encode("utf-8")).hexdigest() - d = {"id": chunck_id, "content_ltks": rag_tokenizer.tokenize(req["content_with_weight"]), - "content_with_weight": req["content_with_weight"]} +@router.post('/create') +async def create( + request: CreateChunkRequest, + current_user = Depends(get_current_user) +): + chunck_id = xxhash.xxh64((request.content_with_weight + request.doc_id).encode("utf-8")).hexdigest() + d = {"id": chunck_id, "content_ltks": rag_tokenizer.tokenize(request.content_with_weight), + "content_with_weight": request.content_with_weight} d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"]) - d["important_kwd"] = req.get("important_kwd", []) + d["important_kwd"] = request.important_kwd if not isinstance(d["important_kwd"], list): return get_data_error_result(message="`important_kwd` is required to be a list") d["important_tks"] = rag_tokenizer.tokenize(" ".join(d["important_kwd"])) - d["question_kwd"] = req.get("question_kwd", []) + d["question_kwd"] = request.question_kwd if not isinstance(d["question_kwd"], list): return get_data_error_result(message="`question_kwd` is required to be a list") d["question_tks"] = rag_tokenizer.tokenize("\n".join(d["question_kwd"])) d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19] d["create_timestamp_flt"] = datetime.datetime.now().timestamp() - if "tag_feas" in req: - d["tag_feas"] = req["tag_feas"] - if "tag_feas" in req: - d["tag_feas"] = req["tag_feas"] + if request.tag_feas is not None: + d["tag_feas"] = request.tag_feas try: - e, doc = DocumentService.get_by_id(req["doc_id"]) + e, doc = DocumentService.get_by_id(request.doc_id) if not e: return get_data_error_result(message="Document not found!") d["kb_id"] = [doc.kb_id] @@ -254,7 +266,7 @@ def create(): d["title_tks"] = rag_tokenizer.tokenize(doc.name) d["doc_id"] = doc.id - tenant_id = DocumentService.get_tenant_id(req["doc_id"]) + tenant_id = DocumentService.get_tenant_id(request.doc_id) if not tenant_id: return get_data_error_result(message="Tenant not found!") @@ -264,10 +276,10 @@ def create(): if kb.pagerank: d[PAGERANK_FLD] = kb.pagerank - embd_id = DocumentService.get_embd_id(req["doc_id"]) + embd_id = DocumentService.get_embd_id(request.doc_id) embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING.value, embd_id) - v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not d["question_kwd"] else "\n".join(d["question_kwd"])]) + v, c = embd_mdl.encode([doc.name, request.content_with_weight if not d["question_kwd"] else "\n".join(d["question_kwd"])]) v = 0.1 * v[0] + 0.9 * v[1] d["q_%d_vec" % len(v)] = v.tolist() settings.docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id) @@ -279,29 +291,29 @@ def create(): return server_error_response(e) -@manager.route('/retrieval_test', methods=['POST']) # noqa: F821 -@login_required -@validate_request("kb_id", "question") -def retrieval_test(): - req = request.json - page = int(req.get("page", 1)) - size = int(req.get("size", 30)) - question = req["question"] - kb_ids = req["kb_id"] +@router.post('/retrieval_test') +async def retrieval_test( + request: RetrievalTestRequest, + current_user = Depends(get_current_user) +): + page = request.page + size = request.size + question = request.question + kb_ids = request.kb_id if isinstance(kb_ids, str): kb_ids = [kb_ids] if not kb_ids: return get_json_result(data=False, message='Please specify dataset firstly.', code=settings.RetCode.DATA_ERROR) - doc_ids = req.get("doc_ids", []) - use_kg = req.get("use_kg", False) - top = int(req.get("top_k", 1024)) - langs = req.get("cross_languages", []) + doc_ids = request.doc_ids + use_kg = request.use_kg + top = request.top_k + langs = request.cross_languages tenant_ids = [] - if req.get("search_id", ""): - search_config = SearchService.get_detail(req.get("search_id", "")).get("search_config", {}) + if request.search_id: + search_config = SearchService.get_detail(request.search_id).get("search_config", {}) meta_data_filter = search_config.get("meta_data_filter", {}) metas = DocumentService.get_meta_by_kbs(kb_ids) if meta_data_filter.get("method") == "auto": @@ -338,19 +350,19 @@ def retrieval_test(): embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id) rerank_mdl = None - if req.get("rerank_id"): - rerank_mdl = LLMBundle(kb.tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"]) + if request.rerank_id: + rerank_mdl = LLMBundle(kb.tenant_id, LLMType.RERANK.value, llm_name=request.rerank_id) - if req.get("keyword", False): + if request.keyword: chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT) question += keyword_extraction(chat_mdl, question) labels = label_question(question, [kb]) ranks = settings.retrievaler.retrieval(question, embd_mdl, tenant_ids, kb_ids, page, size, - float(req.get("similarity_threshold", 0.0)), - float(req.get("vector_similarity_weight", 0.3)), + float(request.similarity_threshold), + float(request.vector_similarity_weight), top, - doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"), + doc_ids, rerank_mdl=rerank_mdl, highlight=request.highlight, rank_feature=labels ) if use_kg: @@ -374,10 +386,11 @@ def retrieval_test(): return server_error_response(e) -@manager.route('/knowledge_graph', methods=['GET']) # noqa: F821 -@login_required -def knowledge_graph(): - doc_id = request.args["doc_id"] +@router.get('/knowledge_graph') +async def knowledge_graph( + doc_id: str = Query(..., description="文档ID"), + current_user = Depends(get_current_user) +): tenant_id = DocumentService.get_tenant_id(doc_id) kb_ids = KnowledgebaseService.get_kb_ids(tenant_id) req = { diff --git a/api/models/chunk_models.py b/api/models/chunk_models.py new file mode 100644 index 0000000..8daf956 --- /dev/null +++ b/api/models/chunk_models.py @@ -0,0 +1,88 @@ +# +# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from typing import Optional, List, Dict, Any +from pydantic import BaseModel, Field + + +class ListChunkRequest(BaseModel): + """列出块请求模型""" + doc_id: str = Field(..., description="文档ID") + page: Optional[int] = Field(1, description="页码") + size: Optional[int] = Field(30, description="每页大小") + keywords: Optional[str] = Field("", description="关键词") + available_int: Optional[int] = Field(None, description="可用性状态") + + +class GetChunkRequest(BaseModel): + """获取块请求模型""" + chunk_id: str = Field(..., description="块ID") + + +class SetChunkRequest(BaseModel): + """设置块请求模型""" + doc_id: str = Field(..., description="文档ID") + chunk_id: str = Field(..., description="块ID") + content_with_weight: str = Field(..., description="带权重的内容") + important_kwd: Optional[List[str]] = Field(None, description="重要关键词") + question_kwd: Optional[List[str]] = Field(None, description="问题关键词") + tag_kwd: Optional[str] = Field(None, description="标签关键词") + tag_feas: Optional[Any] = Field(None, description="标签特征") + available_int: Optional[int] = Field(None, description="可用性状态") + + +class SwitchChunkRequest(BaseModel): + """切换块状态请求模型""" + chunk_ids: List[str] = Field(..., description="块ID列表") + available_int: int = Field(..., description="可用性状态") + doc_id: str = Field(..., description="文档ID") + + +class RemoveChunkRequest(BaseModel): + """删除块请求模型""" + chunk_ids: List[str] = Field(..., description="块ID列表") + doc_id: str = Field(..., description="文档ID") + + +class CreateChunkRequest(BaseModel): + """创建块请求模型""" + doc_id: str = Field(..., description="文档ID") + content_with_weight: str = Field(..., description="带权重的内容") + important_kwd: Optional[List[str]] = Field([], description="重要关键词") + question_kwd: Optional[List[str]] = Field([], description="问题关键词") + tag_feas: Optional[Any] = Field(None, description="标签特征") + + +class RetrievalTestRequest(BaseModel): + """检索测试请求模型""" + kb_id: List[str] = Field(..., description="知识库ID列表") + question: str = Field(..., description="问题") + page: Optional[int] = Field(1, description="页码") + size: Optional[int] = Field(30, description="每页大小") + doc_ids: Optional[List[str]] = Field([], description="文档ID列表") + use_kg: Optional[bool] = Field(False, description="是否使用知识图谱") + top_k: Optional[int] = Field(1024, description="返回数量") + cross_languages: Optional[List[str]] = Field([], description="跨语言列表") + search_id: Optional[str] = Field("", description="搜索ID") + rerank_id: Optional[str] = Field(None, description="重排序ID") + keyword: Optional[bool] = Field(False, description="是否使用关键词") + similarity_threshold: Optional[float] = Field(0.0, description="相似度阈值") + vector_similarity_weight: Optional[float] = Field(0.3, description="向量相似度权重") + highlight: Optional[bool] = Field(None, description="是否高亮") + + +class KnowledgeGraphRequest(BaseModel): + """知识图谱请求模型""" + doc_id: str = Field(..., description="文档ID")