v0.21.1-fastapi
This commit is contained in:
@@ -16,10 +16,19 @@
|
||||
import datetime
|
||||
import json
|
||||
import re
|
||||
from typing import Optional, List
|
||||
|
||||
import xxhash
|
||||
from fastapi import APIRouter, Depends, Query, HTTPException
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
|
||||
from api.apps.models.auth_dependencies import get_current_user
|
||||
from api.apps.models.chunk_models import (
|
||||
ListChunksRequest,
|
||||
SetChunkRequest,
|
||||
SwitchChunksRequest,
|
||||
DeleteChunksRequest,
|
||||
CreateChunkRequest,
|
||||
RetrievalTestRequest,
|
||||
)
|
||||
|
||||
from api import settings
|
||||
from api.db import LLMType, ParserType
|
||||
@@ -29,17 +38,7 @@ from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.search_service import SearchService
|
||||
from api.db.services.user_service import UserTenantService
|
||||
from api.models.chunk_models import (
|
||||
ListChunkRequest,
|
||||
GetChunkRequest,
|
||||
SetChunkRequest,
|
||||
SwitchChunkRequest,
|
||||
RemoveChunkRequest,
|
||||
CreateChunkRequest,
|
||||
RetrievalTestRequest,
|
||||
KnowledgeGraphRequest
|
||||
)
|
||||
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, get_current_user
|
||||
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response
|
||||
from rag.app.qa import beAdoc, rmPrefix
|
||||
from rag.app.tag import label_question
|
||||
from rag.nlp import rag_tokenizer, search
|
||||
@@ -47,19 +46,20 @@ from rag.prompts.generator import gen_meta_filter, cross_languages, keyword_extr
|
||||
from rag.settings import PAGERANK_FLD
|
||||
from rag.utils import rmSpace
|
||||
|
||||
# 创建 FastAPI 路由器
|
||||
# 创建路由器
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post('/list')
|
||||
async def list_chunk(
|
||||
request: ListChunkRequest,
|
||||
request: ListChunksRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""列出文档块"""
|
||||
doc_id = request.doc_id
|
||||
page = request.page
|
||||
size = request.size
|
||||
question = request.keywords
|
||||
page = request.page or 1
|
||||
size = request.size or 30
|
||||
question = request.keywords or ""
|
||||
try:
|
||||
tenant_id = DocumentService.get_tenant_id(doc_id)
|
||||
if not tenant_id:
|
||||
@@ -73,7 +73,7 @@ async def list_chunk(
|
||||
}
|
||||
if request.available_int is not None:
|
||||
query["available_int"] = int(request.available_int)
|
||||
sres = settings.retrievaler.search(query, search.index_name(tenant_id), kb_ids, highlight=True)
|
||||
sres = settings.retriever.search(query, search.index_name(tenant_id), kb_ids, highlight=["content_ltks"])
|
||||
res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()}
|
||||
for id in sres.ids:
|
||||
d = {
|
||||
@@ -105,6 +105,7 @@ async def get(
|
||||
chunk_id: str = Query(..., description="块ID"),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""获取文档块"""
|
||||
try:
|
||||
chunk = None
|
||||
tenants = UserTenantService.query(user_id=current_user.id)
|
||||
@@ -138,6 +139,7 @@ async def set(
|
||||
request: SetChunkRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""设置文档块"""
|
||||
d = {
|
||||
"id": request.chunk_id,
|
||||
"content_with_weight": request.content_with_weight}
|
||||
@@ -192,9 +194,10 @@ async def set(
|
||||
|
||||
@router.post('/switch')
|
||||
async def switch(
|
||||
request: SwitchChunkRequest,
|
||||
request: SwitchChunksRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""切换文档块状态"""
|
||||
try:
|
||||
e, doc = DocumentService.get_by_id(request.doc_id)
|
||||
if not e:
|
||||
@@ -212,9 +215,10 @@ async def switch(
|
||||
|
||||
@router.post('/rm')
|
||||
async def rm(
|
||||
request: RemoveChunkRequest,
|
||||
request: DeleteChunksRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""删除文档块"""
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
try:
|
||||
e, doc = DocumentService.get_by_id(request.doc_id)
|
||||
@@ -240,15 +244,16 @@ async def create(
|
||||
request: CreateChunkRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""创建文档块"""
|
||||
chunck_id = xxhash.xxh64((request.content_with_weight + request.doc_id).encode("utf-8")).hexdigest()
|
||||
d = {"id": chunck_id, "content_ltks": rag_tokenizer.tokenize(request.content_with_weight),
|
||||
"content_with_weight": request.content_with_weight}
|
||||
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
|
||||
d["important_kwd"] = request.important_kwd
|
||||
d["important_kwd"] = request.important_kwd or []
|
||||
if not isinstance(d["important_kwd"], list):
|
||||
return get_data_error_result(message="`important_kwd` is required to be a list")
|
||||
d["important_tks"] = rag_tokenizer.tokenize(" ".join(d["important_kwd"]))
|
||||
d["question_kwd"] = request.question_kwd
|
||||
d["question_kwd"] = request.question_kwd or []
|
||||
if not isinstance(d["question_kwd"], list):
|
||||
return get_data_error_result(message="`question_kwd` is required to be a list")
|
||||
d["question_tks"] = rag_tokenizer.tokenize("\n".join(d["question_kwd"]))
|
||||
@@ -296,8 +301,9 @@ async def retrieval_test(
|
||||
request: RetrievalTestRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
page = request.page
|
||||
size = request.size
|
||||
"""检索测试"""
|
||||
page = request.page or 1
|
||||
size = request.size or 30
|
||||
question = request.question
|
||||
kb_ids = request.kb_id
|
||||
if isinstance(kb_ids, str):
|
||||
@@ -306,10 +312,10 @@ async def retrieval_test(
|
||||
return get_json_result(data=False, message='Please specify dataset firstly.',
|
||||
code=settings.RetCode.DATA_ERROR)
|
||||
|
||||
doc_ids = request.doc_ids
|
||||
use_kg = request.use_kg
|
||||
top = request.top_k
|
||||
langs = request.cross_languages
|
||||
doc_ids = request.doc_ids or []
|
||||
use_kg = request.use_kg or False
|
||||
top = request.top_k or 1024
|
||||
langs = request.cross_languages or []
|
||||
tenant_ids = []
|
||||
|
||||
if request.search_id:
|
||||
@@ -358,15 +364,16 @@ async def retrieval_test(
|
||||
question += keyword_extraction(chat_mdl, question)
|
||||
|
||||
labels = label_question(question, [kb])
|
||||
ranks = settings.retrievaler.retrieval(question, embd_mdl, tenant_ids, kb_ids, page, size,
|
||||
float(request.similarity_threshold),
|
||||
float(request.vector_similarity_weight),
|
||||
ranks = settings.retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids, page, size,
|
||||
float(request.similarity_threshold or 0.0),
|
||||
float(request.vector_similarity_weight or 0.3),
|
||||
top,
|
||||
doc_ids, rerank_mdl=rerank_mdl, highlight=request.highlight,
|
||||
doc_ids, rerank_mdl=rerank_mdl,
|
||||
highlight=request.highlight or False,
|
||||
rank_feature=labels
|
||||
)
|
||||
if use_kg:
|
||||
ck = settings.kg_retrievaler.retrieval(question,
|
||||
ck = settings.kg_retriever.retrieval(question,
|
||||
tenant_ids,
|
||||
kb_ids,
|
||||
embd_mdl,
|
||||
@@ -391,13 +398,14 @@ async def knowledge_graph(
|
||||
doc_id: str = Query(..., description="文档ID"),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""获取知识图谱"""
|
||||
tenant_id = DocumentService.get_tenant_id(doc_id)
|
||||
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
|
||||
req = {
|
||||
"doc_ids": [doc_id],
|
||||
"knowledge_graph_kwd": ["graph", "mind_map"]
|
||||
}
|
||||
sres = settings.retrievaler.search(req, search.index_name(tenant_id), kb_ids)
|
||||
sres = settings.retriever.search(req, search.index_name(tenant_id), kb_ids)
|
||||
obj = {"graph": {}, "mind_map": {}}
|
||||
for id in sres.ids[:2]:
|
||||
ty = sres.field[id]["knowledge_graph_kwd"]
|
||||
|
||||
Reference in New Issue
Block a user