v0.21.1-fastapi

This commit is contained in:
2025-11-04 16:06:36 +08:00
parent 3e58c3d0e9
commit d57b5d76ae
218 changed files with 19617 additions and 72339 deletions

View File

@@ -3,7 +3,7 @@
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# you may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
@@ -15,27 +15,29 @@
#
import json
import logging
from typing import Optional, List
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query, status
from fastapi.responses import JSONResponse
from fastapi import APIRouter, Depends, Query
from api.models.kb_models import (
from api.apps.models.auth_dependencies import get_current_user
from api.apps.models.kb_models import (
CreateKnowledgeBaseRequest,
UpdateKnowledgeBaseRequest,
DeleteKnowledgeBaseRequest,
ListKnowledgeBasesRequest,
ListKnowledgeBasesQuery,
ListKnowledgeBasesBody,
RemoveTagsRequest,
RenameTagRequest,
RunGraphRAGRequest,
ListPipelineLogsQuery,
ListPipelineLogsBody,
ListPipelineDatasetLogsQuery,
ListPipelineDatasetLogsBody,
DeletePipelineLogsQuery,
DeletePipelineLogsBody,
RunGraphragRequest,
RunRaptorRequest,
RunMindmapRequest,
ListPipelineLogsRequest,
ListPipelineDatasetLogsRequest,
DeletePipelineLogsRequest,
UnbindTaskRequest
)
from api.utils.api_utils import get_current_user
from api.db.services import duplicate_name
from api.db.services.document_service import DocumentService, queue_raptor_o_graphrag_tasks
@@ -44,7 +46,12 @@ from api.db.services.file_service import FileService
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
from api.db.services.task_service import TaskService, GRAPH_RAPTOR_FAKE_DOC_ID
from api.db.services.user_service import TenantService, UserTenantService
from api.utils.api_utils import get_error_data_result, server_error_response, get_data_error_result, get_json_result
from api.utils.api_utils import (
get_error_data_result,
server_error_response,
get_data_error_result,
get_json_result,
)
from api.utils import get_uuid
from api.db import PipelineTaskType, StatusEnum, FileSource, VALID_FILE_TYPES, VALID_TASK_STATUS
from api.db.services.knowledgebase_service import KnowledgebaseService
@@ -53,9 +60,10 @@ from api import settings
from rag.nlp import search
from api.constants import DATASET_NAME_LIMIT
from rag.settings import PAGERANK_FLD
from rag.utils.redis_conn import REDIS_CONN
from rag.utils.storage_factory import STORAGE_IMPL
# 创建 FastAPI 路由器
# 创建路由器
router = APIRouter()
@@ -64,7 +72,14 @@ async def create(
request: CreateKnowledgeBaseRequest,
current_user = Depends(get_current_user)
):
dataset_name = request.name
"""创建知识库
支持两种解析类型:
- parse_type=1: 使用内置解析器,需要 parser_id
- parse_type=2: 使用自定义 pipeline需要 pipeline_id
"""
req = request.model_dump(exclude_unset=True)
dataset_name = req["name"]
if not isinstance(dataset_name, str):
return get_data_error_result(message="Dataset name must be string.")
if dataset_name.strip() == "":
@@ -80,55 +95,66 @@ async def create(
tenant_id=current_user.id,
status=StatusEnum.VALID.value)
try:
req = {
"id": get_uuid(),
"name": dataset_name,
"tenant_id": current_user.id,
"created_by": current_user.id,
"parser_id": request.parser_id or "naive",
"description": request.description
}
# 根据 parse_type 处理 parser_id 和 pipeline_id
parse_type = req.pop("parse_type", 1) # 移除 parse_type不需要存储到数据库
if parse_type == 1:
# 使用内置解析器,需要 parser_id
# 验证器已经确保 parser_id 不为空,但保留默认值逻辑以防万一
if not req.get("parser_id") or req["parser_id"].strip() == "":
req["parser_id"] = "naive"
# 清空 pipeline_id设置为 None数据库字段允许为 null
req["pipeline_id"] = None
elif parse_type == 2:
# 使用自定义 pipeline需要 pipeline_id
# 验证器已经确保 pipeline_id 不为空
# parser_id 应该为空字符串,但数据库字段不允许 null所以不设置 parser_id
# 让数据库使用默认值 "naive"(虽然用户传入的是空字符串,但数据库会处理)
# 如果用户明确传入了空字符串,我们也不设置它,让数据库使用默认值
if "parser_id" in req and (not req["parser_id"] or req["parser_id"].strip() == ""):
# 移除空字符串的 parser_id让数据库使用默认值
req.pop("parser_id")
# pipeline_id 保留在 req 中,会被保存到数据库
req["id"] = get_uuid()
req["name"] = dataset_name
req["tenant_id"] = current_user.id
req["created_by"] = current_user.id
# embd_id 已经在模型中定义为必需字段,直接使用
e, t = TenantService.get_by_id(current_user.id)
if not e:
return get_data_error_result(message="Tenant not found.")
# 设置 embd_id 默认值
if not request.embd_id:
req["embd_id"] = t.embd_id
else:
req["embd_id"] = request.embd_id
if request.parser_config:
req["parser_config"] = request.parser_config
else:
req["parser_config"] = {
"layout_recognize": "DeepDOC",
"chunk_token_num": 512,
"delimiter": "\n",
"auto_keywords": 0,
"auto_questions": 0,
"html4excel": False,
"topn_tags": 3,
"raptor": {
"use_raptor": True,
"prompt": "Please summarize the following paragraphs. Be careful with the numbers, do not make things up. Paragraphs as following:\n {cluster_content}\nThe above is the content you need to summarize.",
"max_token": 256,
"threshold": 0.1,
"max_cluster": 64,
"random_seed": 0
},
"graphrag": {
"use_graphrag": True,
"entity_types": [
"organization",
"person",
"geo",
"event",
"category"
],
"method": "light"
}
req["parser_config"] = {
"layout_recognize": "DeepDOC",
"chunk_token_num": 512,
"delimiter": "\n",
"auto_keywords": 0,
"auto_questions": 0,
"html4excel": False,
"topn_tags": 3,
"raptor": {
"use_raptor": True,
"prompt": "Please summarize the following paragraphs. Be careful with the numbers, do not make things up. Paragraphs as following:\n {cluster_content}\nThe above is the content you need to summarize.",
"max_token": 256,
"threshold": 0.1,
"max_cluster": 64,
"random_seed": 0
},
"graphrag": {
"use_graphrag": True,
"entity_types": [
"organization",
"person",
"geo",
"event",
"category"
],
"method": "light"
}
}
if not KnowledgebaseService.save(**req):
return get_data_error_result()
return get_json_result(data={"kb_id": req["id"]})
@@ -141,16 +167,24 @@ async def update(
request: UpdateKnowledgeBaseRequest,
current_user = Depends(get_current_user)
):
if not isinstance(request.name, str):
"""更新知识库"""
req = request.model_dump(exclude_unset=True)
if not isinstance(req["name"], str):
return get_data_error_result(message="Dataset name must be string.")
if request.name.strip() == "":
if req["name"].strip() == "":
return get_data_error_result(message="Dataset name can't be empty.")
if len(request.name.encode("utf-8")) > DATASET_NAME_LIMIT:
if len(req["name"].encode("utf-8")) > DATASET_NAME_LIMIT:
return get_data_error_result(
message=f"Dataset name length is {len(request.name)} which is large than {DATASET_NAME_LIMIT}")
name = request.name.strip()
message=f"Dataset name length is {len(req['name'])} which is large than {DATASET_NAME_LIMIT}")
req["name"] = req["name"].strip()
if not KnowledgebaseService.accessible4deletion(request.kb_id, current_user.id):
# 验证不允许的参数
not_allowed = ["id", "tenant_id", "created_by", "create_time", "update_time", "create_date", "update_date"]
for key in not_allowed:
if key in req:
del req[key]
if not KnowledgebaseService.accessible4deletion(req["kb_id"], current_user.id):
return get_json_result(
data=False,
message='No authorization.',
@@ -158,48 +192,29 @@ async def update(
)
try:
if not KnowledgebaseService.query(
created_by=current_user.id, id=request.kb_id):
created_by=current_user.id, id=req["kb_id"]):
return get_json_result(
data=False, message='Only owner of knowledgebase authorized for this operation.',
code=settings.RetCode.OPERATING_ERROR)
e, kb = KnowledgebaseService.get_by_id(request.kb_id)
e, kb = KnowledgebaseService.get_by_id(req["kb_id"])
if not e:
return get_data_error_result(
message="Can't find this knowledgebase!")
if name.lower() != kb.name.lower() \
if req["name"].lower() != kb.name.lower() \
and len(
KnowledgebaseService.query(name=name, tenant_id=current_user.id, status=StatusEnum.VALID.value)) >= 1:
KnowledgebaseService.query(name=req["name"], tenant_id=current_user.id, status=StatusEnum.VALID.value)) >= 1:
return get_data_error_result(
message="Duplicated knowledgebase name.")
# 构建更新数据,包含所有可更新的字段
update_data = {
"name": name,
"pagerank": request.pagerank
}
# 添加可选字段(如果提供了的话)
if request.description is not None:
update_data["description"] = request.description
if request.permission is not None:
update_data["permission"] = request.permission
if request.avatar is not None:
update_data["avatar"] = request.avatar
if request.parser_id is not None:
update_data["parser_id"] = request.parser_id
if request.embd_id is not None:
update_data["embd_id"] = request.embd_id
if request.parser_config is not None:
update_data["parser_config"] = request.parser_config
if not KnowledgebaseService.update_by_id(kb.id, update_data):
kb_id = req.pop("kb_id")
if not KnowledgebaseService.update_by_id(kb.id, req):
return get_data_error_result()
if kb.pagerank != request.pagerank:
if request.pagerank > 0:
settings.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: request.pagerank},
if kb.pagerank != req.get("pagerank", 0):
if req.get("pagerank", 0) > 0:
settings.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]},
search.index_name(kb.tenant_id), kb.id)
else:
# Elasticsearch requires PAGERANK_FLD be non-zero!
@@ -211,26 +226,7 @@ async def update(
return get_data_error_result(
message="Database error (Knowledgebase rename)!")
kb = kb.to_dict()
# 使用完整的请求数据更新返回结果,保持与原来代码的一致性
request_data = {
"name": name,
"pagerank": request.pagerank
}
if request.description is not None:
request_data["description"] = request.description
if request.permission is not None:
request_data["permission"] = request.permission
if request.avatar is not None:
request_data["avatar"] = request.avatar
if request.parser_id is not None:
request_data["parser_id"] = request.parser_id
if request.embd_id is not None:
request_data["embd_id"] = request.embd_id
if request.parser_config is not None:
request_data["parser_config"] = request.parser_config
kb.update(request_data)
kb.update(req)
return get_json_result(data=kb)
except Exception as e:
@@ -242,6 +238,7 @@ async def detail(
kb_id: str = Query(..., description="知识库ID"),
current_user = Depends(get_current_user)
):
"""获取知识库详情"""
try:
tenants = UserTenantService.query(user_id=current_user.id)
for tenant in tenants:
@@ -257,6 +254,9 @@ async def detail(
return get_data_error_result(
message="Can't find this knowledgebase!")
kb["size"] = DocumentService.get_total_size_by_kb_id(kb_id=kb["id"],keywords="", run_status=[], types=[])
for key in ["graphrag_task_finish_at", "raptor_task_finish_at", "mindmap_task_finish_at"]:
if finish_at := kb.get(key):
kb[key] = finish_at.strftime("%Y-%m-%d %H:%M:%S")
return get_json_result(data=kb)
except Exception as e:
return server_error_response(e)
@@ -264,18 +264,22 @@ async def detail(
@router.post('/list')
async def list_kbs(
request: ListKnowledgeBasesRequest,
keywords: str = Query("", description="关键词"),
page: int = Query(0, description="页码"),
page_size: int = Query(0, description="每页大小"),
parser_id: Optional[str] = Query(None, description="解析器ID"),
orderby: str = Query("create_time", description="排序字段"),
desc: bool = Query(True, description="是否降序"),
query: ListKnowledgeBasesQuery = Depends(),
body: Optional[ListKnowledgeBasesBody] = None,
current_user = Depends(get_current_user)
):
page_number = page
items_per_page = page_size
owner_ids = request.owner_ids
"""列出知识库"""
if body is None:
body = ListKnowledgeBasesBody()
keywords = query.keywords or ""
page_number = int(query.page or 0)
items_per_page = int(query.page_size or 0)
parser_id = query.parser_id
orderby = query.orderby or "create_time"
desc = query.desc.lower() == "true" if query.desc else True
owner_ids = body.owner_ids or [] if body else []
try:
if not owner_ids:
tenants = TenantService.get_joined_tenants_by_user_id(current_user.id)
@@ -296,11 +300,13 @@ async def list_kbs(
except Exception as e:
return server_error_response(e)
@router.post('/rm')
async def rm(
request: DeleteKnowledgeBaseRequest,
current_user = Depends(get_current_user)
):
"""删除知识库"""
if not KnowledgebaseService.accessible4deletion(request.kb_id, current_user.id):
return get_json_result(
data=False,
@@ -343,6 +349,7 @@ async def list_tags(
kb_id: str,
current_user = Depends(get_current_user)
):
"""列出知识库标签"""
if not KnowledgebaseService.accessible(kb_id, current_user.id):
return get_json_result(
data=False,
@@ -353,17 +360,18 @@ async def list_tags(
tenants = UserTenantService.get_tenants_by_user_id(current_user.id)
tags = []
for tenant in tenants:
tags += settings.retrievaler.all_tags(tenant["tenant_id"], [kb_id])
tags += settings.retriever.all_tags(tenant["tenant_id"], [kb_id])
return get_json_result(data=tags)
@router.get('/tags')
async def list_tags_from_kbs(
kb_ids: str = Query(..., description="知识库ID列表逗号分隔"),
kb_ids: str = Query(..., description="知识库ID列表逗号分隔"),
current_user = Depends(get_current_user)
):
kb_ids = kb_ids.split(",")
for kb_id in kb_ids:
"""从多个知识库列出标签"""
kb_id_list = kb_ids.split(",")
for kb_id in kb_id_list:
if not KnowledgebaseService.accessible(kb_id, current_user.id):
return get_json_result(
data=False,
@@ -374,7 +382,7 @@ async def list_tags_from_kbs(
tenants = UserTenantService.get_tenants_by_user_id(current_user.id)
tags = []
for tenant in tenants:
tags += settings.retrievaler.all_tags(tenant["tenant_id"], kb_ids)
tags += settings.retriever.all_tags(tenant["tenant_id"], kb_id_list)
return get_json_result(data=tags)
@@ -384,6 +392,7 @@ async def rm_tags(
request: RemoveTagsRequest,
current_user = Depends(get_current_user)
):
"""删除知识库标签"""
if not KnowledgebaseService.accessible(kb_id, current_user.id):
return get_json_result(
data=False,
@@ -406,6 +415,7 @@ async def rename_tags(
request: RenameTagRequest,
current_user = Depends(get_current_user)
):
"""重命名知识库标签"""
if not KnowledgebaseService.accessible(kb_id, current_user.id):
return get_json_result(
data=False,
@@ -426,6 +436,7 @@ async def knowledge_graph(
kb_id: str,
current_user = Depends(get_current_user)
):
"""获取知识图谱"""
if not KnowledgebaseService.accessible(kb_id, current_user.id):
return get_json_result(
data=False,
@@ -441,7 +452,7 @@ async def knowledge_graph(
obj = {"graph": {}, "mind_map": {}}
if not settings.docStoreConn.indexExist(search.index_name(kb.tenant_id), kb_id):
return get_json_result(data=obj)
sres = settings.retrievaler.search(req, search.index_name(kb.tenant_id), [kb_id])
sres = settings.retriever.search(req, search.index_name(kb.tenant_id), [kb_id])
if not len(sres.ids):
return get_json_result(data=obj)
@@ -468,6 +479,7 @@ async def delete_knowledge_graph(
kb_id: str,
current_user = Depends(get_current_user)
):
"""删除知识图谱"""
if not KnowledgebaseService.accessible(kb_id, current_user.id):
return get_json_result(
data=False,
@@ -482,18 +494,19 @@ async def delete_knowledge_graph(
@router.get("/get_meta")
async def get_meta(
kb_ids: str = Query(..., description="知识库ID列表逗号分隔"),
kb_ids: str = Query(..., description="知识库ID列表逗号分隔"),
current_user = Depends(get_current_user)
):
kb_ids = kb_ids.split(",")
for kb_id in kb_ids:
"""获取知识库元数据"""
kb_id_list = kb_ids.split(",")
for kb_id in kb_id_list:
if not KnowledgebaseService.accessible(kb_id, current_user.id):
return get_json_result(
data=False,
message='No authorization.',
code=settings.RetCode.AUTHENTICATION_ERROR
)
return get_json_result(data=DocumentService.get_meta_by_kbs(kb_ids))
return get_json_result(data=DocumentService.get_meta_by_kbs(kb_id_list))
@router.get("/basic_info")
@@ -501,6 +514,7 @@ async def get_basic_info(
kb_id: str = Query(..., description="知识库ID"),
current_user = Depends(get_current_user)
):
"""获取知识库基本信息"""
if not KnowledgebaseService.accessible(kb_id, current_user.id):
return get_json_result(
data=False,
@@ -515,42 +529,45 @@ async def get_basic_info(
@router.post("/list_pipeline_logs")
async def list_pipeline_logs(
request: ListPipelineLogsRequest,
kb_id: str = Query(..., description="知识库ID"),
keywords: str = Query("", description="关键词"),
page: int = Query(0, description="页码"),
page_size: int = Query(0, description="每页大小"),
orderby: str = Query("create_time", description="排序字段"),
desc: bool = Query(True, description="是否降序"),
create_date_from: str = Query("", description="创建日期开始"),
create_date_to: str = Query("", description="创建日期结束"),
query: ListPipelineLogsQuery = Depends(),
body: Optional[ListPipelineLogsBody] = None,
current_user = Depends(get_current_user)
):
if not kb_id:
"""列出流水线日志"""
if not query.kb_id:
return get_json_result(data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
page_number = page
items_per_page = page_size
if body is None:
body = ListPipelineLogsBody()
keywords = query.keywords or ""
page_number = int(query.page or 0)
items_per_page = int(query.page_size or 0)
orderby = query.orderby or "create_time"
desc = query.desc.lower() == "true" if query.desc else True
create_date_from = query.create_date_from or ""
create_date_to = query.create_date_to or ""
if create_date_to > create_date_from:
return get_data_error_result(message="Create data filter is abnormal.")
operation_status = request.operation_status
operation_status = body.operation_status or []
if operation_status:
invalid_status = {s for s in operation_status if s not in VALID_TASK_STATUS}
if invalid_status:
return get_data_error_result(message=f"Invalid filter operation_status status conditions: {', '.join(invalid_status)}")
types = request.types
types = body.types or []
if types:
invalid_types = {t for t in types if t not in VALID_FILE_TYPES}
if invalid_types:
return get_data_error_result(message=f"Invalid filter conditions: {', '.join(invalid_types)} type{'s' if len(invalid_types) > 1 else ''}")
suffix = request.suffix
suffix = body.suffix or []
try:
logs, tol = PipelineOperationLogService.get_file_logs_by_kb_id(kb_id, page_number, items_per_page, orderby, desc, keywords, operation_status, types, suffix, create_date_from, create_date_to)
logs, tol = PipelineOperationLogService.get_file_logs_by_kb_id(
query.kb_id, page_number, items_per_page, orderby, desc, keywords,
operation_status, types, suffix, create_date_from, create_date_to)
return get_json_result(data={"total": tol, "logs": logs})
except Exception as e:
return server_error_response(e)
@@ -558,33 +575,36 @@ async def list_pipeline_logs(
@router.post("/list_pipeline_dataset_logs")
async def list_pipeline_dataset_logs(
request: ListPipelineDatasetLogsRequest,
kb_id: str = Query(..., description="知识库ID"),
page: int = Query(0, description="页码"),
page_size: int = Query(0, description="每页大小"),
orderby: str = Query("create_time", description="排序字段"),
desc: bool = Query(True, description="是否降序"),
create_date_from: str = Query("", description="创建日期开始"),
create_date_to: str = Query("", description="创建日期结束"),
query: ListPipelineDatasetLogsQuery = Depends(),
body: Optional[ListPipelineDatasetLogsBody] = None,
current_user = Depends(get_current_user)
):
if not kb_id:
"""列出流水线数据集日志"""
if not query.kb_id:
return get_json_result(data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
page_number = page
items_per_page = page_size
if body is None:
body = ListPipelineDatasetLogsBody()
page_number = int(query.page or 0)
items_per_page = int(query.page_size or 0)
orderby = query.orderby or "create_time"
desc = query.desc.lower() == "true" if query.desc else True
create_date_from = query.create_date_from or ""
create_date_to = query.create_date_to or ""
if create_date_to > create_date_from:
return get_data_error_result(message="Create data filter is abnormal.")
operation_status = request.operation_status
operation_status = body.operation_status or []
if operation_status:
invalid_status = {s for s in operation_status if s not in VALID_TASK_STATUS}
if invalid_status:
return get_data_error_result(message=f"Invalid filter operation_status status conditions: {', '.join(invalid_status)}")
try:
logs, tol = PipelineOperationLogService.get_dataset_logs_by_kb_id(kb_id, page_number, items_per_page, orderby, desc, operation_status, create_date_from, create_date_to)
logs, tol = PipelineOperationLogService.get_dataset_logs_by_kb_id(
query.kb_id, page_number, items_per_page, orderby, desc,
operation_status, create_date_from, create_date_to)
return get_json_result(data={"total": tol, "logs": logs})
except Exception as e:
return server_error_response(e)
@@ -592,14 +612,18 @@ async def list_pipeline_dataset_logs(
@router.post("/delete_pipeline_logs")
async def delete_pipeline_logs(
request: DeletePipelineLogsRequest,
kb_id: str = Query(..., description="知识库ID"),
query: DeletePipelineLogsQuery = Depends(),
body: Optional[DeletePipelineLogsBody] = None,
current_user = Depends(get_current_user)
):
if not kb_id:
"""删除流水线日志"""
if not query.kb_id:
return get_json_result(data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
log_ids = request.log_ids
if body is None:
body = DeletePipelineLogsBody(log_ids=[])
log_ids = body.log_ids or []
PipelineOperationLogService.delete_by_ids(log_ids)
@@ -608,9 +632,10 @@ async def delete_pipeline_logs(
@router.get("/pipeline_log_detail")
async def pipeline_log_detail(
log_id: str = Query(..., description="日志ID"),
log_id: str = Query(..., description="流水线日志ID"),
current_user = Depends(get_current_user)
):
"""获取流水线日志详情"""
if not log_id:
return get_json_result(data=False, message='Lack of "Pipeline log ID"', code=settings.RetCode.ARGUMENT_ERROR)
@@ -623,9 +648,10 @@ async def pipeline_log_detail(
@router.post("/run_graphrag")
async def run_graphrag(
request: RunGraphRAGRequest,
request: RunGraphragRequest,
current_user = Depends(get_current_user)
):
"""运行 GraphRAG"""
kb_id = request.kb_id
if not kb_id:
return get_error_data_result(message='Lack of "KB ID"')
@@ -660,7 +686,7 @@ async def run_graphrag(
sample_document = documents[0]
document_ids = [document["id"] for document in documents]
task_id = queue_raptor_o_graphrag_tasks(doc=sample_document, ty="graphrag", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids))
task_id = queue_raptor_o_graphrag_tasks(sample_doc_id=sample_document, ty="graphrag", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids))
if not KnowledgebaseService.update_by_id(kb.id, {"graphrag_task_id": task_id}):
logging.warning(f"Cannot save graphrag_task_id for kb {kb_id}")
@@ -673,6 +699,7 @@ async def trace_graphrag(
kb_id: str = Query(..., description="知识库ID"),
current_user = Depends(get_current_user)
):
"""追踪 GraphRAG 任务"""
if not kb_id:
return get_error_data_result(message='Lack of "KB ID"')
@@ -696,6 +723,7 @@ async def run_raptor(
request: RunRaptorRequest,
current_user = Depends(get_current_user)
):
"""运行 RAPTOR"""
kb_id = request.kb_id
if not kb_id:
return get_error_data_result(message='Lack of "KB ID"')
@@ -730,7 +758,7 @@ async def run_raptor(
sample_document = documents[0]
document_ids = [document["id"] for document in documents]
task_id = queue_raptor_o_graphrag_tasks(doc=sample_document, ty="raptor", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids))
task_id = queue_raptor_o_graphrag_tasks(sample_doc_id=sample_document, ty="raptor", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids))
if not KnowledgebaseService.update_by_id(kb.id, {"raptor_task_id": task_id}):
logging.warning(f"Cannot save raptor_task_id for kb {kb_id}")
@@ -743,6 +771,7 @@ async def trace_raptor(
kb_id: str = Query(..., description="知识库ID"),
current_user = Depends(get_current_user)
):
"""追踪 RAPTOR 任务"""
if not kb_id:
return get_error_data_result(message='Lack of "KB ID"')
@@ -766,6 +795,7 @@ async def run_mindmap(
request: RunMindmapRequest,
current_user = Depends(get_current_user)
):
"""运行 Mindmap"""
kb_id = request.kb_id
if not kb_id:
return get_error_data_result(message='Lack of "KB ID"')
@@ -800,7 +830,7 @@ async def run_mindmap(
sample_document = documents[0]
document_ids = [document["id"] for document in documents]
task_id = queue_raptor_o_graphrag_tasks(doc=sample_document, ty="mindmap", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids))
task_id = queue_raptor_o_graphrag_tasks(sample_doc_id=sample_document, ty="mindmap", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids))
if not KnowledgebaseService.update_by_id(kb.id, {"mindmap_task_id": task_id}):
logging.warning(f"Cannot save mindmap_task_id for kb {kb_id}")
@@ -813,6 +843,7 @@ async def trace_mindmap(
kb_id: str = Query(..., description="知识库ID"),
current_user = Depends(get_current_user)
):
"""追踪 Mindmap 任务"""
if not kb_id:
return get_error_data_result(message='Lack of "KB ID"')
@@ -834,33 +865,43 @@ async def trace_mindmap(
@router.delete("/unbind_task")
async def delete_kb_task(
kb_id: str = Query(..., description="知识库ID"),
pipeline_task_type: str = Query(..., description="管道任务类型"),
pipeline_task_type: str = Query(..., description="流水线任务类型"),
current_user = Depends(get_current_user)
):
"""解绑任务"""
if not kb_id:
return get_error_data_result(message='Lack of "KB ID"')
ok, kb = KnowledgebaseService.get_by_id(kb_id)
if not ok:
return get_json_result(data=True)
if not pipeline_task_type or pipeline_task_type not in [PipelineTaskType.GRAPH_RAG, PipelineTaskType.RAPTOR, PipelineTaskType.MINDMAP]:
return get_error_data_result(message="Invalid task type")
match pipeline_task_type:
case PipelineTaskType.GRAPH_RAG:
settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]}, search.index_name(kb.tenant_id), kb_id)
kb_task_id = "graphrag_task_id"
kb_task_id_field = "graphrag_task_id"
task_id = kb.graphrag_task_id
kb_task_finish_at = "graphrag_task_finish_at"
case PipelineTaskType.RAPTOR:
kb_task_id = "raptor_task_id"
kb_task_id_field = "raptor_task_id"
task_id = kb.raptor_task_id
kb_task_finish_at = "raptor_task_finish_at"
case PipelineTaskType.MINDMAP:
kb_task_id = "mindmap_task_id"
kb_task_id_field = "mindmap_task_id"
task_id = kb.mindmap_task_id
kb_task_finish_at = "mindmap_task_finish_at"
case _:
return get_error_data_result(message="Internal Error: Invalid task type")
ok = KnowledgebaseService.update_by_id(kb_id, {kb_task_id: "", kb_task_finish_at: None})
def cancel_task(task_id):
REDIS_CONN.set(f"{task_id}-cancel", "x")
cancel_task(task_id)
ok = KnowledgebaseService.update_by_id(kb_id, {kb_task_id_field: "", kb_task_finish_at: None})
if not ok:
return server_error_response(f"Internal error: cannot delete task {pipeline_task_type}")
return get_json_result(data=True)