v0.21.1-fastapi

This commit is contained in:
2025-11-04 16:06:36 +08:00
parent 3e58c3d0e9
commit d57b5d76ae
218 changed files with 19617 additions and 72339 deletions

View File

@@ -16,10 +16,19 @@
import datetime
import json
import re
from typing import Optional, List
import xxhash
from fastapi import APIRouter, Depends, Query, HTTPException
from fastapi import APIRouter, Depends, Query
from api.apps.models.auth_dependencies import get_current_user
from api.apps.models.chunk_models import (
ListChunksRequest,
SetChunkRequest,
SwitchChunksRequest,
DeleteChunksRequest,
CreateChunkRequest,
RetrievalTestRequest,
)
from api import settings
from api.db import LLMType, ParserType
@@ -29,17 +38,7 @@ from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle
from api.db.services.search_service import SearchService
from api.db.services.user_service import UserTenantService
from api.models.chunk_models import (
ListChunkRequest,
GetChunkRequest,
SetChunkRequest,
SwitchChunkRequest,
RemoveChunkRequest,
CreateChunkRequest,
RetrievalTestRequest,
KnowledgeGraphRequest
)
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, get_current_user
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response
from rag.app.qa import beAdoc, rmPrefix
from rag.app.tag import label_question
from rag.nlp import rag_tokenizer, search
@@ -47,19 +46,20 @@ from rag.prompts.generator import gen_meta_filter, cross_languages, keyword_extr
from rag.settings import PAGERANK_FLD
from rag.utils import rmSpace
# 创建 FastAPI 路由器
# 创建路由器
router = APIRouter()
@router.post('/list')
async def list_chunk(
request: ListChunkRequest,
request: ListChunksRequest,
current_user = Depends(get_current_user)
):
"""列出文档块"""
doc_id = request.doc_id
page = request.page
size = request.size
question = request.keywords
page = request.page or 1
size = request.size or 30
question = request.keywords or ""
try:
tenant_id = DocumentService.get_tenant_id(doc_id)
if not tenant_id:
@@ -73,7 +73,7 @@ async def list_chunk(
}
if request.available_int is not None:
query["available_int"] = int(request.available_int)
sres = settings.retrievaler.search(query, search.index_name(tenant_id), kb_ids, highlight=True)
sres = settings.retriever.search(query, search.index_name(tenant_id), kb_ids, highlight=["content_ltks"])
res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()}
for id in sres.ids:
d = {
@@ -105,6 +105,7 @@ async def get(
chunk_id: str = Query(..., description="块ID"),
current_user = Depends(get_current_user)
):
"""获取文档块"""
try:
chunk = None
tenants = UserTenantService.query(user_id=current_user.id)
@@ -138,6 +139,7 @@ async def set(
request: SetChunkRequest,
current_user = Depends(get_current_user)
):
"""设置文档块"""
d = {
"id": request.chunk_id,
"content_with_weight": request.content_with_weight}
@@ -192,9 +194,10 @@ async def set(
@router.post('/switch')
async def switch(
request: SwitchChunkRequest,
request: SwitchChunksRequest,
current_user = Depends(get_current_user)
):
"""切换文档块状态"""
try:
e, doc = DocumentService.get_by_id(request.doc_id)
if not e:
@@ -212,9 +215,10 @@ async def switch(
@router.post('/rm')
async def rm(
request: RemoveChunkRequest,
request: DeleteChunksRequest,
current_user = Depends(get_current_user)
):
"""删除文档块"""
from rag.utils.storage_factory import STORAGE_IMPL
try:
e, doc = DocumentService.get_by_id(request.doc_id)
@@ -240,15 +244,16 @@ async def create(
request: CreateChunkRequest,
current_user = Depends(get_current_user)
):
"""创建文档块"""
chunck_id = xxhash.xxh64((request.content_with_weight + request.doc_id).encode("utf-8")).hexdigest()
d = {"id": chunck_id, "content_ltks": rag_tokenizer.tokenize(request.content_with_weight),
"content_with_weight": request.content_with_weight}
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
d["important_kwd"] = request.important_kwd
d["important_kwd"] = request.important_kwd or []
if not isinstance(d["important_kwd"], list):
return get_data_error_result(message="`important_kwd` is required to be a list")
d["important_tks"] = rag_tokenizer.tokenize(" ".join(d["important_kwd"]))
d["question_kwd"] = request.question_kwd
d["question_kwd"] = request.question_kwd or []
if not isinstance(d["question_kwd"], list):
return get_data_error_result(message="`question_kwd` is required to be a list")
d["question_tks"] = rag_tokenizer.tokenize("\n".join(d["question_kwd"]))
@@ -296,8 +301,9 @@ async def retrieval_test(
request: RetrievalTestRequest,
current_user = Depends(get_current_user)
):
page = request.page
size = request.size
"""检索测试"""
page = request.page or 1
size = request.size or 30
question = request.question
kb_ids = request.kb_id
if isinstance(kb_ids, str):
@@ -306,10 +312,10 @@ async def retrieval_test(
return get_json_result(data=False, message='Please specify dataset firstly.',
code=settings.RetCode.DATA_ERROR)
doc_ids = request.doc_ids
use_kg = request.use_kg
top = request.top_k
langs = request.cross_languages
doc_ids = request.doc_ids or []
use_kg = request.use_kg or False
top = request.top_k or 1024
langs = request.cross_languages or []
tenant_ids = []
if request.search_id:
@@ -358,15 +364,16 @@ async def retrieval_test(
question += keyword_extraction(chat_mdl, question)
labels = label_question(question, [kb])
ranks = settings.retrievaler.retrieval(question, embd_mdl, tenant_ids, kb_ids, page, size,
float(request.similarity_threshold),
float(request.vector_similarity_weight),
ranks = settings.retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids, page, size,
float(request.similarity_threshold or 0.0),
float(request.vector_similarity_weight or 0.3),
top,
doc_ids, rerank_mdl=rerank_mdl, highlight=request.highlight,
doc_ids, rerank_mdl=rerank_mdl,
highlight=request.highlight or False,
rank_feature=labels
)
if use_kg:
ck = settings.kg_retrievaler.retrieval(question,
ck = settings.kg_retriever.retrieval(question,
tenant_ids,
kb_ids,
embd_mdl,
@@ -391,13 +398,14 @@ async def knowledge_graph(
doc_id: str = Query(..., description="文档ID"),
current_user = Depends(get_current_user)
):
"""获取知识图谱"""
tenant_id = DocumentService.get_tenant_id(doc_id)
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
req = {
"doc_ids": [doc_id],
"knowledge_graph_kwd": ["graph", "mind_map"]
}
sres = settings.retrievaler.search(req, search.index_name(tenant_id), kb_ids)
sres = settings.retriever.search(req, search.index_name(tenant_id), kb_ids)
obj = {"graph": {}, "mind_map": {}}
for id in sres.ids[:2]:
ty = sres.field[id]["knowledge_graph_kwd"]