v0.21.1-fastapi

This commit is contained in:
2025-11-04 16:06:36 +08:00
parent 3e58c3d0e9
commit d57b5d76ae
218 changed files with 19617 additions and 72339 deletions

View File

@@ -52,6 +52,40 @@ def create_app() -> FastAPI:
openapi_url="/apispec.json"
)
# 自定义 OpenAPI schema 以支持 Bearer Token 认证
def custom_openapi():
if app.openapi_schema:
return app.openapi_schema
from fastapi.openapi.utils import get_openapi
openapi_schema = get_openapi(
title=app.title,
version=app.version,
description=app.description,
routes=app.routes,
)
# 添加安全方案定义HTTPBearer 会自动注册为 "HTTPBearer"
# 如果 components 不存在,先创建它
if "components" not in openapi_schema:
openapi_schema["components"] = {}
if "securitySchemes" not in openapi_schema["components"]:
openapi_schema["components"]["securitySchemes"] = {}
# 移除 CustomHTTPBearer如果存在只保留 HTTPBearer
if "CustomHTTPBearer" in openapi_schema["components"]["securitySchemes"]:
del openapi_schema["components"]["securitySchemes"]["CustomHTTPBearer"]
# 添加/更新 HTTPBearer 安全方案FastAPI 默认名称)
openapi_schema["components"]["securitySchemes"]["HTTPBearer"] = {
"type": "http",
"scheme": "bearer",
"bearerFormat": "JWT",
"description": "输入 token可以不带 'Bearer ' 前缀,系统会自动添加"
}
app.openapi_schema = openapi_schema
return app.openapi_schema
app.openapi = custom_openapi
# 添加CORS中间件
app.add_middleware(
CORSMiddleware,
@@ -124,22 +158,20 @@ def setup_routes(app: FastAPI):
from api.apps.user_app_fastapi import router as user_router
from api.apps.kb_app import router as kb_router
from api.apps.document_app import router as document_router
from api.apps.file_app import router as file_router
from api.apps.file2document_app import router as file2document_router
from api.apps.mcp_server_app import router as mcp_router
from api.apps.tenant_app import router as tenant_router
from api.apps.llm_app import router as llm_router
from api.apps.chunk_app import router as chunk_router
from api.apps.mcp_server_app import router as mcp_router
from api.apps.canvas_app import router as canvas_router
app.include_router(user_router, prefix=f"/{API_VERSION}/user", tags=["User"])
app.include_router(kb_router, prefix=f"/{API_VERSION}/kb", tags=["KB"])
app.include_router(kb_router, prefix=f"/{API_VERSION}/kb", tags=["KnowledgeBase"])
app.include_router(document_router, prefix=f"/{API_VERSION}/document", tags=["Document"])
app.include_router(file_router, prefix=f"/{API_VERSION}/file", tags=["File"])
app.include_router(file2document_router, prefix=f"/{API_VERSION}/file2document", tags=["File2Document"])
app.include_router(mcp_router, prefix=f"/{API_VERSION}/mcp", tags=["MCP"])
app.include_router(tenant_router, prefix=f"/{API_VERSION}/tenant", tags=["Tenant"])
app.include_router(llm_router, prefix=f"/{API_VERSION}/llm", tags=["LLM"])
app.include_router(chunk_router, prefix=f"/{API_VERSION}/chunk", tags=["Chunk"])
app.include_router(mcp_router, prefix=f"/{API_VERSION}/mcp", tags=["MCP"])
app.include_router(canvas_router, prefix=f"/{API_VERSION}/canvas", tags=["Canvas"])
def get_current_user_from_token(authorization: str):
"""从token获取当前用户"""

View File

@@ -490,7 +490,7 @@ def upload():
@manager.route('/document/upload_and_parse', methods=['POST']) # noqa: F821
@validate_request("conversation_id")
async def upload_parse():
def upload_parse():
token = request.headers.get('Authorization').split()[1]
objs = APIToken.query(token=token)
if not objs:
@@ -507,7 +507,7 @@ async def upload_parse():
return get_json_result(
data=False, message='No file selected!', code=settings.RetCode.ARGUMENT_ERROR)
doc_ids = await doc_upload_and_parse(request.form.get("conversation_id"), file_objs, objs[0].tenant_id)
doc_ids = doc_upload_and_parse(request.form.get("conversation_id"), file_objs, objs[0].tenant_id)
return get_json_result(data=doc_ids)
@@ -536,7 +536,7 @@ def list_chunks():
)
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
res = settings.retrievaler.chunk_list(doc_id, tenant_id, kb_ids)
res = settings.retriever.chunk_list(doc_id, tenant_id, kb_ids)
res = [
{
"content": res_item["content_with_weight"],
@@ -884,7 +884,7 @@ def retrieval():
if req.get("keyword", False):
chat_mdl = LLMBundle(kbs[0].tenant_id, LLMType.CHAT)
question += keyword_extraction(chat_mdl, question)
ranks = settings.retrievaler.retrieval(question, embd_mdl, kbs[0].tenant_id, kb_ids, page, size,
ranks = settings.retriever.retrieval(question, embd_mdl, kbs[0].tenant_id, kb_ids, page, size,
similarity_threshold, vector_similarity_weight, top,
doc_ids, rerank_mdl=rerank_mdl, highlight= highlight,
rank_feature=label_question(question, kbs))

View File

@@ -18,11 +18,12 @@ import logging
import re
import sys
from functools import partial
from typing import Optional
import flask
import trio
from flask import request, Response
from flask_login import login_required, current_user
from fastapi import APIRouter, Depends, Query, Header, UploadFile, File, Form
from fastapi.responses import StreamingResponse, Response
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from agent.component import LLM
from api import settings
@@ -36,7 +37,24 @@ from api.db.services.user_service import TenantService
from api.db.services.user_canvas_version import UserCanvasVersionService
from api.settings import RetCode
from api.utils import get_uuid
from api.utils.api_utils import get_json_result, server_error_response, validate_request, get_data_error_result
from api.utils.api_utils import get_json_result, server_error_response, get_data_error_result
from api.apps.models.auth_dependencies import get_current_user
from api.apps.models.canvas_models import (
DeleteCanvasRequest,
SaveCanvasRequest,
CompletionRequest,
RerunRequest,
ResetCanvasRequest,
InputFormQuery,
DebugRequest,
TestDBConnectRequest,
ListCanvasQuery,
SettingRequest,
TraceQuery,
ListSessionsQuery,
DownloadQuery,
UploadQuery,
)
from agent.canvas import Canvas
from peewee import MySQLDatabase, PostgresqlDatabase
from api.db.db_models import APIToken, Task
@@ -47,18 +65,25 @@ from rag.flow.pipeline import Pipeline
from rag.nlp import search
from rag.utils.redis_conn import REDIS_CONN
@manager.route('/templates', methods=['GET']) # noqa: F821
@login_required
def templates():
return get_json_result(data=[c.to_dict() for c in CanvasTemplateService.query(canvas_category=CanvasCategory.Agent)])
# 创建路由器
router = APIRouter()
@manager.route('/rm', methods=['POST']) # noqa: F821
@validate_request("canvas_ids")
@login_required
def rm():
for i in request.json["canvas_ids"]:
@router.get('/templates')
async def templates(
current_user = Depends(get_current_user)
):
"""获取画布模板列表"""
return get_json_result(data=[c.to_dict() for c in CanvasTemplateService.get_all()])
@router.post('/rm')
async def rm(
request: DeleteCanvasRequest,
current_user = Depends(get_current_user)
):
"""删除画布"""
for i in request.canvas_ids:
if not UserCanvasService.accessible(i, current_user.id):
return get_json_result(
data=False, message='Only owner of canvas authorized for this operation.',
@@ -67,16 +92,18 @@ def rm():
return get_json_result(data=True)
@manager.route('/set', methods=['POST']) # noqa: F821
@validate_request("dsl", "title")
@login_required
def save():
req = request.json
@router.post('/set')
async def save(
request: SaveCanvasRequest,
current_user = Depends(get_current_user)
):
"""保存画布"""
req = request.model_dump(exclude_unset=True)
if not isinstance(req["dsl"], str):
req["dsl"] = json.dumps(req["dsl"], ensure_ascii=False)
req["dsl"] = json.loads(req["dsl"])
cate = req.get("canvas_category", CanvasCategory.Agent)
if "id" not in req:
if "id" not in req or not req.get("id"):
req["user_id"] = current_user.id
if UserCanvasService.query(user_id=current_user.id, title=req["title"].strip(), canvas_category=cate):
return get_data_error_result(message=f"{req['title'].strip()} already exists.")
@@ -95,21 +122,28 @@ def save():
return get_json_result(data=req)
@manager.route('/get/<canvas_id>', methods=['GET']) # noqa: F821
@login_required
def get(canvas_id):
@router.get('/get/{canvas_id}')
async def get(
canvas_id: str,
current_user = Depends(get_current_user)
):
"""获取画布详情"""
if not UserCanvasService.accessible(canvas_id, current_user.id):
return get_data_error_result(message="canvas not found.")
e, c = UserCanvasService.get_by_canvas_id(canvas_id)
return get_json_result(data=c)
@manager.route('/getsse/<canvas_id>', methods=['GET']) # type: ignore # noqa: F821
def getsse(canvas_id):
token = request.headers.get('Authorization').split()
if len(token) != 2:
@router.get('/getsse/{canvas_id}')
async def getsse(
canvas_id: str,
authorization: str = Header(..., description="Authorization header")
):
"""获取画布详情SSE通过API token认证"""
token_parts = authorization.split()
if len(token_parts) != 2:
return get_data_error_result(message='Authorization is not valid!"')
token = token[1]
token = token_parts[1]
objs = APIToken.query(beta=token)
if not objs:
return get_data_error_result(message='Authentication error: API key is invalid!"')
@@ -126,21 +160,23 @@ def getsse(canvas_id):
return get_json_result(data=c.to_dict())
@manager.route('/completion', methods=['POST']) # noqa: F821
@validate_request("id")
@login_required
def run():
req = request.json
query = req.get("query", "")
files = req.get("files", [])
inputs = req.get("inputs", {})
user_id = req.get("user_id", current_user.id)
if not UserCanvasService.accessible(req["id"], current_user.id):
@router.post('/completion')
async def run(
request: CompletionRequest,
current_user = Depends(get_current_user)
):
"""运行画布(完成/执行)"""
query = request.query or ""
files = request.files or []
inputs = request.inputs or {}
user_id = request.user_id or current_user.id
if not UserCanvasService.accessible(request.id, current_user.id):
return get_json_result(
data=False, message='Only owner of canvas authorized for this operation.',
code=RetCode.OPERATING_ERROR)
e, cvs = UserCanvasService.get_by_id(req["id"])
e, cvs = UserCanvasService.get_by_id(request.id)
if not e:
return get_data_error_result(message="canvas not found.")
@@ -149,43 +185,48 @@ def run():
if cvs.canvas_category == CanvasCategory.DataFlow:
task_id = get_uuid()
Pipeline(cvs.dsl, tenant_id=current_user.id, doc_id=CANVAS_DEBUG_DOC_ID, task_id=task_id, flow_id=req["id"])
ok, error_message = queue_dataflow(tenant_id=user_id, flow_id=req["id"], task_id=task_id, file=files[0], priority=0)
Pipeline(cvs.dsl, tenant_id=current_user.id, doc_id=CANVAS_DEBUG_DOC_ID, task_id=task_id, flow_id=request.id)
ok, error_message = queue_dataflow(tenant_id=user_id, flow_id=request.id, task_id=task_id, file=files[0] if files else None, priority=0)
if not ok:
return get_data_error_result(message=error_message)
return get_json_result(data={"message_id": task_id})
try:
canvas = Canvas(cvs.dsl, current_user.id, req["id"])
canvas = Canvas(cvs.dsl, current_user.id, request.id)
except Exception as e:
return server_error_response(e)
def sse():
async def sse():
nonlocal canvas, user_id
try:
for ans in canvas.run(query=query, files=files, user_id=user_id, inputs=inputs):
yield "data:" + json.dumps(ans, ensure_ascii=False) + "\n\n"
cvs.dsl = json.loads(str(canvas))
UserCanvasService.update_by_id(req["id"], cvs.to_dict())
UserCanvasService.update_by_id(request.id, cvs.to_dict())
except Exception as e:
logging.exception(e)
yield "data:" + json.dumps({"code": 500, "message": str(e), "data": False}, ensure_ascii=False) + "\n\n"
resp = Response(sse(), mimetype="text/event-stream")
resp.headers.add_header("Cache-control", "no-cache")
resp.headers.add_header("Connection", "keep-alive")
resp.headers.add_header("X-Accel-Buffering", "no")
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
return resp
return StreamingResponse(
sse(),
media_type="text/event-stream",
headers={
"Cache-control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
"Content-Type": "text/event-stream; charset=utf-8"
}
)
@manager.route('/rerun', methods=['POST']) # noqa: F821
@validate_request("id", "dsl", "component_id")
@login_required
def rerun():
req = request.json
doc = PipelineOperationLogService.get_documents_info(req["id"])
@router.post('/rerun')
async def rerun(
request: RerunRequest,
current_user = Depends(get_current_user)
):
"""重新运行流水线"""
doc = PipelineOperationLogService.get_documents_info(request.id)
if not doc:
return get_data_error_result(message="Document not found.")
doc = doc[0]
@@ -198,19 +239,22 @@ def rerun():
doc["chunk_num"] = 0
doc["token_num"] = 0
DocumentService.clear_chunk_num_when_rerun(doc["id"])
DocumentService.update_by_id(id, doc)
TaskService.filter_delete([Task.doc_id == id])
DocumentService.update_by_id(request.id, doc)
TaskService.filter_delete([Task.doc_id == request.id])
dsl = req["dsl"]
dsl["path"] = [req["component_id"]]
PipelineOperationLogService.update_by_id(req["id"], {"dsl": dsl})
queue_dataflow(tenant_id=current_user.id, flow_id=req["id"], task_id=get_uuid(), doc_id=doc["id"], priority=0, rerun=True)
dsl = request.dsl
dsl["path"] = [request.component_id]
PipelineOperationLogService.update_by_id(request.id, {"dsl": dsl})
queue_dataflow(tenant_id=current_user.id, flow_id=request.id, task_id=get_uuid(), doc_id=doc["id"], priority=0, rerun=True)
return get_json_result(data=True)
@manager.route('/cancel/<task_id>', methods=['PUT']) # noqa: F821
@login_required
def cancel(task_id):
@router.put('/cancel/{task_id}')
async def cancel(
task_id: str,
current_user = Depends(get_current_user)
):
"""取消任务"""
try:
REDIS_CONN.set(f"{task_id}-cancel", "x")
except Exception as e:
@@ -218,36 +262,43 @@ def cancel(task_id):
return get_json_result(data=True)
@manager.route('/reset', methods=['POST']) # noqa: F821
@validate_request("id")
@login_required
def reset():
req = request.json
if not UserCanvasService.accessible(req["id"], current_user.id):
@router.post('/reset')
async def reset(
request: ResetCanvasRequest,
current_user = Depends(get_current_user)
):
"""重置画布"""
if not UserCanvasService.accessible(request.id, current_user.id):
return get_json_result(
data=False, message='Only owner of canvas authorized for this operation.',
code=RetCode.OPERATING_ERROR)
try:
e, user_canvas = UserCanvasService.get_by_id(req["id"])
e, user_canvas = UserCanvasService.get_by_id(request.id)
if not e:
return get_data_error_result(message="canvas not found.")
canvas = Canvas(json.dumps(user_canvas.dsl), current_user.id)
canvas.reset()
req["dsl"] = json.loads(str(canvas))
UserCanvasService.update_by_id(req["id"], {"dsl": req["dsl"]})
return get_json_result(data=req["dsl"])
dsl = json.loads(str(canvas))
UserCanvasService.update_by_id(request.id, {"dsl": dsl})
return get_json_result(data=dsl)
except Exception as e:
return server_error_response(e)
@manager.route("/upload/<canvas_id>", methods=["POST"]) # noqa: F821
def upload(canvas_id):
@router.post("/upload/{canvas_id}")
async def upload(
canvas_id: str,
url: Optional[str] = Query(None, description="URL可选用于从URL下载"),
file: Optional[UploadFile] = File(None, description="上传的文件")
):
"""上传文件到画布"""
e, cvs = UserCanvasService.get_by_canvas_id(canvas_id)
if not e:
return get_data_error_result(message="canvas not found.")
user_id = cvs["user_id"]
def structured(filename, filetype, blob, content_type):
nonlocal user_id
if filetype == FileType.PDF.value:
@@ -267,7 +318,7 @@ def upload(canvas_id):
"preview_url": None
}
if request.args.get("url"):
if url:
from crawl4ai import (
AsyncWebCrawler,
BrowserConfig,
@@ -277,7 +328,6 @@ def upload(canvas_id):
CrawlResult
)
try:
url = request.args.get("url")
filename = re.sub(r"\?.*", "", url.split("/")[-1])
async def adownload():
browser_config = BrowserConfig(
@@ -301,61 +351,67 @@ def upload(canvas_id):
if page.pdf:
if filename.split(".")[-1].lower() != "pdf":
filename += ".pdf"
return get_json_result(data=structured(filename, "pdf", page.pdf, page.response_headers["content-type"]))
return get_json_result(data=structured(filename, "pdf", page.pdf, page.response_headers.get("content-type", "application/pdf")))
return get_json_result(data=structured(filename, "html", str(page.markdown).encode("utf-8"), page.response_headers["content-type"], user_id))
return get_json_result(data=structured(filename, "html", str(page.markdown).encode("utf-8"), page.response_headers.get("content-type", "text/html"), user_id))
except Exception as e:
return server_error_response(e)
return server_error_response(e)
file = request.files['file']
if not file:
return get_data_error_result(message="No file provided.")
try:
DocumentService.check_doc_health(user_id, file.filename)
return get_json_result(data=structured(file.filename, filename_type(file.filename), file.read(), file.content_type))
blob = await file.read()
return get_json_result(data=structured(file.filename, filename_type(file.filename), blob, file.content_type or "application/octet-stream"))
except Exception as e:
return server_error_response(e)
return server_error_response(e)
@manager.route('/input_form', methods=['GET']) # noqa: F821
@login_required
def input_form():
cvs_id = request.args.get("id")
cpn_id = request.args.get("component_id")
@router.get('/input_form')
async def input_form(
id: str = Query(..., description="画布ID"),
component_id: str = Query(..., description="组件ID"),
current_user = Depends(get_current_user)
):
"""获取组件输入表单"""
try:
e, user_canvas = UserCanvasService.get_by_id(cvs_id)
e, user_canvas = UserCanvasService.get_by_id(id)
if not e:
return get_data_error_result(message="canvas not found.")
if not UserCanvasService.query(user_id=current_user.id, id=cvs_id):
if not UserCanvasService.query(user_id=current_user.id, id=id):
return get_json_result(
data=False, message='Only owner of canvas authorized for this operation.',
code=RetCode.OPERATING_ERROR)
canvas = Canvas(json.dumps(user_canvas.dsl), current_user.id)
return get_json_result(data=canvas.get_component_input_form(cpn_id))
return get_json_result(data=canvas.get_component_input_form(component_id))
except Exception as e:
return server_error_response(e)
@manager.route('/debug', methods=['POST']) # noqa: F821
@validate_request("id", "component_id", "params")
@login_required
def debug():
req = request.json
if not UserCanvasService.accessible(req["id"], current_user.id):
@router.post('/debug')
async def debug(
request: DebugRequest,
current_user = Depends(get_current_user)
):
"""调试组件"""
if not UserCanvasService.accessible(request.id, current_user.id):
return get_json_result(
data=False, message='Only owner of canvas authorized for this operation.',
code=RetCode.OPERATING_ERROR)
try:
e, user_canvas = UserCanvasService.get_by_id(req["id"])
e, user_canvas = UserCanvasService.get_by_id(request.id)
canvas = Canvas(json.dumps(user_canvas.dsl), current_user.id)
canvas.reset()
canvas.message_id = get_uuid()
component = canvas.get_component(req["component_id"])["obj"]
component = canvas.get_component(request.component_id)["obj"]
component.reset()
if isinstance(component, LLM):
component.set_debug_inputs(req["params"])
component.invoke(**{k: o["value"] for k,o in req["params"].items()})
component.set_debug_inputs(request.params)
component.invoke(**{k: o["value"] for k,o in request.params.items()})
outputs = component.output()
for k in outputs.keys():
if isinstance(outputs[k], partial):
@@ -368,11 +424,13 @@ def debug():
return server_error_response(e)
@manager.route('/test_db_connect', methods=['POST']) # noqa: F821
@validate_request("db_type", "database", "username", "host", "port", "password")
@login_required
def test_db_connect():
req = request.json
@router.post('/test_db_connect')
async def test_db_connect(
request: TestDBConnectRequest,
current_user = Depends(get_current_user)
):
"""测试数据库连接"""
req = request.model_dump()
try:
if req["db_type"] in ["mysql", "mariadb"]:
db = MySQLDatabase(req["database"], user=req["username"], host=req["host"], port=req["port"],
@@ -409,6 +467,49 @@ def test_db_connect():
ibm_db.fetch_assoc(stmt)
ibm_db.close(conn)
return get_json_result(data="Database Connection Successful!")
elif req["db_type"] == 'trino':
def _parse_catalog_schema(db: str):
if not db:
return None, None
if "." in db:
c, s = db.split(".", 1)
elif "/" in db:
c, s = db.split("/", 1)
else:
c, s = db, "default"
return c, s
try:
import trino
import os
from trino.auth import BasicAuthentication
except Exception:
return server_error_response("Missing dependency 'trino'. Please install: pip install trino")
catalog, schema = _parse_catalog_schema(req["database"])
if not catalog:
return server_error_response("For Trino, 'database' must be 'catalog.schema' or at least 'catalog'.")
http_scheme = "https" if os.environ.get("TRINO_USE_TLS", "0") == "1" else "http"
auth = None
if http_scheme == "https" and req.get("password"):
auth = BasicAuthentication(req.get("username") or "ragflow", req["password"])
conn = trino.dbapi.connect(
host=req["host"],
port=int(req["port"] or 8080),
user=req["username"] or "ragflow",
catalog=catalog,
schema=schema or "default",
http_scheme=http_scheme,
auth=auth
)
cur = conn.cursor()
cur.execute("SELECT 1")
cur.fetchall()
cur.close()
conn.close()
return get_json_result(data="Database Connection Successful!")
else:
return server_error_response("Unsupported database type.")
if req["db_type"] != 'mssql':
@@ -421,42 +522,48 @@ def test_db_connect():
#api get list version dsl of canvas
@manager.route('/getlistversion/<canvas_id>', methods=['GET']) # noqa: F821
@login_required
def getlistversion(canvas_id):
@router.get('/getlistversion/{canvas_id}')
async def getlistversion(
canvas_id: str,
current_user = Depends(get_current_user)
):
"""获取画布版本列表"""
try:
list =sorted([c.to_dict() for c in UserCanvasVersionService.list_by_canvas_id(canvas_id)], key=lambda x: x["update_time"]*-1)
list = sorted([c.to_dict() for c in UserCanvasVersionService.list_by_canvas_id(canvas_id)], key=lambda x: x["update_time"]*-1)
return get_json_result(data=list)
except Exception as e:
return get_data_error_result(message=f"Error getting history files: {e}")
#api get version dsl of canvas
@manager.route('/getversion/<version_id>', methods=['GET']) # noqa: F821
@login_required
def getversion( version_id):
@router.get('/getversion/{version_id}')
async def getversion(
version_id: str,
current_user = Depends(get_current_user)
):
"""获取画布版本详情"""
try:
e, version = UserCanvasVersionService.get_by_id(version_id)
if version:
return get_json_result(data=version.to_dict())
return get_data_error_result(message="Version not found.")
except Exception as e:
return get_json_result(data=f"Error getting history file: {e}")
return get_data_error_result(message=f"Error getting history file: {e}")
@manager.route('/list', methods=['GET']) # noqa: F821
@login_required
def list_canvas():
keywords = request.args.get("keywords", "")
page_number = int(request.args.get("page", 0))
items_per_page = int(request.args.get("page_size", 0))
orderby = request.args.get("orderby", "create_time")
canvas_category = request.args.get("canvas_category")
if request.args.get("desc", "true").lower() == "false":
desc = False
else:
desc = True
owner_ids = [id for id in request.args.get("owner_ids", "").strip().split(",") if id]
@router.get('/list')
async def list_canvas(
query: ListCanvasQuery = Depends(),
current_user = Depends(get_current_user)
):
"""列出画布"""
keywords = query.keywords or ""
page_number = int(query.page or 0)
items_per_page = int(query.page_size or 0)
orderby = query.orderby or "create_time"
canvas_category = query.canvas_category
desc = query.desc.lower() == "false" if query.desc else True
owner_ids = [id for id in (query.owner_ids or "").strip().split(",") if id]
if not owner_ids:
tenants = TenantService.get_joined_tenants_by_user_id(current_user.id)
tenants = [m["tenant_id"] for m in tenants]
@@ -472,68 +579,73 @@ def list_canvas():
return get_json_result(data={"canvas": canvas, "total": total})
@manager.route('/setting', methods=['POST']) # noqa: F821
@validate_request("id", "title", "permission")
@login_required
def setting():
req = request.json
req["user_id"] = current_user.id
if not UserCanvasService.accessible(req["id"], current_user.id):
@router.post('/setting')
async def setting(
request: SettingRequest,
current_user = Depends(get_current_user)
):
"""设置画布"""
if not UserCanvasService.accessible(request.id, current_user.id):
return get_json_result(
data=False, message='Only owner of canvas authorized for this operation.',
code=RetCode.OPERATING_ERROR)
e,flow = UserCanvasService.get_by_id(req["id"])
e, flow = UserCanvasService.get_by_id(request.id)
if not e:
return get_data_error_result(message="canvas not found.")
flow = flow.to_dict()
flow["title"] = req["title"]
flow["title"] = request.title
for key in ["description", "permission", "avatar"]:
if value := req.get(key):
value = getattr(request, key, None)
if value:
flow[key] = value
num= UserCanvasService.update_by_id(req["id"], flow)
num = UserCanvasService.update_by_id(request.id, flow)
return get_json_result(data=num)
@manager.route('/trace', methods=['GET']) # noqa: F821
def trace():
cvs_id = request.args.get("canvas_id")
msg_id = request.args.get("message_id")
@router.get('/trace')
async def trace(
canvas_id: str = Query(..., description="画布ID"),
message_id: str = Query(..., description="消息ID")
):
"""追踪日志"""
try:
bin = REDIS_CONN.get(f"{cvs_id}-{msg_id}-logs")
bin = REDIS_CONN.get(f"{canvas_id}-{message_id}-logs")
if not bin:
return get_json_result(data={})
return get_json_result(data=json.loads(bin.encode("utf-8")))
except Exception as e:
logging.exception(e)
return server_error_response(e)
@manager.route('/<canvas_id>/sessions', methods=['GET']) # noqa: F821
@login_required
def sessions(canvas_id):
@router.get('/{canvas_id}/sessions')
async def sessions(
canvas_id: str,
query: ListSessionsQuery = Depends(),
current_user = Depends(get_current_user)
):
"""列出画布会话"""
tenant_id = current_user.id
if not UserCanvasService.accessible(canvas_id, tenant_id):
return get_json_result(
data=False, message='Only owner of canvas authorized for this operation.',
code=RetCode.OPERATING_ERROR)
user_id = request.args.get("user_id")
page_number = int(request.args.get("page", 1))
items_per_page = int(request.args.get("page_size", 30))
keywords = request.args.get("keywords")
from_date = request.args.get("from_date")
to_date = request.args.get("to_date")
orderby = request.args.get("orderby", "update_time")
if request.args.get("desc") == "False" or request.args.get("desc") == "false":
desc = False
else:
desc = True
user_id = query.user_id
page_number = int(query.page or 1)
items_per_page = int(query.page_size or 30)
keywords = query.keywords
from_date = query.from_date
to_date = query.to_date
orderby = query.orderby or "update_time"
desc = query.desc.lower() in ["false", "False"] if query.desc else True
# dsl defaults to True in all cases except for False and false
include_dsl = request.args.get("dsl") != "False" and request.args.get("dsl") != "false"
include_dsl = query.dsl.lower() not in ["false", "False"] if query.dsl else True
total, sess = API4ConversationService.get_list(canvas_id, tenant_id, page_number, items_per_page, orderby, desc,
None, user_id, include_dsl, keywords, from_date, to_date)
try:
@@ -542,9 +654,11 @@ def sessions(canvas_id):
return server_error_response(e)
@manager.route('/prompts', methods=['GET']) # noqa: F821
@login_required
def prompts():
@router.get('/prompts')
async def prompts(
current_user = Depends(get_current_user)
):
"""获取提示词模板"""
from rag.prompts.generator import ANALYZE_TASK_SYSTEM, ANALYZE_TASK_USER, NEXT_STEP, REFLECT, CITATION_PROMPT_TEMPLATE
return get_json_result(data={
"task_analysis": ANALYZE_TASK_SYSTEM +"\n\n"+ ANALYZE_TASK_USER,
@@ -556,9 +670,11 @@ def prompts():
})
@manager.route('/download', methods=['GET']) # noqa: F821
def download():
id = request.args.get("id")
created_by = request.args.get("created_by")
@router.get('/download')
async def download(
id: str = Query(..., description="文件ID"),
created_by: str = Query(..., description="创建者ID")
):
"""下载文件"""
blob = FileService.get_blob(created_by, id)
return flask.make_response(blob)
return Response(content=blob)

View File

@@ -16,10 +16,19 @@
import datetime
import json
import re
from typing import Optional, List
import xxhash
from fastapi import APIRouter, Depends, Query, HTTPException
from fastapi import APIRouter, Depends, Query
from api.apps.models.auth_dependencies import get_current_user
from api.apps.models.chunk_models import (
ListChunksRequest,
SetChunkRequest,
SwitchChunksRequest,
DeleteChunksRequest,
CreateChunkRequest,
RetrievalTestRequest,
)
from api import settings
from api.db import LLMType, ParserType
@@ -29,17 +38,7 @@ from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle
from api.db.services.search_service import SearchService
from api.db.services.user_service import UserTenantService
from api.models.chunk_models import (
ListChunkRequest,
GetChunkRequest,
SetChunkRequest,
SwitchChunkRequest,
RemoveChunkRequest,
CreateChunkRequest,
RetrievalTestRequest,
KnowledgeGraphRequest
)
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, get_current_user
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response
from rag.app.qa import beAdoc, rmPrefix
from rag.app.tag import label_question
from rag.nlp import rag_tokenizer, search
@@ -47,19 +46,20 @@ from rag.prompts.generator import gen_meta_filter, cross_languages, keyword_extr
from rag.settings import PAGERANK_FLD
from rag.utils import rmSpace
# 创建 FastAPI 路由器
# 创建路由器
router = APIRouter()
@router.post('/list')
async def list_chunk(
request: ListChunkRequest,
request: ListChunksRequest,
current_user = Depends(get_current_user)
):
"""列出文档块"""
doc_id = request.doc_id
page = request.page
size = request.size
question = request.keywords
page = request.page or 1
size = request.size or 30
question = request.keywords or ""
try:
tenant_id = DocumentService.get_tenant_id(doc_id)
if not tenant_id:
@@ -73,7 +73,7 @@ async def list_chunk(
}
if request.available_int is not None:
query["available_int"] = int(request.available_int)
sres = settings.retrievaler.search(query, search.index_name(tenant_id), kb_ids, highlight=True)
sres = settings.retriever.search(query, search.index_name(tenant_id), kb_ids, highlight=["content_ltks"])
res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()}
for id in sres.ids:
d = {
@@ -105,6 +105,7 @@ async def get(
chunk_id: str = Query(..., description="块ID"),
current_user = Depends(get_current_user)
):
"""获取文档块"""
try:
chunk = None
tenants = UserTenantService.query(user_id=current_user.id)
@@ -138,6 +139,7 @@ async def set(
request: SetChunkRequest,
current_user = Depends(get_current_user)
):
"""设置文档块"""
d = {
"id": request.chunk_id,
"content_with_weight": request.content_with_weight}
@@ -192,9 +194,10 @@ async def set(
@router.post('/switch')
async def switch(
request: SwitchChunkRequest,
request: SwitchChunksRequest,
current_user = Depends(get_current_user)
):
"""切换文档块状态"""
try:
e, doc = DocumentService.get_by_id(request.doc_id)
if not e:
@@ -212,9 +215,10 @@ async def switch(
@router.post('/rm')
async def rm(
request: RemoveChunkRequest,
request: DeleteChunksRequest,
current_user = Depends(get_current_user)
):
"""删除文档块"""
from rag.utils.storage_factory import STORAGE_IMPL
try:
e, doc = DocumentService.get_by_id(request.doc_id)
@@ -240,15 +244,16 @@ async def create(
request: CreateChunkRequest,
current_user = Depends(get_current_user)
):
"""创建文档块"""
chunck_id = xxhash.xxh64((request.content_with_weight + request.doc_id).encode("utf-8")).hexdigest()
d = {"id": chunck_id, "content_ltks": rag_tokenizer.tokenize(request.content_with_weight),
"content_with_weight": request.content_with_weight}
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
d["important_kwd"] = request.important_kwd
d["important_kwd"] = request.important_kwd or []
if not isinstance(d["important_kwd"], list):
return get_data_error_result(message="`important_kwd` is required to be a list")
d["important_tks"] = rag_tokenizer.tokenize(" ".join(d["important_kwd"]))
d["question_kwd"] = request.question_kwd
d["question_kwd"] = request.question_kwd or []
if not isinstance(d["question_kwd"], list):
return get_data_error_result(message="`question_kwd` is required to be a list")
d["question_tks"] = rag_tokenizer.tokenize("\n".join(d["question_kwd"]))
@@ -296,8 +301,9 @@ async def retrieval_test(
request: RetrievalTestRequest,
current_user = Depends(get_current_user)
):
page = request.page
size = request.size
"""检索测试"""
page = request.page or 1
size = request.size or 30
question = request.question
kb_ids = request.kb_id
if isinstance(kb_ids, str):
@@ -306,10 +312,10 @@ async def retrieval_test(
return get_json_result(data=False, message='Please specify dataset firstly.',
code=settings.RetCode.DATA_ERROR)
doc_ids = request.doc_ids
use_kg = request.use_kg
top = request.top_k
langs = request.cross_languages
doc_ids = request.doc_ids or []
use_kg = request.use_kg or False
top = request.top_k or 1024
langs = request.cross_languages or []
tenant_ids = []
if request.search_id:
@@ -358,15 +364,16 @@ async def retrieval_test(
question += keyword_extraction(chat_mdl, question)
labels = label_question(question, [kb])
ranks = settings.retrievaler.retrieval(question, embd_mdl, tenant_ids, kb_ids, page, size,
float(request.similarity_threshold),
float(request.vector_similarity_weight),
ranks = settings.retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids, page, size,
float(request.similarity_threshold or 0.0),
float(request.vector_similarity_weight or 0.3),
top,
doc_ids, rerank_mdl=rerank_mdl, highlight=request.highlight,
doc_ids, rerank_mdl=rerank_mdl,
highlight=request.highlight or False,
rank_feature=labels
)
if use_kg:
ck = settings.kg_retrievaler.retrieval(question,
ck = settings.kg_retriever.retrieval(question,
tenant_ids,
kb_ids,
embd_mdl,
@@ -391,13 +398,14 @@ async def knowledge_graph(
doc_id: str = Query(..., description="文档ID"),
current_user = Depends(get_current_user)
):
"""获取知识图谱"""
tenant_id = DocumentService.get_tenant_id(doc_id)
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
req = {
"doc_ids": [doc_id],
"knowledge_graph_kwd": ["graph", "mind_map"]
}
sres = settings.retrievaler.search(req, search.index_name(tenant_id), kb_ids)
sres = settings.retriever.search(req, search.index_name(tenant_id), kb_ids)
obj = {"graph": {}, "mind_map": {}}
for id in sres.ids[:2]:
ty = sres.field[id]["knowledge_graph_kwd"]

View File

@@ -17,15 +17,30 @@ import json
import os.path
import pathlib
import re
import traceback
from pathlib import Path
from typing import List, Optional
from typing import Optional, List
from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile, Query
from fastapi.responses import StreamingResponse
from fastapi.security import HTTPAuthorizationCredentials
from api.utils.api_utils import security
from fastapi import APIRouter, Depends, Query, UploadFile, File, Form
from fastapi.responses import Response
from api.apps.models.auth_dependencies import get_current_user
from api.apps.models.document_models import (
CreateDocumentRequest,
WebCrawlRequest,
ListDocumentsQuery,
ListDocumentsBody,
FilterDocumentsRequest,
GetDocumentInfosRequest,
ChangeStatusRequest,
DeleteDocumentRequest,
RunDocumentRequest,
RenameDocumentRequest,
ChangeParserRequest,
ChangeParserSimpleRequest,
UploadAndParseRequest,
ParseRequest,
SetMetaRequest,
)
from api import settings
from api.common.check_team_permission import check_kb_team_permission
from api.constants import FILE_NAME_LEN_LIMIT, IMG_BASE64_PREFIX
@@ -43,159 +58,29 @@ from api.utils.api_utils import (
get_data_error_result,
get_json_result,
server_error_response,
validate_request,
)
from api.utils.file_utils import filename_type, get_project_base_directory, thumbnail
from api.utils.web_utils import CONTENT_TYPE_MAP, html2pdf, is_valid_url
from deepdoc.parser.html_parser import RAGFlowHtmlParser
from rag.nlp import search
from rag.nlp import search, rag_tokenizer
from rag.utils.storage_factory import STORAGE_IMPL
from pydantic import BaseModel
from api.db.db_models import User
# Security
# Pydantic models for request/response
class WebCrawlRequest(BaseModel):
kb_id: str
name: str
url: str
class CreateDocumentRequest(BaseModel):
name: str
kb_id: str
class DocumentListRequest(BaseModel):
run_status: List[str] = []
types: List[str] = []
suffix: List[str] = []
class DocumentFilterRequest(BaseModel):
kb_id: str
keywords: str = ""
run_status: List[str] = []
types: List[str] = []
suffix: List[str] = []
class DocumentInfosRequest(BaseModel):
doc_ids: List[str]
class ChangeStatusRequest(BaseModel):
doc_ids: List[str]
status: str
class RemoveDocumentRequest(BaseModel):
doc_id: List[str]
class RunDocumentRequest(BaseModel):
doc_ids: List[str]
run: int
delete: bool = False
class RenameDocumentRequest(BaseModel):
doc_id: str
name: str
class ChangeParserRequest(BaseModel):
doc_id: str
parser_id: str
pipeline_id: Optional[str] = None
parser_config: Optional[dict] = None
class UploadAndParseRequest(BaseModel):
conversation_id: str
class ParseRequest(BaseModel):
url: Optional[str] = None
class SetMetaRequest(BaseModel):
doc_id: str
meta: str
# Dependency injection
async def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(security)):
"""获取当前用户"""
from api.db import StatusEnum
from api.db.services.user_service import UserService
from fastapi import HTTPException, status
import logging
try:
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
except ImportError:
# 如果没有itsdangerous使用jwt作为替代
import jwt
Serializer = jwt
jwt = Serializer(secret_key=settings.SECRET_KEY)
authorization = credentials.credentials
if authorization:
try:
access_token = str(jwt.loads(authorization))
if not access_token or not access_token.strip():
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authentication attempt with empty access token"
)
# Access tokens should be UUIDs (32 hex characters)
if len(access_token.strip()) < 32:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=f"Authentication attempt with invalid token format: {len(access_token)} chars"
)
user = UserService.query(
access_token=access_token, status=StatusEnum.VALID.value
)
if user:
if not user[0].access_token or not user[0].access_token.strip():
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=f"User {user[0].email} has empty access_token in database"
)
return user[0]
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid access token"
)
except Exception as e:
logging.warning(f"load_user got exception {e}")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid access token"
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authorization header required"
)
# Create router
# 创建路由器
router = APIRouter()
@router.post("/upload")
async def upload(
kb_id: str = Form(...),
file: List[UploadFile] = File(...),
files: List[UploadFile] = File(...),
current_user = Depends(get_current_user)
):
if not kb_id:
return get_json_result(data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
if not file:
"""上传文档"""
if not files:
return get_json_result(data=False, message="No file part!", code=settings.RetCode.ARGUMENT_ERROR)
# Use UploadFile directly - file is already a list from multiple file fields
file_objs = file
for file_obj in file_objs:
if file_obj.filename == "":
for file_obj in files:
if not file_obj.filename or file_obj.filename == "":
return get_json_result(data=False, message="No file selected!", code=settings.RetCode.ARGUMENT_ERROR)
if len(file_obj.filename.encode("utf-8")) > FILE_NAME_LEN_LIMIT:
return get_json_result(data=False, message=f"File name must be {FILE_NAME_LEN_LIMIT} bytes or less.", code=settings.RetCode.ARGUMENT_ERROR)
@@ -206,29 +91,30 @@ async def upload(
if not check_kb_team_permission(kb, current_user.id):
return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR)
err, files = await FileService.upload_document(kb, file_objs, current_user.id)
err, uploaded_files = FileService.upload_document(kb, files, current_user.id)
if err:
return get_json_result(data=files, message="\n".join(err), code=settings.RetCode.SERVER_ERROR)
return get_json_result(data=uploaded_files, message="\n".join(err), code=settings.RetCode.SERVER_ERROR)
if not files:
return get_json_result(data=files, message="There seems to be an issue with your file format. Please verify it is correct and not corrupted.", code=settings.RetCode.DATA_ERROR)
files = [f[0] for f in files] # remove the blob
if not uploaded_files:
return get_json_result(data=uploaded_files, message="There seems to be an issue with your file format. Please verify it is correct and not corrupted.", code=settings.RetCode.DATA_ERROR)
files_result = [f[0] for f in uploaded_files] # remove the blob
return get_json_result(data=files)
return get_json_result(data=files_result)
@router.post("/web_crawl")
async def web_crawl(
req: WebCrawlRequest,
request: WebCrawlRequest,
current_user = Depends(get_current_user)
):
kb_id = req.kb_id
if not kb_id:
return get_json_result(data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
name = req.name
url = req.url
"""网页爬取"""
kb_id = request.kb_id
name = request.name
url = request.url
if not is_valid_url(url):
return get_json_result(data=False, message="The URL format is invalid", code=settings.RetCode.ARGUMENT_ERROR)
e, kb = KnowledgebaseService.get_by_id(kb_id)
if not e:
raise LookupError("Can't find this knowledgebase!")
@@ -259,6 +145,7 @@ async def web_crawl(
"id": get_uuid(),
"kb_id": kb.id,
"parser_id": kb.parser_id,
"pipeline_id": kb.pipeline_id,
"parser_config": kb.parser_config,
"created_by": current_user.id,
"type": filetype,
@@ -285,25 +172,26 @@ async def web_crawl(
@router.post("/create")
async def create(
req: CreateDocumentRequest,
request: CreateDocumentRequest,
current_user = Depends(get_current_user)
):
kb_id = req.kb_id
if not kb_id:
return get_json_result(data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
if len(req.name.encode("utf-8")) > FILE_NAME_LEN_LIMIT:
"""创建文档"""
req = request.model_dump(exclude_unset=True)
kb_id = req["kb_id"]
if len(req["name"].encode("utf-8")) > FILE_NAME_LEN_LIMIT:
return get_json_result(data=False, message=f"File name must be {FILE_NAME_LEN_LIMIT} bytes or less.", code=settings.RetCode.ARGUMENT_ERROR)
if req.name.strip() == "":
if req["name"].strip() == "":
return get_json_result(data=False, message="File name can't be empty.", code=settings.RetCode.ARGUMENT_ERROR)
req.name = req.name.strip()
req["name"] = req["name"].strip()
try:
e, kb = KnowledgebaseService.get_by_id(kb_id)
if not e:
return get_data_error_result(message="Can't find this knowledgebase!")
if DocumentService.query(name=req.name, kb_id=kb_id):
if DocumentService.query(name=req["name"], kb_id=kb_id):
return get_data_error_result(message="Duplicated document name in the same knowledgebase.")
kb_root_folder = FileService.get_kb_folder(kb.tenant_id)
@@ -326,8 +214,8 @@ async def create(
"parser_config": kb.parser_config,
"created_by": current_user.id,
"type": FileType.VIRTUAL,
"name": req.name,
"suffix": Path(req.name).suffix.lstrip("."),
"name": req["name"],
"suffix": Path(req["name"]).suffix.lstrip("."),
"location": "",
"size": 0,
}
@@ -342,47 +230,46 @@ async def create(
@router.post("/list")
async def list_docs(
kb_id: str = Query(...),
keywords: str = Query(""),
page: int = Query(0),
page_size: int = Query(0),
orderby: str = Query("create_time"),
desc: str = Query("true"),
create_time_from: int = Query(0),
create_time_to: int = Query(0),
req: DocumentListRequest = None,
query: ListDocumentsQuery = Depends(),
body: Optional[ListDocumentsBody] = None,
current_user = Depends(get_current_user)
):
if not kb_id:
return get_json_result(data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
"""列出文档"""
if body is None:
body = ListDocumentsBody()
kb_id = query.kb_id
tenants = UserTenantService.query(user_id=current_user.id)
for tenant in tenants:
if KnowledgebaseService.query(tenant_id=tenant.tenant_id, id=kb_id):
break
else:
return get_json_result(data=False, message="Only owner of knowledgebase authorized for this operation.", code=settings.RetCode.OPERATING_ERROR)
keywords = query.keywords or ""
page_number = int(query.page or 0)
items_per_page = int(query.page_size or 0)
orderby = query.orderby or "create_time"
desc = query.desc.lower() == "true" if query.desc else True
create_time_from = int(query.create_time_from or 0)
create_time_to = int(query.create_time_to or 0)
if desc.lower() == "false":
desc_bool = False
else:
desc_bool = True
run_status = req.run_status if req else []
run_status = body.run_status or []
if run_status:
invalid_status = {s for s in run_status if s not in VALID_TASK_STATUS}
if invalid_status:
return get_data_error_result(message=f"Invalid filter run status conditions: {', '.join(invalid_status)}")
types = req.types if req else []
types = body.types or []
if types:
invalid_types = {t for t in types if t not in VALID_FILE_TYPES}
if invalid_types:
return get_data_error_result(message=f"Invalid filter conditions: {', '.join(invalid_types)} type{'s' if len(invalid_types) > 1 else ''}")
suffix = req.suffix if req else []
suffix = body.suffix or []
try:
docs, tol = DocumentService.get_by_kb_id(kb_id, page, page_size, orderby, desc_bool, keywords, run_status, types, suffix)
docs, tol = DocumentService.get_by_kb_id(kb_id, page_number, items_per_page, orderby, desc, keywords, run_status, types, suffix)
if create_time_from or create_time_to:
filtered_docs = []
@@ -403,12 +290,11 @@ async def list_docs(
@router.post("/filter")
async def get_filter(
req: DocumentFilterRequest,
request: FilterDocumentsRequest,
current_user = Depends(get_current_user)
):
kb_id = req.kb_id
if not kb_id:
return get_json_result(data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
"""过滤文档"""
kb_id = request.kb_id
tenants = UserTenantService.query(user_id=current_user.id)
for tenant in tenants:
if KnowledgebaseService.query(tenant_id=tenant.tenant_id, id=kb_id):
@@ -416,15 +302,16 @@ async def get_filter(
else:
return get_json_result(data=False, message="Only owner of knowledgebase authorized for this operation.", code=settings.RetCode.OPERATING_ERROR)
keywords = req.keywords
suffix = req.suffix
run_status = req.run_status
keywords = request.keywords or ""
suffix = request.suffix or []
run_status = request.run_status or []
if run_status:
invalid_status = {s for s in run_status if s not in VALID_TASK_STATUS}
if invalid_status:
return get_data_error_result(message=f"Invalid filter run status conditions: {', '.join(invalid_status)}")
types = req.types
types = request.types or []
if types:
invalid_types = {t for t in types if t not in VALID_FILE_TYPES}
if invalid_types:
@@ -439,10 +326,11 @@ async def get_filter(
@router.post("/infos")
async def docinfos(
req: DocumentInfosRequest,
request: GetDocumentInfosRequest,
current_user = Depends(get_current_user)
):
doc_ids = req.doc_ids
"""获取文档信息"""
doc_ids = request.doc_ids
for doc_id in doc_ids:
if not DocumentService.accessible(doc_id, current_user.id):
return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR)
@@ -452,8 +340,9 @@ async def docinfos(
@router.get("/thumbnails")
async def thumbnails(
doc_ids: List[str] = Query(...)
doc_ids: List[str] = Query(..., description="文档ID列表"),
):
"""获取文档缩略图"""
if not doc_ids:
return get_json_result(data=False, message='Lack of "Document ID"', code=settings.RetCode.ARGUMENT_ERROR)
@@ -471,14 +360,12 @@ async def thumbnails(
@router.post("/change_status")
async def change_status(
req: ChangeStatusRequest,
request: ChangeStatusRequest,
current_user = Depends(get_current_user)
):
doc_ids = req.doc_ids
status = str(req.status)
if status not in ["0", "1"]:
return get_json_result(data=False, message='"Status" must be either 0 or 1!', code=settings.RetCode.ARGUMENT_ERROR)
"""修改文档状态"""
doc_ids = request.doc_ids
status = request.status
result = {}
for doc_id in doc_ids:
@@ -511,10 +398,11 @@ async def change_status(
@router.post("/rm")
async def rm(
req: RemoveDocumentRequest,
request: DeleteDocumentRequest,
current_user = Depends(get_current_user)
):
doc_ids = req.doc_id
"""删除文档"""
doc_ids = request.doc_id
if isinstance(doc_ids, str):
doc_ids = [doc_ids]
@@ -570,17 +458,19 @@ async def rm(
@router.post("/run")
async def run(
req: RunDocumentRequest,
request: RunDocumentRequest,
current_user = Depends(get_current_user)
):
for doc_id in req.doc_ids:
"""运行文档解析"""
for doc_id in request.doc_ids:
if not DocumentService.accessible(doc_id, current_user.id):
return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR)
try:
kb_table_num_map = {}
for id in req.doc_ids:
info = {"run": str(req.run), "progress": 0}
if req.run == int(TaskStatus.RUNNING.value) and req.delete:
for id in request.doc_ids:
info = {"run": str(request.run), "progress": 0}
if str(request.run) == TaskStatus.RUNNING.value and request.delete:
info["progress_msg"] = ""
info["chunk_num"] = 0
info["token_num"] = 0
@@ -592,21 +482,21 @@ async def run(
if not e:
return get_data_error_result(message="Document not found!")
if req.run == int(TaskStatus.CANCEL.value):
if str(request.run) == TaskStatus.CANCEL.value:
if str(doc.run) == TaskStatus.RUNNING.value:
cancel_all_task_of(id)
else:
return get_data_error_result(message="Cannot cancel a task that is not in RUNNING status")
if all([req.delete, req.run == int(TaskStatus.RUNNING.value), str(doc.run) == TaskStatus.DONE.value]):
if all([not request.delete, str(request.run) == TaskStatus.RUNNING.value, str(doc.run) == TaskStatus.DONE.value]):
DocumentService.clear_chunk_num_when_rerun(doc.id)
DocumentService.update_by_id(id, info)
if req.delete:
if request.delete:
TaskService.filter_delete([Task.doc_id == id])
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id)
if req.run == int(TaskStatus.RUNNING.value):
if str(request.run) == TaskStatus.RUNNING.value:
doc = doc.to_dict()
doc["tenant_id"] = tenant_id
@@ -628,37 +518,53 @@ async def run(
return get_json_result(data=True)
except Exception as e:
traceback.print_exc()
return server_error_response(e)
@router.post("/rename")
async def rename(
req: RenameDocumentRequest,
request: RenameDocumentRequest,
current_user = Depends(get_current_user)
):
if not DocumentService.accessible(req.doc_id, current_user.id):
"""重命名文档"""
if not DocumentService.accessible(request.doc_id, current_user.id):
return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR)
try:
e, doc = DocumentService.get_by_id(req.doc_id)
e, doc = DocumentService.get_by_id(request.doc_id)
if not e:
return get_data_error_result(message="Document not found!")
if pathlib.Path(req.name.lower()).suffix != pathlib.Path(doc.name.lower()).suffix:
if pathlib.Path(request.name.lower()).suffix != pathlib.Path(doc.name.lower()).suffix:
return get_json_result(data=False, message="The extension of file can't be changed", code=settings.RetCode.ARGUMENT_ERROR)
if len(req.name.encode("utf-8")) > FILE_NAME_LEN_LIMIT:
if len(request.name.encode("utf-8")) > FILE_NAME_LEN_LIMIT:
return get_json_result(data=False, message=f"File name must be {FILE_NAME_LEN_LIMIT} bytes or less.", code=settings.RetCode.ARGUMENT_ERROR)
for d in DocumentService.query(name=req.name, kb_id=doc.kb_id):
if d.name == req.name:
for d in DocumentService.query(name=request.name, kb_id=doc.kb_id):
if d.name == request.name:
return get_data_error_result(message="Duplicated document name in the same knowledgebase.")
if not DocumentService.update_by_id(req.doc_id, {"name": req.name}):
if not DocumentService.update_by_id(request.doc_id, {"name": request.name}):
return get_data_error_result(message="Database error (Document rename)!")
informs = File2DocumentService.get_by_document_id(req.doc_id)
informs = File2DocumentService.get_by_document_id(request.doc_id)
if informs:
e, file = FileService.get_by_id(informs[0].file_id)
FileService.update_by_id(file.id, {"name": req.name})
FileService.update_by_id(file.id, {"name": request.name})
tenant_id = DocumentService.get_tenant_id(request.doc_id)
title_tks = rag_tokenizer.tokenize(request.name)
es_body = {
"docnm_kwd": request.name,
"title_tks": title_tks,
"title_sm_tks": rag_tokenizer.fine_grained_tokenize(title_tks),
}
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
settings.docStoreConn.update(
{"doc_id": request.doc_id},
es_body,
search.index_name(tenant_id),
doc.kb_id,
)
return get_json_result(data=True)
except Exception as e:
@@ -667,6 +573,7 @@ async def rename(
@router.get("/get/{doc_id}")
async def get(doc_id: str):
"""获取文档文件"""
try:
e, doc = DocumentService.get_by_id(doc_id)
if not e:
@@ -677,70 +584,79 @@ async def get(doc_id: str):
ext = re.search(r"\.([^.]+)$", doc.name.lower())
ext = ext.group(1) if ext else None
content_type = "application/octet-stream"
if ext:
if doc.type == FileType.VISUAL.value:
media_type = CONTENT_TYPE_MAP.get(ext, f"image/{ext}")
content_type = CONTENT_TYPE_MAP.get(ext, f"image/{ext}")
else:
media_type = CONTENT_TYPE_MAP.get(ext, f"application/{ext}")
else:
media_type = "application/octet-stream"
return StreamingResponse(
iter([content]),
media_type=media_type,
headers={"Content-Disposition": f"attachment; filename={doc.name}"}
)
content_type = CONTENT_TYPE_MAP.get(ext, f"application/{ext}")
return Response(content=content, media_type=content_type)
except Exception as e:
return server_error_response(e)
@router.post("/change_parser")
async def change_parser(
req: ChangeParserRequest,
request: ChangeParserSimpleRequest,
current_user = Depends(get_current_user)
):
if not DocumentService.accessible(req.doc_id, current_user.id):
"""修改文档解析器"""
if not DocumentService.accessible(request.doc_id, current_user.id):
return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR)
e, doc = DocumentService.get_by_id(req.doc_id)
e, doc = DocumentService.get_by_id(request.doc_id)
if not e:
return get_data_error_result(message="Document not found!")
def reset_doc():
def reset_doc(update_data_override=None):
nonlocal doc
e = DocumentService.update_by_id(doc.id, {"parser_id": req.parser_id, "progress": 0, "progress_msg": "", "run": TaskStatus.UNSTART.value})
update_data = update_data_override or {}
if request.pipeline_id is not None:
update_data["pipeline_id"] = request.pipeline_id
if request.parser_id is not None:
update_data["parser_id"] = request.parser_id
update_data.update({
"progress": 0,
"progress_msg": "",
"run": TaskStatus.UNSTART.value
})
e = DocumentService.update_by_id(doc.id, update_data)
if not e:
return get_data_error_result(message="Document not found!")
if doc.token_num > 0:
e = DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num * -1, doc.chunk_num * -1, doc.process_duration * -1)
if not e:
return get_data_error_result(message="Document not found!")
tenant_id = DocumentService.get_tenant_id(req.doc_id)
tenant_id = DocumentService.get_tenant_id(request.doc_id)
if not tenant_id:
return get_data_error_result(message="Tenant not found!")
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
try:
if req.pipeline_id:
if doc.pipeline_id == req.pipeline_id:
if request.pipeline_id is not None and request.pipeline_id != "":
if doc.pipeline_id == request.pipeline_id:
return get_json_result(data=True)
DocumentService.update_by_id(doc.id, {"pipeline_id": req.pipeline_id})
reset_doc()
reset_doc({"pipeline_id": request.pipeline_id})
return get_json_result(data=True)
if doc.parser_id.lower() == req.parser_id.lower():
if req.parser_config:
if req.parser_config == doc.parser_config:
if request.parser_id is None:
return get_json_result(data=False, message="缺少 parser_id 或 pipeline_id", code=settings.RetCode.ARGUMENT_ERROR)
if doc.parser_id.lower() == request.parser_id.lower():
if request.parser_config is not None:
if request.parser_config == doc.parser_config:
return get_json_result(data=True)
else:
return get_json_result(data=True)
if (doc.type == FileType.VISUAL and req.parser_id != "picture") or (re.search(r"\.(ppt|pptx|pages)$", doc.name) and req.parser_id != "presentation"):
if (doc.type == FileType.VISUAL and request.parser_id != "picture") or (re.search(r"\.(ppt|pptx|pages)$", doc.name) and request.parser_id != "presentation"):
return get_data_error_result(message="Not supported yet!")
if req.parser_config:
DocumentService.update_parser_config(doc.id, req.parser_config)
if request.parser_config is not None:
DocumentService.update_parser_config(doc.id, request.parser_config)
reset_doc()
return get_json_result(data=True)
except Exception as e:
@@ -749,16 +665,14 @@ async def change_parser(
@router.get("/image/{image_id}")
async def get_image(image_id: str):
"""获取图片"""
try:
arr = image_id.split("-")
if len(arr) != 2:
return get_data_error_result(message="Image not found.")
bkt, nm = image_id.split("-")
content = STORAGE_IMPL.get(bkt, nm)
return StreamingResponse(
iter([content]),
media_type="image/JPEG"
)
return Response(content=content, media_type="image/JPEG")
except Exception as e:
return server_error_response(e)
@@ -769,28 +683,28 @@ async def upload_and_parse(
files: List[UploadFile] = File(...),
current_user = Depends(get_current_user)
):
"""上传并解析"""
if not files:
return get_json_result(data=False, message="No file part!", code=settings.RetCode.ARGUMENT_ERROR)
# Use UploadFile directly
file_objs = files
for file_obj in file_objs:
if file_obj.filename == "":
for file_obj in files:
if not file_obj.filename or file_obj.filename == "":
return get_json_result(data=False, message="No file selected!", code=settings.RetCode.ARGUMENT_ERROR)
doc_ids = await doc_upload_and_parse(conversation_id, file_objs, current_user.id)
doc_ids = doc_upload_and_parse(conversation_id, files, current_user.id)
return get_json_result(data=doc_ids)
@router.post("/parse")
async def parse(
req: ParseRequest = None,
files: List[UploadFile] = File(None),
request: Optional[ParseRequest] = None,
files: Optional[List[UploadFile]] = File(None),
current_user = Depends(get_current_user)
):
url = req.url if req else ""
"""解析文档"""
url = request.url if request else None
if url:
if not is_valid_url(url):
return get_json_result(data=False, message="The URL format is invalid", code=settings.RetCode.ARGUMENT_ERROR)
@@ -828,46 +742,37 @@ async def parse(
if not r or not r.group(1):
return get_json_result(data=False, message="Can't not identify downloaded file", code=settings.RetCode.ARGUMENT_ERROR)
f = File(r.group(1), os.path.join(download_path, r.group(1)))
txt = await FileService.parse_docs([f], current_user.id)
txt = FileService.parse_docs([f], current_user.id)
return get_json_result(data=txt)
if not files:
return get_json_result(data=False, message="No file part!", code=settings.RetCode.ARGUMENT_ERROR)
# Use UploadFile directly
file_objs = files
txt = await FileService.parse_docs(file_objs, current_user.id)
txt = FileService.parse_docs(files, current_user.id)
return get_json_result(data=txt)
@router.post("/set_meta")
async def set_meta(
req: SetMetaRequest,
request: SetMetaRequest,
current_user = Depends(get_current_user)
):
if not DocumentService.accessible(req.doc_id, current_user.id):
"""设置元数据"""
if not DocumentService.accessible(request.doc_id, current_user.id):
return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR)
try:
meta = json.loads(req.meta)
if not isinstance(meta, dict):
return get_json_result(data=False, message="Only dictionary type supported.", code=settings.RetCode.ARGUMENT_ERROR)
for k, v in meta.items():
if not isinstance(v, str) and not isinstance(v, int) and not isinstance(v, float):
return get_json_result(data=False, message=f"The type is not supported: {v}", code=settings.RetCode.ARGUMENT_ERROR)
except Exception as e:
return get_json_result(data=False, message=f"Json syntax error: {e}", code=settings.RetCode.ARGUMENT_ERROR)
if not isinstance(meta, dict):
return get_json_result(data=False, message='Meta data should be in Json map format, like {"key": "value"}', code=settings.RetCode.ARGUMENT_ERROR)
try:
e, doc = DocumentService.get_by_id(req.doc_id)
meta = json.loads(request.meta)
e, doc = DocumentService.get_by_id(request.doc_id)
if not e:
return get_data_error_result(message="Document not found!")
if not DocumentService.update_by_id(req.doc_id, {"meta_fields": meta}):
if not DocumentService.update_by_id(request.doc_id, {"meta_fields": meta}):
return get_data_error_result(message="Database error (meta updates)!")
return get_json_result(data=True)
except Exception as e:
return server_error_response(e)

View File

@@ -15,15 +15,12 @@
#
from pathlib import Path
from typing import List
from fastapi import APIRouter, Depends
from fastapi.security import HTTPAuthorizationCredentials
from api.utils.api_utils import security
from api.db.services.file2document_service import File2DocumentService
from api.db.services.file_service import FileService
from flask import request
from flask_login import login_required, current_user
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from api.utils import get_uuid
@@ -31,91 +28,15 @@ from api.db import FileType
from api.db.services.document_service import DocumentService
from api import settings
from api.utils.api_utils import get_json_result
from pydantic import BaseModel
# Security
# Pydantic models for request/response
class ConvertRequest(BaseModel):
file_ids: List[str]
kb_ids: List[str]
class RemoveFile2DocumentRequest(BaseModel):
file_ids: List[str]
# Dependency injection
async def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(security)):
"""获取当前用户"""
from api.db import StatusEnum
from api.db.services.user_service import UserService
from fastapi import HTTPException, status
import logging
try:
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
except ImportError:
# 如果没有itsdangerous使用jwt作为替代
import jwt
Serializer = jwt
jwt = Serializer(secret_key=settings.SECRET_KEY)
authorization = credentials.credentials
if authorization:
try:
access_token = str(jwt.loads(authorization))
if not access_token or not access_token.strip():
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authentication attempt with empty access token"
)
# Access tokens should be UUIDs (32 hex characters)
if len(access_token.strip()) < 32:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=f"Authentication attempt with invalid token format: {len(access_token)} chars"
)
user = UserService.query(
access_token=access_token, status=StatusEnum.VALID.value
)
if user:
if not user[0].access_token or not user[0].access_token.strip():
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=f"User {user[0].email} has empty access_token in database"
)
return user[0]
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid access token"
)
except Exception as e:
logging.warning(f"load_user got exception {e}")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid access token"
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authorization header required"
)
# Create router
router = APIRouter()
@router.post('/convert')
async def convert(
req: ConvertRequest,
current_user = Depends(get_current_user)
):
kb_ids = req.kb_ids
file_ids = req.file_ids
@manager.route('/convert', methods=['POST']) # noqa: F821
@login_required
@validate_request("file_ids", "kb_ids")
def convert():
req = request.json
kb_ids = req["kb_ids"]
file_ids = req["file_ids"]
file2documents = []
try:
@@ -179,12 +100,12 @@ async def convert(
return server_error_response(e)
@router.post('/rm')
async def rm(
req: RemoveFile2DocumentRequest,
current_user = Depends(get_current_user)
):
file_ids = req.file_ids
@manager.route('/rm', methods=['POST']) # noqa: F821
@login_required
@validate_request("file_ids")
def rm():
req = request.json
file_ids = req["file_ids"]
if not file_ids:
return get_json_result(
data=False, message='Lack of "Files ID"', code=settings.RetCode.ARGUMENT_ERROR)

View File

@@ -13,15 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License
#
import logging
import os
import pathlib
import re
from typing import List, Optional
from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile, Query
from fastapi.responses import StreamingResponse
from fastapi.security import HTTPAuthorizationCredentials
from api.utils.api_utils import security
import flask
from flask import request
from flask_login import login_required, current_user
from api.common.check_team_permission import check_file_team_permission
from api.db.services.document_service import DocumentService
@@ -36,110 +35,22 @@ from api.utils.api_utils import get_json_result
from api.utils.file_utils import filename_type
from api.utils.web_utils import CONTENT_TYPE_MAP
from rag.utils.storage_factory import STORAGE_IMPL
from pydantic import BaseModel
# Security
# Pydantic models for request/response
class CreateFileRequest(BaseModel):
name: str
parent_id: Optional[str] = None
type: Optional[str] = None
class RemoveFileRequest(BaseModel):
file_ids: List[str]
class RenameFileRequest(BaseModel):
file_id: str
name: str
class MoveFileRequest(BaseModel):
src_file_ids: List[str]
dest_file_id: str
# Dependency injection
async def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(security)):
"""获取当前用户"""
from api.db import StatusEnum
from api.db.services.user_service import UserService
from fastapi import HTTPException, status
import logging
try:
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
except ImportError:
# 如果没有itsdangerous使用jwt作为替代
import jwt
Serializer = jwt
jwt = Serializer(secret_key=settings.SECRET_KEY)
authorization = credentials.credentials
if authorization:
try:
access_token = str(jwt.loads(authorization))
if not access_token or not access_token.strip():
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authentication attempt with empty access token"
)
# Access tokens should be UUIDs (32 hex characters)
if len(access_token.strip()) < 32:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=f"Authentication attempt with invalid token format: {len(access_token)} chars"
)
user = UserService.query(
access_token=access_token, status=StatusEnum.VALID.value
)
if user:
if not user[0].access_token or not user[0].access_token.strip():
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=f"User {user[0].email} has empty access_token in database"
)
return user[0]
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid access token"
)
except Exception as e:
logging.warning(f"load_user got exception {e}")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid access token"
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authorization header required"
)
# Create router
router = APIRouter()
@router.post('/upload')
async def upload(
parent_id: Optional[str] = Form(None),
files: List[UploadFile] = File(...),
current_user = Depends(get_current_user)
):
pf_id = parent_id
@manager.route('/upload', methods=['POST']) # noqa: F821
@login_required
# @validate_request("parent_id")
def upload():
pf_id = request.form.get("parent_id")
if not pf_id:
root_folder = FileService.get_root_folder(current_user.id)
pf_id = root_folder["id"]
if not files:
if 'file' not in request.files:
return get_json_result(
data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR)
file_objs = files
file_objs = request.files.getlist('file')
for file_obj in file_objs:
if file_obj.filename == '':
@@ -186,7 +97,7 @@ async def upload(
location = file_obj_names[file_len - 1]
while STORAGE_IMPL.obj_exist(last_folder.id, location):
location += "_"
blob = await file_obj.read()
blob = file_obj.read()
filename = duplicate_name(
FileService.query,
name=file_obj_names[file_len - 1],
@@ -209,13 +120,13 @@ async def upload(
return server_error_response(e)
@router.post('/create')
async def create(
req: CreateFileRequest,
current_user = Depends(get_current_user)
):
pf_id = req.parent_id
input_file_type = req.type
@manager.route('/create', methods=['POST']) # noqa: F821
@login_required
@validate_request("name")
def create():
req = request.json
pf_id = request.json.get("parent_id")
input_file_type = request.json.get("type")
if not pf_id:
root_folder = FileService.get_root_folder(current_user.id)
pf_id = root_folder["id"]
@@ -224,7 +135,7 @@ async def create(
if not FileService.is_parent_folder_exist(pf_id):
return get_json_result(
data=False, message="Parent Folder Doesn't Exist!", code=settings.RetCode.OPERATING_ERROR)
if FileService.query(name=req.name, parent_id=pf_id):
if FileService.query(name=req["name"], parent_id=pf_id):
return get_data_error_result(
message="Duplicated folder name in the same folder.")
@@ -238,7 +149,7 @@ async def create(
"parent_id": pf_id,
"tenant_id": current_user.id,
"created_by": current_user.id,
"name": req.name,
"name": req["name"],
"location": "",
"size": 0,
"type": file_type
@@ -249,18 +160,17 @@ async def create(
return server_error_response(e)
@router.get('/list')
async def list_files(
parent_id: Optional[str] = Query(None),
keywords: str = Query(""),
page: int = Query(1),
page_size: int = Query(15),
orderby: str = Query("create_time"),
desc: bool = Query(True),
current_user = Depends(get_current_user)
):
pf_id = parent_id
@manager.route('/list', methods=['GET']) # noqa: F821
@login_required
def list_files():
pf_id = request.args.get("parent_id")
keywords = request.args.get("keywords", "")
page_number = int(request.args.get("page", 1))
items_per_page = int(request.args.get("page_size", 15))
orderby = request.args.get("orderby", "create_time")
desc = request.args.get("desc", True)
if not pf_id:
root_folder = FileService.get_root_folder(current_user.id)
pf_id = root_folder["id"]
@@ -271,7 +181,7 @@ async def list_files(
return get_data_error_result(message="Folder not found!")
files, total = FileService.get_by_pf_id(
current_user.id, pf_id, page, page_size, orderby, desc, keywords)
current_user.id, pf_id, page_number, items_per_page, orderby, desc, keywords)
parent_folder = FileService.get_parent_folder(pf_id)
if not parent_folder:
@@ -282,8 +192,9 @@ async def list_files(
return server_error_response(e)
@router.get('/root_folder')
async def get_root_folder(current_user = Depends(get_current_user)):
@manager.route('/root_folder', methods=['GET']) # noqa: F821
@login_required
def get_root_folder():
try:
root_folder = FileService.get_root_folder(current_user.id)
return get_json_result(data={"root_folder": root_folder})
@@ -291,11 +202,10 @@ async def get_root_folder(current_user = Depends(get_current_user)):
return server_error_response(e)
@router.get('/parent_folder')
async def get_parent_folder(
file_id: str = Query(...),
current_user = Depends(get_current_user)
):
@manager.route('/parent_folder', methods=['GET']) # noqa: F821
@login_required
def get_parent_folder():
file_id = request.args.get("file_id")
try:
e, file = FileService.get_by_id(file_id)
if not e:
@@ -307,11 +217,10 @@ async def get_parent_folder(
return server_error_response(e)
@router.get('/all_parent_folder')
async def get_all_parent_folders(
file_id: str = Query(...),
current_user = Depends(get_current_user)
):
@manager.route('/all_parent_folder', methods=['GET']) # noqa: F821
@login_required
def get_all_parent_folders():
file_id = request.args.get("file_id")
try:
e, file = FileService.get_by_id(file_id)
if not e:
@@ -326,90 +235,99 @@ async def get_all_parent_folders(
return server_error_response(e)
@router.post('/rm')
async def rm(
req: RemoveFileRequest,
current_user = Depends(get_current_user)
):
file_ids = req.file_ids
@manager.route("/rm", methods=["POST"]) # noqa: F821
@login_required
@validate_request("file_ids")
def rm():
req = request.json
file_ids = req["file_ids"]
def _delete_single_file(file):
try:
if file.location:
STORAGE_IMPL.rm(file.parent_id, file.location)
except Exception:
logging.exception(f"Fail to remove object: {file.parent_id}/{file.location}")
informs = File2DocumentService.get_by_file_id(file.id)
for inform in informs:
doc_id = inform.document_id
e, doc = DocumentService.get_by_id(doc_id)
if e and doc:
tenant_id = DocumentService.get_tenant_id(doc_id)
if tenant_id:
DocumentService.remove_document(doc, tenant_id)
File2DocumentService.delete_by_file_id(file.id)
FileService.delete(file)
def _delete_folder_recursive(folder, tenant_id):
sub_files = FileService.list_all_files_by_parent_id(folder.id)
for sub_file in sub_files:
if sub_file.type == FileType.FOLDER.value:
_delete_folder_recursive(sub_file, tenant_id)
else:
_delete_single_file(sub_file)
FileService.delete(folder)
try:
for file_id in file_ids:
e, file = FileService.get_by_id(file_id)
if not e:
if not e or not file:
return get_data_error_result(message="File or Folder not found!")
if not file.tenant_id:
return get_data_error_result(message="Tenant not found!")
if not check_file_team_permission(file, current_user.id):
return get_json_result(data=False, message='No authorization.', code=settings.RetCode.AUTHENTICATION_ERROR)
return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR)
if file.source_type == FileSource.KNOWLEDGEBASE:
continue
if file.type == FileType.FOLDER.value:
file_id_list = FileService.get_all_innermost_file_ids(file_id, [])
for inner_file_id in file_id_list:
e, file = FileService.get_by_id(inner_file_id)
if not e:
return get_data_error_result(message="File not found!")
STORAGE_IMPL.rm(file.parent_id, file.location)
FileService.delete_folder_by_pf_id(current_user.id, file_id)
else:
STORAGE_IMPL.rm(file.parent_id, file.location)
if not FileService.delete(file):
return get_data_error_result(
message="Database error (File removal)!")
_delete_folder_recursive(file, current_user.id)
continue
# delete file2document
informs = File2DocumentService.get_by_file_id(file_id)
for inform in informs:
doc_id = inform.document_id
e, doc = DocumentService.get_by_id(doc_id)
if not e:
return get_data_error_result(message="Document not found!")
tenant_id = DocumentService.get_tenant_id(doc_id)
if not tenant_id:
return get_data_error_result(message="Tenant not found!")
if not DocumentService.remove_document(doc, tenant_id):
return get_data_error_result(
message="Database error (Document removal)!")
File2DocumentService.delete_by_file_id(file_id)
_delete_single_file(file)
return get_json_result(data=True)
except Exception as e:
return server_error_response(e)
@router.post('/rename')
async def rename(
req: RenameFileRequest,
current_user = Depends(get_current_user)
):
@manager.route('/rename', methods=['POST']) # noqa: F821
@login_required
@validate_request("file_id", "name")
def rename():
req = request.json
try:
e, file = FileService.get_by_id(req.file_id)
e, file = FileService.get_by_id(req["file_id"])
if not e:
return get_data_error_result(message="File not found!")
if not check_file_team_permission(file, current_user.id):
return get_json_result(data=False, message='No authorization.', code=settings.RetCode.AUTHENTICATION_ERROR)
if file.type != FileType.FOLDER.value \
and pathlib.Path(req.name.lower()).suffix != pathlib.Path(
and pathlib.Path(req["name"].lower()).suffix != pathlib.Path(
file.name.lower()).suffix:
return get_json_result(
data=False,
message="The extension of file can't be changed",
code=settings.RetCode.ARGUMENT_ERROR)
for file in FileService.query(name=req.name, pf_id=file.parent_id):
if file.name == req.name:
for file in FileService.query(name=req["name"], pf_id=file.parent_id):
if file.name == req["name"]:
return get_data_error_result(
message="Duplicated file name in the same folder.")
if not FileService.update_by_id(
req.file_id, {"name": req.name}):
req["file_id"], {"name": req["name"]}):
return get_data_error_result(
message="Database error (File rename)!")
informs = File2DocumentService.get_by_file_id(req.file_id)
informs = File2DocumentService.get_by_file_id(req["file_id"])
if informs:
if not DocumentService.update_by_id(
informs[0].document_id, {"name": req.name}):
informs[0].document_id, {"name": req["name"]}):
return get_data_error_result(
message="Database error (Document rename)!")
@@ -418,8 +336,9 @@ async def rename(
return server_error_response(e)
@router.get('/get/{file_id}')
async def get(file_id: str, current_user = Depends(get_current_user)):
@manager.route('/get/<file_id>', methods=['GET']) # noqa: F821
@login_required
def get(file_id):
try:
e, file = FileService.get_by_id(file_id)
if not e:
@@ -432,6 +351,7 @@ async def get(file_id: str, current_user = Depends(get_current_user)):
b, n = File2DocumentService.get_storage_address(file_id=file_id)
blob = STORAGE_IMPL.get(b, n)
response = flask.make_response(blob)
ext = re.search(r"\.([^.]+)$", file.name.lower())
ext = ext.group(1) if ext else None
if ext:
@@ -439,43 +359,95 @@ async def get(file_id: str, current_user = Depends(get_current_user)):
content_type = CONTENT_TYPE_MAP.get(ext, f"image/{ext}")
else:
content_type = CONTENT_TYPE_MAP.get(ext, f"application/{ext}")
else:
content_type = "application/octet-stream"
return StreamingResponse(
iter([blob]),
media_type=content_type,
headers={"Content-Disposition": f"attachment; filename={file.name}"}
)
response.headers.set("Content-Type", content_type)
return response
except Exception as e:
return server_error_response(e)
@router.post('/mv')
async def move(
req: MoveFileRequest,
current_user = Depends(get_current_user)
):
@manager.route("/mv", methods=["POST"]) # noqa: F821
@login_required
@validate_request("src_file_ids", "dest_file_id")
def move():
req = request.json
try:
file_ids = req.src_file_ids
parent_id = req.dest_file_id
file_ids = req["src_file_ids"]
dest_parent_id = req["dest_file_id"]
ok, dest_folder = FileService.get_by_id(dest_parent_id)
if not ok or not dest_folder:
return get_data_error_result(message="Parent Folder not found!")
files = FileService.get_by_ids(file_ids)
files_dict = {}
for file in files:
files_dict[file.id] = file
if not files:
return get_data_error_result(message="Source files not found!")
files_dict = {f.id: f for f in files}
for file_id in file_ids:
file = files_dict[file_id]
file = files_dict.get(file_id)
if not file:
return get_data_error_result(message="File or Folder not found!")
if not file.tenant_id:
return get_data_error_result(message="Tenant not found!")
if not check_file_team_permission(file, current_user.id):
return get_json_result(data=False, message='No authorization.', code=settings.RetCode.AUTHENTICATION_ERROR)
fe, _ = FileService.get_by_id(parent_id)
if not fe:
return get_data_error_result(message="Parent Folder not found!")
FileService.move_file(file_ids, parent_id)
return get_json_result(
data=False,
message="No authorization.",
code=settings.RetCode.AUTHENTICATION_ERROR,
)
def _move_entry_recursive(source_file_entry, dest_folder):
if source_file_entry.type == FileType.FOLDER.value:
existing_folder = FileService.query(name=source_file_entry.name, parent_id=dest_folder.id)
if existing_folder:
new_folder = existing_folder[0]
else:
new_folder = FileService.insert(
{
"id": get_uuid(),
"parent_id": dest_folder.id,
"tenant_id": source_file_entry.tenant_id,
"created_by": current_user.id,
"name": source_file_entry.name,
"location": "",
"size": 0,
"type": FileType.FOLDER.value,
}
)
sub_files = FileService.list_all_files_by_parent_id(source_file_entry.id)
for sub_file in sub_files:
_move_entry_recursive(sub_file, new_folder)
FileService.delete_by_id(source_file_entry.id)
return
old_parent_id = source_file_entry.parent_id
old_location = source_file_entry.location
filename = source_file_entry.name
new_location = filename
while STORAGE_IMPL.obj_exist(dest_folder.id, new_location):
new_location += "_"
try:
STORAGE_IMPL.move(old_parent_id, old_location, dest_folder.id, new_location)
except Exception as storage_err:
raise RuntimeError(f"Move file failed at storage layer: {str(storage_err)}")
FileService.update_by_id(
source_file_entry.id,
{
"parent_id": dest_folder.id,
"location": new_location,
},
)
for file in files:
_move_entry_recursive(file, dest_folder)
return get_json_result(data=True)
except Exception as e:
return server_error_response(e)

View File

@@ -3,7 +3,7 @@
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# you may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
@@ -15,27 +15,29 @@
#
import json
import logging
from typing import Optional, List
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query, status
from fastapi.responses import JSONResponse
from fastapi import APIRouter, Depends, Query
from api.models.kb_models import (
from api.apps.models.auth_dependencies import get_current_user
from api.apps.models.kb_models import (
CreateKnowledgeBaseRequest,
UpdateKnowledgeBaseRequest,
DeleteKnowledgeBaseRequest,
ListKnowledgeBasesRequest,
ListKnowledgeBasesQuery,
ListKnowledgeBasesBody,
RemoveTagsRequest,
RenameTagRequest,
RunGraphRAGRequest,
ListPipelineLogsQuery,
ListPipelineLogsBody,
ListPipelineDatasetLogsQuery,
ListPipelineDatasetLogsBody,
DeletePipelineLogsQuery,
DeletePipelineLogsBody,
RunGraphragRequest,
RunRaptorRequest,
RunMindmapRequest,
ListPipelineLogsRequest,
ListPipelineDatasetLogsRequest,
DeletePipelineLogsRequest,
UnbindTaskRequest
)
from api.utils.api_utils import get_current_user
from api.db.services import duplicate_name
from api.db.services.document_service import DocumentService, queue_raptor_o_graphrag_tasks
@@ -44,7 +46,12 @@ from api.db.services.file_service import FileService
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
from api.db.services.task_service import TaskService, GRAPH_RAPTOR_FAKE_DOC_ID
from api.db.services.user_service import TenantService, UserTenantService
from api.utils.api_utils import get_error_data_result, server_error_response, get_data_error_result, get_json_result
from api.utils.api_utils import (
get_error_data_result,
server_error_response,
get_data_error_result,
get_json_result,
)
from api.utils import get_uuid
from api.db import PipelineTaskType, StatusEnum, FileSource, VALID_FILE_TYPES, VALID_TASK_STATUS
from api.db.services.knowledgebase_service import KnowledgebaseService
@@ -53,9 +60,10 @@ from api import settings
from rag.nlp import search
from api.constants import DATASET_NAME_LIMIT
from rag.settings import PAGERANK_FLD
from rag.utils.redis_conn import REDIS_CONN
from rag.utils.storage_factory import STORAGE_IMPL
# 创建 FastAPI 路由器
# 创建路由器
router = APIRouter()
@@ -64,7 +72,14 @@ async def create(
request: CreateKnowledgeBaseRequest,
current_user = Depends(get_current_user)
):
dataset_name = request.name
"""创建知识库
支持两种解析类型:
- parse_type=1: 使用内置解析器,需要 parser_id
- parse_type=2: 使用自定义 pipeline需要 pipeline_id
"""
req = request.model_dump(exclude_unset=True)
dataset_name = req["name"]
if not isinstance(dataset_name, str):
return get_data_error_result(message="Dataset name must be string.")
if dataset_name.strip() == "":
@@ -80,55 +95,66 @@ async def create(
tenant_id=current_user.id,
status=StatusEnum.VALID.value)
try:
req = {
"id": get_uuid(),
"name": dataset_name,
"tenant_id": current_user.id,
"created_by": current_user.id,
"parser_id": request.parser_id or "naive",
"description": request.description
}
# 根据 parse_type 处理 parser_id 和 pipeline_id
parse_type = req.pop("parse_type", 1) # 移除 parse_type不需要存储到数据库
if parse_type == 1:
# 使用内置解析器,需要 parser_id
# 验证器已经确保 parser_id 不为空,但保留默认值逻辑以防万一
if not req.get("parser_id") or req["parser_id"].strip() == "":
req["parser_id"] = "naive"
# 清空 pipeline_id设置为 None数据库字段允许为 null
req["pipeline_id"] = None
elif parse_type == 2:
# 使用自定义 pipeline需要 pipeline_id
# 验证器已经确保 pipeline_id 不为空
# parser_id 应该为空字符串,但数据库字段不允许 null所以不设置 parser_id
# 让数据库使用默认值 "naive"(虽然用户传入的是空字符串,但数据库会处理)
# 如果用户明确传入了空字符串,我们也不设置它,让数据库使用默认值
if "parser_id" in req and (not req["parser_id"] or req["parser_id"].strip() == ""):
# 移除空字符串的 parser_id让数据库使用默认值
req.pop("parser_id")
# pipeline_id 保留在 req 中,会被保存到数据库
req["id"] = get_uuid()
req["name"] = dataset_name
req["tenant_id"] = current_user.id
req["created_by"] = current_user.id
# embd_id 已经在模型中定义为必需字段,直接使用
e, t = TenantService.get_by_id(current_user.id)
if not e:
return get_data_error_result(message="Tenant not found.")
# 设置 embd_id 默认值
if not request.embd_id:
req["embd_id"] = t.embd_id
else:
req["embd_id"] = request.embd_id
if request.parser_config:
req["parser_config"] = request.parser_config
else:
req["parser_config"] = {
"layout_recognize": "DeepDOC",
"chunk_token_num": 512,
"delimiter": "\n",
"auto_keywords": 0,
"auto_questions": 0,
"html4excel": False,
"topn_tags": 3,
"raptor": {
"use_raptor": True,
"prompt": "Please summarize the following paragraphs. Be careful with the numbers, do not make things up. Paragraphs as following:\n {cluster_content}\nThe above is the content you need to summarize.",
"max_token": 256,
"threshold": 0.1,
"max_cluster": 64,
"random_seed": 0
},
"graphrag": {
"use_graphrag": True,
"entity_types": [
"organization",
"person",
"geo",
"event",
"category"
],
"method": "light"
}
req["parser_config"] = {
"layout_recognize": "DeepDOC",
"chunk_token_num": 512,
"delimiter": "\n",
"auto_keywords": 0,
"auto_questions": 0,
"html4excel": False,
"topn_tags": 3,
"raptor": {
"use_raptor": True,
"prompt": "Please summarize the following paragraphs. Be careful with the numbers, do not make things up. Paragraphs as following:\n {cluster_content}\nThe above is the content you need to summarize.",
"max_token": 256,
"threshold": 0.1,
"max_cluster": 64,
"random_seed": 0
},
"graphrag": {
"use_graphrag": True,
"entity_types": [
"organization",
"person",
"geo",
"event",
"category"
],
"method": "light"
}
}
if not KnowledgebaseService.save(**req):
return get_data_error_result()
return get_json_result(data={"kb_id": req["id"]})
@@ -141,16 +167,24 @@ async def update(
request: UpdateKnowledgeBaseRequest,
current_user = Depends(get_current_user)
):
if not isinstance(request.name, str):
"""更新知识库"""
req = request.model_dump(exclude_unset=True)
if not isinstance(req["name"], str):
return get_data_error_result(message="Dataset name must be string.")
if request.name.strip() == "":
if req["name"].strip() == "":
return get_data_error_result(message="Dataset name can't be empty.")
if len(request.name.encode("utf-8")) > DATASET_NAME_LIMIT:
if len(req["name"].encode("utf-8")) > DATASET_NAME_LIMIT:
return get_data_error_result(
message=f"Dataset name length is {len(request.name)} which is large than {DATASET_NAME_LIMIT}")
name = request.name.strip()
message=f"Dataset name length is {len(req['name'])} which is large than {DATASET_NAME_LIMIT}")
req["name"] = req["name"].strip()
if not KnowledgebaseService.accessible4deletion(request.kb_id, current_user.id):
# 验证不允许的参数
not_allowed = ["id", "tenant_id", "created_by", "create_time", "update_time", "create_date", "update_date"]
for key in not_allowed:
if key in req:
del req[key]
if not KnowledgebaseService.accessible4deletion(req["kb_id"], current_user.id):
return get_json_result(
data=False,
message='No authorization.',
@@ -158,48 +192,29 @@ async def update(
)
try:
if not KnowledgebaseService.query(
created_by=current_user.id, id=request.kb_id):
created_by=current_user.id, id=req["kb_id"]):
return get_json_result(
data=False, message='Only owner of knowledgebase authorized for this operation.',
code=settings.RetCode.OPERATING_ERROR)
e, kb = KnowledgebaseService.get_by_id(request.kb_id)
e, kb = KnowledgebaseService.get_by_id(req["kb_id"])
if not e:
return get_data_error_result(
message="Can't find this knowledgebase!")
if name.lower() != kb.name.lower() \
if req["name"].lower() != kb.name.lower() \
and len(
KnowledgebaseService.query(name=name, tenant_id=current_user.id, status=StatusEnum.VALID.value)) >= 1:
KnowledgebaseService.query(name=req["name"], tenant_id=current_user.id, status=StatusEnum.VALID.value)) >= 1:
return get_data_error_result(
message="Duplicated knowledgebase name.")
# 构建更新数据,包含所有可更新的字段
update_data = {
"name": name,
"pagerank": request.pagerank
}
# 添加可选字段(如果提供了的话)
if request.description is not None:
update_data["description"] = request.description
if request.permission is not None:
update_data["permission"] = request.permission
if request.avatar is not None:
update_data["avatar"] = request.avatar
if request.parser_id is not None:
update_data["parser_id"] = request.parser_id
if request.embd_id is not None:
update_data["embd_id"] = request.embd_id
if request.parser_config is not None:
update_data["parser_config"] = request.parser_config
if not KnowledgebaseService.update_by_id(kb.id, update_data):
kb_id = req.pop("kb_id")
if not KnowledgebaseService.update_by_id(kb.id, req):
return get_data_error_result()
if kb.pagerank != request.pagerank:
if request.pagerank > 0:
settings.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: request.pagerank},
if kb.pagerank != req.get("pagerank", 0):
if req.get("pagerank", 0) > 0:
settings.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]},
search.index_name(kb.tenant_id), kb.id)
else:
# Elasticsearch requires PAGERANK_FLD be non-zero!
@@ -211,26 +226,7 @@ async def update(
return get_data_error_result(
message="Database error (Knowledgebase rename)!")
kb = kb.to_dict()
# 使用完整的请求数据更新返回结果,保持与原来代码的一致性
request_data = {
"name": name,
"pagerank": request.pagerank
}
if request.description is not None:
request_data["description"] = request.description
if request.permission is not None:
request_data["permission"] = request.permission
if request.avatar is not None:
request_data["avatar"] = request.avatar
if request.parser_id is not None:
request_data["parser_id"] = request.parser_id
if request.embd_id is not None:
request_data["embd_id"] = request.embd_id
if request.parser_config is not None:
request_data["parser_config"] = request.parser_config
kb.update(request_data)
kb.update(req)
return get_json_result(data=kb)
except Exception as e:
@@ -242,6 +238,7 @@ async def detail(
kb_id: str = Query(..., description="知识库ID"),
current_user = Depends(get_current_user)
):
"""获取知识库详情"""
try:
tenants = UserTenantService.query(user_id=current_user.id)
for tenant in tenants:
@@ -257,6 +254,9 @@ async def detail(
return get_data_error_result(
message="Can't find this knowledgebase!")
kb["size"] = DocumentService.get_total_size_by_kb_id(kb_id=kb["id"],keywords="", run_status=[], types=[])
for key in ["graphrag_task_finish_at", "raptor_task_finish_at", "mindmap_task_finish_at"]:
if finish_at := kb.get(key):
kb[key] = finish_at.strftime("%Y-%m-%d %H:%M:%S")
return get_json_result(data=kb)
except Exception as e:
return server_error_response(e)
@@ -264,18 +264,22 @@ async def detail(
@router.post('/list')
async def list_kbs(
request: ListKnowledgeBasesRequest,
keywords: str = Query("", description="关键词"),
page: int = Query(0, description="页码"),
page_size: int = Query(0, description="每页大小"),
parser_id: Optional[str] = Query(None, description="解析器ID"),
orderby: str = Query("create_time", description="排序字段"),
desc: bool = Query(True, description="是否降序"),
query: ListKnowledgeBasesQuery = Depends(),
body: Optional[ListKnowledgeBasesBody] = None,
current_user = Depends(get_current_user)
):
page_number = page
items_per_page = page_size
owner_ids = request.owner_ids
"""列出知识库"""
if body is None:
body = ListKnowledgeBasesBody()
keywords = query.keywords or ""
page_number = int(query.page or 0)
items_per_page = int(query.page_size or 0)
parser_id = query.parser_id
orderby = query.orderby or "create_time"
desc = query.desc.lower() == "true" if query.desc else True
owner_ids = body.owner_ids or [] if body else []
try:
if not owner_ids:
tenants = TenantService.get_joined_tenants_by_user_id(current_user.id)
@@ -296,11 +300,13 @@ async def list_kbs(
except Exception as e:
return server_error_response(e)
@router.post('/rm')
async def rm(
request: DeleteKnowledgeBaseRequest,
current_user = Depends(get_current_user)
):
"""删除知识库"""
if not KnowledgebaseService.accessible4deletion(request.kb_id, current_user.id):
return get_json_result(
data=False,
@@ -343,6 +349,7 @@ async def list_tags(
kb_id: str,
current_user = Depends(get_current_user)
):
"""列出知识库标签"""
if not KnowledgebaseService.accessible(kb_id, current_user.id):
return get_json_result(
data=False,
@@ -353,17 +360,18 @@ async def list_tags(
tenants = UserTenantService.get_tenants_by_user_id(current_user.id)
tags = []
for tenant in tenants:
tags += settings.retrievaler.all_tags(tenant["tenant_id"], [kb_id])
tags += settings.retriever.all_tags(tenant["tenant_id"], [kb_id])
return get_json_result(data=tags)
@router.get('/tags')
async def list_tags_from_kbs(
kb_ids: str = Query(..., description="知识库ID列表逗号分隔"),
kb_ids: str = Query(..., description="知识库ID列表逗号分隔"),
current_user = Depends(get_current_user)
):
kb_ids = kb_ids.split(",")
for kb_id in kb_ids:
"""从多个知识库列出标签"""
kb_id_list = kb_ids.split(",")
for kb_id in kb_id_list:
if not KnowledgebaseService.accessible(kb_id, current_user.id):
return get_json_result(
data=False,
@@ -374,7 +382,7 @@ async def list_tags_from_kbs(
tenants = UserTenantService.get_tenants_by_user_id(current_user.id)
tags = []
for tenant in tenants:
tags += settings.retrievaler.all_tags(tenant["tenant_id"], kb_ids)
tags += settings.retriever.all_tags(tenant["tenant_id"], kb_id_list)
return get_json_result(data=tags)
@@ -384,6 +392,7 @@ async def rm_tags(
request: RemoveTagsRequest,
current_user = Depends(get_current_user)
):
"""删除知识库标签"""
if not KnowledgebaseService.accessible(kb_id, current_user.id):
return get_json_result(
data=False,
@@ -406,6 +415,7 @@ async def rename_tags(
request: RenameTagRequest,
current_user = Depends(get_current_user)
):
"""重命名知识库标签"""
if not KnowledgebaseService.accessible(kb_id, current_user.id):
return get_json_result(
data=False,
@@ -426,6 +436,7 @@ async def knowledge_graph(
kb_id: str,
current_user = Depends(get_current_user)
):
"""获取知识图谱"""
if not KnowledgebaseService.accessible(kb_id, current_user.id):
return get_json_result(
data=False,
@@ -441,7 +452,7 @@ async def knowledge_graph(
obj = {"graph": {}, "mind_map": {}}
if not settings.docStoreConn.indexExist(search.index_name(kb.tenant_id), kb_id):
return get_json_result(data=obj)
sres = settings.retrievaler.search(req, search.index_name(kb.tenant_id), [kb_id])
sres = settings.retriever.search(req, search.index_name(kb.tenant_id), [kb_id])
if not len(sres.ids):
return get_json_result(data=obj)
@@ -468,6 +479,7 @@ async def delete_knowledge_graph(
kb_id: str,
current_user = Depends(get_current_user)
):
"""删除知识图谱"""
if not KnowledgebaseService.accessible(kb_id, current_user.id):
return get_json_result(
data=False,
@@ -482,18 +494,19 @@ async def delete_knowledge_graph(
@router.get("/get_meta")
async def get_meta(
kb_ids: str = Query(..., description="知识库ID列表逗号分隔"),
kb_ids: str = Query(..., description="知识库ID列表逗号分隔"),
current_user = Depends(get_current_user)
):
kb_ids = kb_ids.split(",")
for kb_id in kb_ids:
"""获取知识库元数据"""
kb_id_list = kb_ids.split(",")
for kb_id in kb_id_list:
if not KnowledgebaseService.accessible(kb_id, current_user.id):
return get_json_result(
data=False,
message='No authorization.',
code=settings.RetCode.AUTHENTICATION_ERROR
)
return get_json_result(data=DocumentService.get_meta_by_kbs(kb_ids))
return get_json_result(data=DocumentService.get_meta_by_kbs(kb_id_list))
@router.get("/basic_info")
@@ -501,6 +514,7 @@ async def get_basic_info(
kb_id: str = Query(..., description="知识库ID"),
current_user = Depends(get_current_user)
):
"""获取知识库基本信息"""
if not KnowledgebaseService.accessible(kb_id, current_user.id):
return get_json_result(
data=False,
@@ -515,42 +529,45 @@ async def get_basic_info(
@router.post("/list_pipeline_logs")
async def list_pipeline_logs(
request: ListPipelineLogsRequest,
kb_id: str = Query(..., description="知识库ID"),
keywords: str = Query("", description="关键词"),
page: int = Query(0, description="页码"),
page_size: int = Query(0, description="每页大小"),
orderby: str = Query("create_time", description="排序字段"),
desc: bool = Query(True, description="是否降序"),
create_date_from: str = Query("", description="创建日期开始"),
create_date_to: str = Query("", description="创建日期结束"),
query: ListPipelineLogsQuery = Depends(),
body: Optional[ListPipelineLogsBody] = None,
current_user = Depends(get_current_user)
):
if not kb_id:
"""列出流水线日志"""
if not query.kb_id:
return get_json_result(data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
page_number = page
items_per_page = page_size
if body is None:
body = ListPipelineLogsBody()
keywords = query.keywords or ""
page_number = int(query.page or 0)
items_per_page = int(query.page_size or 0)
orderby = query.orderby or "create_time"
desc = query.desc.lower() == "true" if query.desc else True
create_date_from = query.create_date_from or ""
create_date_to = query.create_date_to or ""
if create_date_to > create_date_from:
return get_data_error_result(message="Create data filter is abnormal.")
operation_status = request.operation_status
operation_status = body.operation_status or []
if operation_status:
invalid_status = {s for s in operation_status if s not in VALID_TASK_STATUS}
if invalid_status:
return get_data_error_result(message=f"Invalid filter operation_status status conditions: {', '.join(invalid_status)}")
types = request.types
types = body.types or []
if types:
invalid_types = {t for t in types if t not in VALID_FILE_TYPES}
if invalid_types:
return get_data_error_result(message=f"Invalid filter conditions: {', '.join(invalid_types)} type{'s' if len(invalid_types) > 1 else ''}")
suffix = request.suffix
suffix = body.suffix or []
try:
logs, tol = PipelineOperationLogService.get_file_logs_by_kb_id(kb_id, page_number, items_per_page, orderby, desc, keywords, operation_status, types, suffix, create_date_from, create_date_to)
logs, tol = PipelineOperationLogService.get_file_logs_by_kb_id(
query.kb_id, page_number, items_per_page, orderby, desc, keywords,
operation_status, types, suffix, create_date_from, create_date_to)
return get_json_result(data={"total": tol, "logs": logs})
except Exception as e:
return server_error_response(e)
@@ -558,33 +575,36 @@ async def list_pipeline_logs(
@router.post("/list_pipeline_dataset_logs")
async def list_pipeline_dataset_logs(
request: ListPipelineDatasetLogsRequest,
kb_id: str = Query(..., description="知识库ID"),
page: int = Query(0, description="页码"),
page_size: int = Query(0, description="每页大小"),
orderby: str = Query("create_time", description="排序字段"),
desc: bool = Query(True, description="是否降序"),
create_date_from: str = Query("", description="创建日期开始"),
create_date_to: str = Query("", description="创建日期结束"),
query: ListPipelineDatasetLogsQuery = Depends(),
body: Optional[ListPipelineDatasetLogsBody] = None,
current_user = Depends(get_current_user)
):
if not kb_id:
"""列出流水线数据集日志"""
if not query.kb_id:
return get_json_result(data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
page_number = page
items_per_page = page_size
if body is None:
body = ListPipelineDatasetLogsBody()
page_number = int(query.page or 0)
items_per_page = int(query.page_size or 0)
orderby = query.orderby or "create_time"
desc = query.desc.lower() == "true" if query.desc else True
create_date_from = query.create_date_from or ""
create_date_to = query.create_date_to or ""
if create_date_to > create_date_from:
return get_data_error_result(message="Create data filter is abnormal.")
operation_status = request.operation_status
operation_status = body.operation_status or []
if operation_status:
invalid_status = {s for s in operation_status if s not in VALID_TASK_STATUS}
if invalid_status:
return get_data_error_result(message=f"Invalid filter operation_status status conditions: {', '.join(invalid_status)}")
try:
logs, tol = PipelineOperationLogService.get_dataset_logs_by_kb_id(kb_id, page_number, items_per_page, orderby, desc, operation_status, create_date_from, create_date_to)
logs, tol = PipelineOperationLogService.get_dataset_logs_by_kb_id(
query.kb_id, page_number, items_per_page, orderby, desc,
operation_status, create_date_from, create_date_to)
return get_json_result(data={"total": tol, "logs": logs})
except Exception as e:
return server_error_response(e)
@@ -592,14 +612,18 @@ async def list_pipeline_dataset_logs(
@router.post("/delete_pipeline_logs")
async def delete_pipeline_logs(
request: DeletePipelineLogsRequest,
kb_id: str = Query(..., description="知识库ID"),
query: DeletePipelineLogsQuery = Depends(),
body: Optional[DeletePipelineLogsBody] = None,
current_user = Depends(get_current_user)
):
if not kb_id:
"""删除流水线日志"""
if not query.kb_id:
return get_json_result(data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
log_ids = request.log_ids
if body is None:
body = DeletePipelineLogsBody(log_ids=[])
log_ids = body.log_ids or []
PipelineOperationLogService.delete_by_ids(log_ids)
@@ -608,9 +632,10 @@ async def delete_pipeline_logs(
@router.get("/pipeline_log_detail")
async def pipeline_log_detail(
log_id: str = Query(..., description="日志ID"),
log_id: str = Query(..., description="流水线日志ID"),
current_user = Depends(get_current_user)
):
"""获取流水线日志详情"""
if not log_id:
return get_json_result(data=False, message='Lack of "Pipeline log ID"', code=settings.RetCode.ARGUMENT_ERROR)
@@ -623,9 +648,10 @@ async def pipeline_log_detail(
@router.post("/run_graphrag")
async def run_graphrag(
request: RunGraphRAGRequest,
request: RunGraphragRequest,
current_user = Depends(get_current_user)
):
"""运行 GraphRAG"""
kb_id = request.kb_id
if not kb_id:
return get_error_data_result(message='Lack of "KB ID"')
@@ -660,7 +686,7 @@ async def run_graphrag(
sample_document = documents[0]
document_ids = [document["id"] for document in documents]
task_id = queue_raptor_o_graphrag_tasks(doc=sample_document, ty="graphrag", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids))
task_id = queue_raptor_o_graphrag_tasks(sample_doc_id=sample_document, ty="graphrag", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids))
if not KnowledgebaseService.update_by_id(kb.id, {"graphrag_task_id": task_id}):
logging.warning(f"Cannot save graphrag_task_id for kb {kb_id}")
@@ -673,6 +699,7 @@ async def trace_graphrag(
kb_id: str = Query(..., description="知识库ID"),
current_user = Depends(get_current_user)
):
"""追踪 GraphRAG 任务"""
if not kb_id:
return get_error_data_result(message='Lack of "KB ID"')
@@ -696,6 +723,7 @@ async def run_raptor(
request: RunRaptorRequest,
current_user = Depends(get_current_user)
):
"""运行 RAPTOR"""
kb_id = request.kb_id
if not kb_id:
return get_error_data_result(message='Lack of "KB ID"')
@@ -730,7 +758,7 @@ async def run_raptor(
sample_document = documents[0]
document_ids = [document["id"] for document in documents]
task_id = queue_raptor_o_graphrag_tasks(doc=sample_document, ty="raptor", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids))
task_id = queue_raptor_o_graphrag_tasks(sample_doc_id=sample_document, ty="raptor", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids))
if not KnowledgebaseService.update_by_id(kb.id, {"raptor_task_id": task_id}):
logging.warning(f"Cannot save raptor_task_id for kb {kb_id}")
@@ -743,6 +771,7 @@ async def trace_raptor(
kb_id: str = Query(..., description="知识库ID"),
current_user = Depends(get_current_user)
):
"""追踪 RAPTOR 任务"""
if not kb_id:
return get_error_data_result(message='Lack of "KB ID"')
@@ -766,6 +795,7 @@ async def run_mindmap(
request: RunMindmapRequest,
current_user = Depends(get_current_user)
):
"""运行 Mindmap"""
kb_id = request.kb_id
if not kb_id:
return get_error_data_result(message='Lack of "KB ID"')
@@ -800,7 +830,7 @@ async def run_mindmap(
sample_document = documents[0]
document_ids = [document["id"] for document in documents]
task_id = queue_raptor_o_graphrag_tasks(doc=sample_document, ty="mindmap", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids))
task_id = queue_raptor_o_graphrag_tasks(sample_doc_id=sample_document, ty="mindmap", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids))
if not KnowledgebaseService.update_by_id(kb.id, {"mindmap_task_id": task_id}):
logging.warning(f"Cannot save mindmap_task_id for kb {kb_id}")
@@ -813,6 +843,7 @@ async def trace_mindmap(
kb_id: str = Query(..., description="知识库ID"),
current_user = Depends(get_current_user)
):
"""追踪 Mindmap 任务"""
if not kb_id:
return get_error_data_result(message='Lack of "KB ID"')
@@ -834,33 +865,43 @@ async def trace_mindmap(
@router.delete("/unbind_task")
async def delete_kb_task(
kb_id: str = Query(..., description="知识库ID"),
pipeline_task_type: str = Query(..., description="管道任务类型"),
pipeline_task_type: str = Query(..., description="流水线任务类型"),
current_user = Depends(get_current_user)
):
"""解绑任务"""
if not kb_id:
return get_error_data_result(message='Lack of "KB ID"')
ok, kb = KnowledgebaseService.get_by_id(kb_id)
if not ok:
return get_json_result(data=True)
if not pipeline_task_type or pipeline_task_type not in [PipelineTaskType.GRAPH_RAG, PipelineTaskType.RAPTOR, PipelineTaskType.MINDMAP]:
return get_error_data_result(message="Invalid task type")
match pipeline_task_type:
case PipelineTaskType.GRAPH_RAG:
settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]}, search.index_name(kb.tenant_id), kb_id)
kb_task_id = "graphrag_task_id"
kb_task_id_field = "graphrag_task_id"
task_id = kb.graphrag_task_id
kb_task_finish_at = "graphrag_task_finish_at"
case PipelineTaskType.RAPTOR:
kb_task_id = "raptor_task_id"
kb_task_id_field = "raptor_task_id"
task_id = kb.raptor_task_id
kb_task_finish_at = "raptor_task_finish_at"
case PipelineTaskType.MINDMAP:
kb_task_id = "mindmap_task_id"
kb_task_id_field = "mindmap_task_id"
task_id = kb.mindmap_task_id
kb_task_finish_at = "mindmap_task_finish_at"
case _:
return get_error_data_result(message="Internal Error: Invalid task type")
ok = KnowledgebaseService.update_by_id(kb_id, {kb_task_id: "", kb_task_finish_at: None})
def cancel_task(task_id):
REDIS_CONN.set(f"{task_id}-cancel", "x")
cancel_task(task_id)
ok = KnowledgebaseService.update_by_id(kb_id, {kb_task_id_field: "", kb_task_finish_at: None})
if not ok:
return server_error_response(f"Internal error: cannot delete task {pipeline_task_type}")
return get_json_result(data=True)

View File

@@ -15,26 +15,36 @@
#
import logging
import json
from typing import Optional
from fastapi import APIRouter, Depends, Query
from fastapi.responses import JSONResponse
from api.apps.models.auth_dependencies import get_current_user
from api.apps.models.llm_models import (
SetApiKeyRequest,
AddLLMRequest,
DeleteLLMRequest,
DeleteFactoryRequest,
MyLLMsQuery,
ListLLMsQuery,
)
from api.db.services.tenant_llm_service import LLMFactoriesService, TenantLLMService
from api.db.services.llm_service import LLMService
from api import settings
from api.utils.api_utils import server_error_response, get_data_error_result, get_json_result
from api.utils.api_utils import server_error_response, get_data_error_result
from api.db import StatusEnum, LLMType
from api.db.db_models import TenantLLM
from api.utils.api_utils import get_json_result
from api.utils.base64_image import test_image
from rag.llm import EmbeddingModel, ChatModel, RerankModel, CvModel, TTSModel
from api.models.llm_models import SetApiKeyRequest, AddLLMRequest, DeleteLLMRequest, DeleteFactoryRequest
from api.utils.api_utils import get_current_user
# 创建 FastAPI 路由器
# 创建路由器
router = APIRouter()
@router.get('/factories')
async def factories(current_user = Depends(get_current_user)):
async def factories(
current_user = Depends(get_current_user)
):
"""获取 LLM 工厂列表"""
try:
fac = LLMFactoriesService.get_all()
fac = [f.to_dict() for f in fac if f.name not in ["Youdao", "FastEmbed", "BAAI"]]
@@ -55,17 +65,22 @@ async def factories(current_user = Depends(get_current_user)):
@router.post('/set_api_key')
async def set_api_key(request: SetApiKeyRequest, current_user = Depends(get_current_user)):
async def set_api_key(
request: SetApiKeyRequest,
current_user = Depends(get_current_user)
):
"""设置 API Key"""
req = request.model_dump(exclude_unset=True)
# test if api key works
chat_passed, embd_passed, rerank_passed = False, False, False
factory = request.llm_factory
factory = req["llm_factory"]
extra = {"provider": factory}
msg = ""
for llm in LLMService.query(fid=factory):
if not embd_passed and llm.model_type == LLMType.EMBEDDING.value:
assert factory in EmbeddingModel, f"Embedding model from {factory} is not supported yet."
mdl = EmbeddingModel[factory](
request.api_key, llm.llm_name, base_url=request.base_url)
req["api_key"], llm.llm_name, base_url=req.get("base_url", ""))
try:
arr, tc = mdl.encode(["Test if the api key is available"])
if len(arr[0]) == 0:
@@ -76,7 +91,7 @@ async def set_api_key(request: SetApiKeyRequest, current_user = Depends(get_curr
elif not chat_passed and llm.model_type == LLMType.CHAT.value:
assert factory in ChatModel, f"Chat model from {factory} is not supported yet."
mdl = ChatModel[factory](
request.api_key, llm.llm_name, base_url=request.base_url, **extra)
req["api_key"], llm.llm_name, base_url=req.get("base_url", ""), **extra)
try:
m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}],
{"temperature": 0.9, 'max_tokens': 50})
@@ -89,7 +104,7 @@ async def set_api_key(request: SetApiKeyRequest, current_user = Depends(get_curr
elif not rerank_passed and llm.model_type == LLMType.RERANK:
assert factory in RerankModel, f"Re-rank model from {factory} is not supported yet."
mdl = RerankModel[factory](
request.api_key, llm.llm_name, base_url=request.base_url)
req["api_key"], llm.llm_name, base_url=req.get("base_url", ""))
try:
arr, tc = mdl.similarity("What's the weather?", ["Is it sunny today?"])
if len(arr) == 0 or tc == 0:
@@ -107,9 +122,12 @@ async def set_api_key(request: SetApiKeyRequest, current_user = Depends(get_curr
return get_data_error_result(message=msg)
llm_config = {
"api_key": request.api_key,
"api_base": request.base_url or ""
"api_key": req["api_key"],
"api_base": req.get("base_url", "")
}
for n in ["model_type", "llm_name"]:
if n in req:
llm_config[n] = req[n]
for llm in LLMService.query(fid=factory):
llm_config["max_tokens"]=llm.max_tokens
@@ -132,14 +150,19 @@ async def set_api_key(request: SetApiKeyRequest, current_user = Depends(get_curr
@router.post('/add_llm')
async def add_llm(request: AddLLMRequest, current_user = Depends(get_current_user)):
factory = request.llm_factory
api_key = request.api_key or "x"
llm_name = request.llm_name
async def add_llm(
request: AddLLMRequest,
current_user = Depends(get_current_user)
):
"""添加 LLM"""
req = request.model_dump(exclude_unset=True)
factory = req["llm_factory"]
api_key = req.get("api_key", "x")
llm_name = req.get("llm_name")
def apikey_json(keys):
nonlocal request
return json.dumps({k: getattr(request, k, "") for k in keys})
nonlocal req
return json.dumps({k: req.get(k, "") for k in keys})
if factory == "VolcEngine":
# For VolcEngine, due to its special authentication method
@@ -147,21 +170,28 @@ async def add_llm(request: AddLLMRequest, current_user = Depends(get_current_use
api_key = apikey_json(["ark_api_key", "endpoint_id"])
elif factory == "Tencent Hunyuan":
# Create a temporary request object for set_api_key
temp_request = SetApiKeyRequest(
llm_factory=factory,
api_key=apikey_json(["hunyuan_sid", "hunyuan_sk"]),
base_url=request.api_base
req["api_key"] = apikey_json(["hunyuan_sid", "hunyuan_sk"])
# 创建 SetApiKeyRequest 并调用 set_api_key 逻辑
set_api_key_req = SetApiKeyRequest(
llm_factory=req["llm_factory"],
api_key=req["api_key"],
base_url=req.get("api_base", req.get("base_url", "")),
model_type=req.get("model_type"),
llm_name=req.get("llm_name")
)
return await set_api_key(temp_request, current_user)
return await set_api_key(set_api_key_req, current_user)
elif factory == "Tencent Cloud":
temp_request = SetApiKeyRequest(
llm_factory=factory,
api_key=apikey_json(["tencent_cloud_sid", "tencent_cloud_sk"]),
base_url=request.api_base
req["api_key"] = apikey_json(["tencent_cloud_sid", "tencent_cloud_sk"])
# 创建 SetApiKeyRequest 并调用 set_api_key 逻辑
set_api_key_req = SetApiKeyRequest(
llm_factory=req["llm_factory"],
api_key=req["api_key"],
base_url=req.get("api_base", req.get("base_url", "")),
model_type=req.get("model_type"),
llm_name=req.get("llm_name")
)
return await set_api_key(temp_request, current_user)
return await set_api_key(set_api_key_req, current_user)
elif factory == "Bedrock":
# For Bedrock, due to its special authentication method
@@ -181,9 +211,9 @@ async def add_llm(request: AddLLMRequest, current_user = Depends(get_current_use
llm_name += "___VLLM"
elif factory == "XunFei Spark":
if request.model_type == "chat":
api_key = request.spark_api_password or ""
elif request.model_type == "tts":
if req["model_type"] == "chat":
api_key = req.get("spark_api_password", "")
elif req["model_type"] == "tts":
api_key = apikey_json(["spark_app_id", "spark_api_secret", "spark_api_key"])
elif factory == "BaiduYiyan":
@@ -198,14 +228,17 @@ async def add_llm(request: AddLLMRequest, current_user = Depends(get_current_use
elif factory == "Azure-OpenAI":
api_key = apikey_json(["api_key", "api_version"])
elif factory == "OpenRouter":
api_key = apikey_json(["api_key", "provider_order"])
llm = {
"tenant_id": current_user.id,
"llm_factory": factory,
"model_type": request.model_type,
"model_type": req["model_type"],
"llm_name": llm_name,
"api_base": request.api_base or "",
"api_base": req.get("api_base", ""),
"api_key": api_key,
"max_tokens": request.max_tokens
"max_tokens": req.get("max_tokens")
}
msg = ""
@@ -295,7 +328,11 @@ async def add_llm(request: AddLLMRequest, current_user = Depends(get_current_use
@router.post('/delete_llm')
async def delete_llm(request: DeleteLLMRequest, current_user = Depends(get_current_user)):
async def delete_llm(
request: DeleteLLMRequest,
current_user = Depends(get_current_user)
):
"""删除 LLM"""
TenantLLMService.filter_delete(
[TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == request.llm_factory,
TenantLLM.llm_name == request.llm_name])
@@ -303,7 +340,11 @@ async def delete_llm(request: DeleteLLMRequest, current_user = Depends(get_curre
@router.post('/delete_factory')
async def delete_factory(request: DeleteFactoryRequest, current_user = Depends(get_current_user)):
async def delete_factory(
request: DeleteFactoryRequest,
current_user = Depends(get_current_user)
):
"""删除工厂"""
TenantLLMService.filter_delete(
[TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == request.llm_factory])
return get_json_result(data=True)
@@ -311,10 +352,13 @@ async def delete_factory(request: DeleteFactoryRequest, current_user = Depends(g
@router.get('/my_llms')
async def my_llms(
include_details: bool = Query(False, description="是否包含详细信息"),
query: MyLLMsQuery = Depends(),
current_user = Depends(get_current_user)
):
"""获取我的 LLMs"""
try:
include_details = query.include_details.lower() == 'true'
if include_details:
res = {}
objs = TenantLLMService.query(tenant_id=current_user.id)
@@ -362,11 +406,13 @@ async def my_llms(
@router.get('/list')
async def list_app(
model_type: Optional[str] = Query(None, description="模型类型"),
query: ListLLMsQuery = Depends(),
current_user = Depends(get_current_user)
):
"""列出 LLMs"""
self_deployed = ["Youdao", "FastEmbed", "BAAI", "Ollama", "Xinference", "LocalAI", "LM-Studio", "GPUStack"]
weighted = ["Youdao", "FastEmbed", "BAAI"] if settings.LIGHTEN != 0 else []
model_type = query.model_type
try:
objs = TenantLLMService.query(tenant_id=current_user.id)
facts = set([o.to_dict()["llm_factory"] for o in objs if o.api_key])

View File

@@ -13,164 +13,61 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import List, Optional, Dict, Any
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi.security import HTTPAuthorizationCredentials
from api.utils.api_utils import security
from fastapi import APIRouter, Depends, Query
from api.apps.models.auth_dependencies import get_current_user
from api.apps.models.mcp_models import (
ListMCPServersQuery,
ListMCPServersBody,
CreateMCPServerRequest,
UpdateMCPServerRequest,
DeleteMCPServersRequest,
ImportMCPServersRequest,
ExportMCPServersRequest,
ListMCPToolsRequest,
TestMCPToolRequest,
CacheMCPToolsRequest,
TestMCPRequest,
)
from api import settings
from api.db import VALID_MCP_SERVER_TYPES
from api.db.db_models import MCPServer, User
from api.db.db_models import MCPServer
from api.db.services.mcp_server_service import MCPServerService
from api.db.services.user_service import TenantService
from api.settings import RetCode
from api.utils import get_uuid
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, get_mcp_tools
from api.utils.web_utils import get_float, safe_json_parse
from api.utils.web_utils import safe_json_parse
from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
from pydantic import BaseModel
# Security
# Pydantic models for request/response
class ListMCPRequest(BaseModel):
mcp_ids: List[str] = []
class CreateMCPRequest(BaseModel):
name: str
url: str
server_type: str
headers: Optional[Dict[str, Any]] = {}
variables: Optional[Dict[str, Any]] = {}
timeout: Optional[float] = 10
class UpdateMCPRequest(BaseModel):
mcp_id: str
name: Optional[str] = None
url: Optional[str] = None
server_type: Optional[str] = None
headers: Optional[Dict[str, Any]] = None
variables: Optional[Dict[str, Any]] = None
timeout: Optional[float] = 10
class RemoveMCPRequest(BaseModel):
mcp_ids: List[str]
class ImportMCPRequest(BaseModel):
mcpServers: Dict[str, Dict[str, Any]]
timeout: Optional[float] = 10
class ExportMCPRequest(BaseModel):
mcp_ids: List[str]
class ListToolsRequest(BaseModel):
mcp_ids: List[str]
timeout: Optional[float] = 10
class TestToolRequest(BaseModel):
mcp_id: str
tool_name: str
arguments: Dict[str, Any]
timeout: Optional[float] = 10
class CacheToolsRequest(BaseModel):
mcp_id: str
tools: List[Dict[str, Any]]
class TestMCPRequest(BaseModel):
url: str
server_type: str
timeout: Optional[float] = 10
headers: Optional[Dict[str, Any]] = {}
variables: Optional[Dict[str, Any]] = {}
# Dependency injection
async def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(security)):
"""获取当前用户"""
from api.db import StatusEnum
from api.db.services.user_service import UserService
from fastapi import HTTPException, status
import logging
try:
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
except ImportError:
# 如果没有itsdangerous使用jwt作为替代
import jwt
Serializer = jwt
jwt = Serializer(secret_key=settings.SECRET_KEY)
authorization = credentials.credentials
if authorization:
try:
access_token = str(jwt.loads(authorization))
if not access_token or not access_token.strip():
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authentication attempt with empty access token"
)
# Access tokens should be UUIDs (32 hex characters)
if len(access_token.strip()) < 32:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=f"Authentication attempt with invalid token format: {len(access_token)} chars"
)
user = UserService.query(
access_token=access_token, status=StatusEnum.VALID.value
)
if user:
if not user[0].access_token or not user[0].access_token.strip():
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=f"User {user[0].email} has empty access_token in database"
)
return user[0]
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid access token"
)
except Exception as e:
logging.warning(f"load_user got exception {e}")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid access token"
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authorization header required"
)
# Create router
# 创建路由器
router = APIRouter()
@router.post("/list")
async def list_mcp(
request: ListMCPRequest,
keywords: str = Query("", description="Search keywords"),
page: int = Query(0, description="Page number"),
page_size: int = Query(0, description="Items per page"),
orderby: str = Query("create_time", description="Order by field"),
desc: bool = Query(True, description="Sort descending"),
current_user: User = Depends(get_current_user)
query: ListMCPServersQuery = Depends(),
body: ListMCPServersBody = None,
current_user = Depends(get_current_user)
):
"""列出MCP服务器"""
if body is None:
body = ListMCPServersBody()
keywords = query.keywords or ""
page_number = int(query.page or 0)
items_per_page = int(query.page_size or 0)
orderby = query.orderby or "create_time"
desc = query.desc.lower() == "true" if query.desc else True
mcp_ids = body.mcp_ids or []
try:
servers = MCPServerService.get_servers(current_user.id, request.mcp_ids, 0, 0, orderby, desc, keywords) or []
servers = MCPServerService.get_servers(current_user.id, mcp_ids, 0, 0, orderby, desc, keywords) or []
total = len(servers)
if page and page_size:
servers = servers[(page - 1) * page_size : page * page_size]
if page_number and items_per_page:
servers = servers[(page_number - 1) * items_per_page : page_number * items_per_page]
return get_json_result(data={"mcp_servers": servers, "total": total})
except Exception as e:
@@ -179,9 +76,10 @@ async def list_mcp(
@router.get("/detail")
async def detail(
mcp_id: str = Query(..., description="MCP server ID"),
current_user: User = Depends(get_current_user)
mcp_id: str = Query(..., description="MCP服务器ID"),
current_user = Depends(get_current_user)
):
"""获取MCP服务器详情"""
try:
mcp_server = MCPServerService.get_or_none(id=mcp_id, tenant_id=current_user.id)
@@ -195,9 +93,10 @@ async def detail(
@router.post("/create")
async def create(
request: CreateMCPRequest,
current_user: User = Depends(get_current_user)
request: CreateMCPServerRequest,
current_user = Depends(get_current_user)
):
"""创建MCP服务器"""
server_type = request.server_type
if server_type not in VALID_MCP_SERVER_TYPES:
return get_data_error_result(message="Unsupported MCP server type.")
@@ -218,10 +117,10 @@ async def create(
variables = safe_json_parse(request.variables or {})
variables.pop("tools", None)
timeout = request.timeout or 10
timeout = request.timeout or 10.0
try:
req_data = {
req = {
"id": get_uuid(),
"tenant_id": current_user.id,
"name": server_name,
@@ -229,7 +128,6 @@ async def create(
"server_type": server_type,
"headers": headers,
"variables": variables,
"timeout": timeout
}
e, _ = TenantService.get_by_id(current_user.id)
@@ -244,46 +142,45 @@ async def create(
tools = server_tools[server_name]
tools = {tool["name"]: tool for tool in tools if isinstance(tool, dict) and "name" in tool}
variables["tools"] = tools
req_data["variables"] = variables
req["variables"] = variables
if not MCPServerService.insert(**req_data):
if not MCPServerService.insert(**req):
return get_data_error_result("Failed to create MCP server.")
return get_json_result(data=req_data)
return get_json_result(data=req)
except Exception as e:
return server_error_response(e)
@router.post("/update")
async def update(
request: UpdateMCPRequest,
current_user: User = Depends(get_current_user)
request: UpdateMCPServerRequest,
current_user = Depends(get_current_user)
):
"""更新MCP服务器"""
mcp_id = request.mcp_id
e, mcp_server = MCPServerService.get_by_id(mcp_id)
if not e or mcp_server.tenant_id != current_user.id:
return get_data_error_result(message=f"Cannot find MCP server {mcp_id} for user {current_user.id}")
server_type = request.server_type or mcp_server.server_type
server_type = request.server_type if request.server_type is not None else mcp_server.server_type
if server_type and server_type not in VALID_MCP_SERVER_TYPES:
return get_data_error_result(message="Unsupported MCP server type.")
server_name = request.name or mcp_server.name
server_name = request.name if request.name is not None else mcp_server.name
if server_name and len(server_name.encode("utf-8")) > 255:
return get_data_error_result(message=f"Invalid MCP name or length is {len(server_name)} which is large than 255.")
url = request.url or mcp_server.url
url = request.url if request.url is not None else mcp_server.url
if not url:
return get_data_error_result(message="Invalid url.")
headers = safe_json_parse(request.headers or mcp_server.headers)
variables = safe_json_parse(request.variables or mcp_server.variables)
headers = safe_json_parse(request.headers if request.headers is not None else mcp_server.headers)
variables = safe_json_parse(request.variables if request.variables is not None else mcp_server.variables)
variables.pop("tools", None)
timeout = request.timeout or 10
timeout = request.timeout or 10.0
try:
req_data = {
req = {
"tenant_id": current_user.id,
"id": mcp_id,
"name": server_name,
@@ -291,7 +188,6 @@ async def update(
"server_type": server_type,
"headers": headers,
"variables": variables,
"timeout": timeout
}
mcp_server = MCPServer(id=server_name, name=server_name, url=url, server_type=server_type, variables=variables, headers=headers)
@@ -302,12 +198,12 @@ async def update(
tools = server_tools[server_name]
tools = {tool["name"]: tool for tool in tools if isinstance(tool, dict) and "name" in tool}
variables["tools"] = tools
req_data["variables"] = variables
req["variables"] = variables
if not MCPServerService.filter_update([MCPServer.id == mcp_id, MCPServer.tenant_id == current_user.id], req_data):
if not MCPServerService.filter_update([MCPServer.id == mcp_id, MCPServer.tenant_id == current_user.id], req):
return get_data_error_result(message="Failed to updated MCP server.")
e, updated_mcp = MCPServerService.get_by_id(req_data["id"])
e, updated_mcp = MCPServerService.get_by_id(req["id"])
if not e:
return get_data_error_result(message="Failed to fetch updated MCP server.")
@@ -318,9 +214,10 @@ async def update(
@router.post("/rm")
async def rm(
request: RemoveMCPRequest,
current_user: User = Depends(get_current_user)
request: DeleteMCPServersRequest,
current_user = Depends(get_current_user)
):
"""删除MCP服务器"""
mcp_ids = request.mcp_ids
try:
@@ -334,14 +231,15 @@ async def rm(
@router.post("/import")
async def import_multiple(
request: ImportMCPRequest,
current_user: User = Depends(get_current_user)
request: ImportMCPServersRequest,
current_user = Depends(get_current_user)
):
"""批量导入MCP服务器"""
servers = request.mcpServers
if not servers:
return get_data_error_result(message="No MCP servers provided.")
timeout = request.timeout or 10
timeout = request.timeout or 10.0
results = []
try:
@@ -401,9 +299,10 @@ async def import_multiple(
@router.post("/export")
async def export_multiple(
request: ExportMCPRequest,
current_user: User = Depends(get_current_user)
request: ExportMCPServersRequest,
current_user = Depends(get_current_user)
):
"""批量导出MCP服务器"""
mcp_ids = request.mcp_ids
if not mcp_ids:
@@ -432,12 +331,16 @@ async def export_multiple(
@router.post("/list_tools")
async def list_tools(req: ListToolsRequest, current_user: User = Depends(get_current_user)):
mcp_ids = req.mcp_ids
async def list_tools(
request: ListMCPToolsRequest,
current_user = Depends(get_current_user)
):
"""列出MCP工具"""
mcp_ids = request.mcp_ids
if not mcp_ids:
return get_data_error_result(message="No MCP server IDs provided.")
timeout = req.timeout
timeout = request.timeout or 10.0
results = {}
tool_call_sessions = []
@@ -476,15 +379,19 @@ async def list_tools(req: ListToolsRequest, current_user: User = Depends(get_cur
@router.post("/test_tool")
async def test_tool(req: TestToolRequest, current_user: User = Depends(get_current_user)):
mcp_id = req.mcp_id
async def test_tool(
request: TestMCPToolRequest,
current_user = Depends(get_current_user)
):
"""测试MCP工具"""
mcp_id = request.mcp_id
if not mcp_id:
return get_data_error_result(message="No MCP server ID provided.")
timeout = req.timeout
timeout = request.timeout or 10.0
tool_name = req.tool_name
arguments = req.arguments
tool_name = request.tool_name
arguments = request.arguments
if not all([tool_name, arguments]):
return get_data_error_result(message="Require provide tool name and arguments.")
@@ -506,11 +413,15 @@ async def test_tool(req: TestToolRequest, current_user: User = Depends(get_curre
@router.post("/cache_tools")
async def cache_tool(req: CacheToolsRequest, current_user: User = Depends(get_current_user)):
mcp_id = req.mcp_id
async def cache_tool(
request: CacheMCPToolsRequest,
current_user = Depends(get_current_user)
):
"""缓存MCP工具"""
mcp_id = request.mcp_id
if not mcp_id:
return get_data_error_result(message="No MCP server ID provided.")
tools = req.tools
tools = request.tools
e, mcp_server = MCPServerService.get_by_id(mcp_id)
if not e or mcp_server.tenant_id != current_user.id:
@@ -527,18 +438,21 @@ async def cache_tool(req: CacheToolsRequest, current_user: User = Depends(get_cu
@router.post("/test_mcp")
async def test_mcp(req: TestMCPRequest):
url = req.url
async def test_mcp(
request: TestMCPRequest
):
"""测试MCP服务器不需要登录"""
url = request.url
if not url:
return get_data_error_result(message="Invalid MCP url.")
server_type = req.server_type
server_type = request.server_type
if server_type not in VALID_MCP_SERVER_TYPES:
return get_data_error_result(message="Unsupported MCP server type.")
timeout = req.timeout
headers = req.headers
variables = req.variables
timeout = request.timeout or 10.0
headers = safe_json_parse(request.headers or {})
variables = safe_json_parse(request.variables or {})
mcp_server = MCPServer(id=f"{server_type}: {url}", server_type=server_type, url=url, headers=headers, variables=variables)

View File

@@ -0,0 +1,53 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Optional
from fastapi import Depends, Header, Security, HTTPException, status
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from api import settings
from api.utils.api_utils import get_json_result
# 创建 HTTPBearer 安全方案auto_error=False 允许我们自定义错误处理)
http_bearer = HTTPBearer(auto_error=False)
def get_current_user(credentials: Optional[HTTPAuthorizationCredentials] = Security(http_bearer)):
"""FastAPI 依赖注入:获取当前用户(替代 Flask 的 login_required 和 current_user
使用 Security(http_bearer) 可以让 FastAPI 自动在 OpenAPI schema 中添加安全要求,
这样 Swagger UI 就会显示授权输入框并自动在请求中添加 Authorization 头。
"""
# 延迟导入以避免循环导入
from api.apps.__init___fastapi import get_current_user_from_token
if not credentials:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authorization header is required"
)
# HTTPBearer 已经提取了 Bearer tokencredentials.credentials 就是 token 本身
authorization = credentials.credentials
user = get_current_user_from_token(authorization)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid or expired token"
)
return user

View File

@@ -0,0 +1,129 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Optional, List, Dict, Any, Union
from pydantic import BaseModel, Field
class DeleteCanvasRequest(BaseModel):
"""删除画布请求"""
canvas_ids: List[str] = Field(..., description="画布ID列表")
class SaveCanvasRequest(BaseModel):
"""保存画布请求"""
dsl: Union[str, Dict[str, Any]] = Field(..., description="DSL配置")
title: str = Field(..., description="画布标题")
id: Optional[str] = Field(default=None, description="画布ID更新时提供")
canvas_category: Optional[str] = Field(default=None, description="画布类别")
description: Optional[str] = Field(default=None, description="描述")
permission: Optional[str] = Field(default=None, description="权限")
avatar: Optional[str] = Field(default=None, description="头像")
class CompletionRequest(BaseModel):
"""完成/运行画布请求"""
id: str = Field(..., description="画布ID")
query: Optional[str] = Field(default="", description="查询内容")
files: Optional[List[str]] = Field(default=[], description="文件列表")
inputs: Optional[Dict[str, Any]] = Field(default={}, description="输入参数")
user_id: Optional[str] = Field(default=None, description="用户ID")
class RerunRequest(BaseModel):
"""重新运行请求"""
id: str = Field(..., description="流水线ID")
dsl: Dict[str, Any] = Field(..., description="DSL配置")
component_id: str = Field(..., description="组件ID")
class ResetCanvasRequest(BaseModel):
"""重置画布请求"""
id: str = Field(..., description="画布ID")
class InputFormQuery(BaseModel):
"""获取输入表单查询参数"""
id: str = Field(..., description="画布ID")
component_id: str = Field(..., description="组件ID")
class DebugRequest(BaseModel):
"""调试请求"""
id: str = Field(..., description="画布ID")
component_id: str = Field(..., description="组件ID")
params: Dict[str, Any] = Field(..., description="参数")
class TestDBConnectRequest(BaseModel):
"""测试数据库连接请求"""
db_type: str = Field(..., description="数据库类型")
database: str = Field(..., description="数据库名")
username: str = Field(..., description="用户名")
host: str = Field(..., description="主机")
port: str = Field(..., description="端口")
password: str = Field(..., description="密码")
class ListCanvasQuery(BaseModel):
"""列出画布查询参数"""
keywords: Optional[str] = Field(default="", description="关键词")
page: Optional[int] = Field(default=0, description="页码")
page_size: Optional[int] = Field(default=0, description="每页大小")
orderby: Optional[str] = Field(default="create_time", description="排序字段")
desc: Optional[str] = Field(default="true", description="是否降序")
canvas_category: Optional[str] = Field(default=None, description="画布类别")
owner_ids: Optional[str] = Field(default="", description="所有者ID列表逗号分隔")
class SettingRequest(BaseModel):
"""画布设置请求"""
id: str = Field(..., description="画布ID")
title: str = Field(..., description="标题")
permission: str = Field(..., description="权限")
description: Optional[str] = Field(default=None, description="描述")
avatar: Optional[str] = Field(default=None, description="头像")
class TraceQuery(BaseModel):
"""追踪查询参数"""
canvas_id: str = Field(..., description="画布ID")
message_id: str = Field(..., description="消息ID")
class ListSessionsQuery(BaseModel):
"""列出会话查询参数"""
user_id: Optional[str] = Field(default=None, description="用户ID")
page: Optional[int] = Field(default=1, description="页码")
page_size: Optional[int] = Field(default=30, description="每页大小")
keywords: Optional[str] = Field(default=None, description="关键词")
from_date: Optional[str] = Field(default=None, description="开始日期")
to_date: Optional[str] = Field(default=None, description="结束日期")
orderby: Optional[str] = Field(default="update_time", description="排序字段")
desc: Optional[str] = Field(default="true", description="是否降序")
dsl: Optional[str] = Field(default="true", description="是否包含DSL")
class DownloadQuery(BaseModel):
"""下载查询参数"""
id: str = Field(..., description="文件ID")
created_by: str = Field(..., description="创建者ID")
class UploadQuery(BaseModel):
"""上传查询参数"""
url: Optional[str] = Field(default=None, description="URL可选用于从URL下载")

View File

@@ -0,0 +1,80 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Optional, List, Any, Union
from pydantic import BaseModel, Field, field_validator
class ListChunksRequest(BaseModel):
"""列出文档块请求"""
doc_id: str = Field(..., description="文档ID")
page: Optional[int] = Field(default=1, description="页码")
size: Optional[int] = Field(default=30, description="每页大小")
keywords: Optional[str] = Field(default="", description="关键词")
available_int: Optional[int] = Field(default=None, description="可用状态")
class SetChunkRequest(BaseModel):
"""设置文档块请求"""
doc_id: str = Field(..., description="文档ID")
chunk_id: str = Field(..., description="块ID")
content_with_weight: str = Field(..., description="内容")
important_kwd: Optional[List[str]] = Field(default=None, description="重要关键词列表")
question_kwd: Optional[List[str]] = Field(default=None, description="问题关键词列表")
tag_kwd: Optional[str] = Field(default=None, description="标签关键词")
tag_feas: Optional[Any] = Field(default=None, description="标签特征")
available_int: Optional[int] = Field(default=None, description="可用状态")
class SwitchChunksRequest(BaseModel):
"""切换文档块状态请求"""
doc_id: str = Field(..., description="文档ID")
chunk_ids: List[str] = Field(..., description="块ID列表")
available_int: int = Field(..., description="可用状态")
class DeleteChunksRequest(BaseModel):
"""删除文档块请求"""
doc_id: str = Field(..., description="文档ID")
chunk_ids: List[str] = Field(..., description="块ID列表")
class CreateChunkRequest(BaseModel):
"""创建文档块请求"""
doc_id: str = Field(..., description="文档ID")
content_with_weight: str = Field(..., description="内容")
important_kwd: Optional[List[str]] = Field(default=[], description="重要关键词列表")
question_kwd: Optional[List[str]] = Field(default=[], description="问题关键词列表")
tag_feas: Optional[Any] = Field(default=None, description="标签特征")
class RetrievalTestRequest(BaseModel):
"""检索测试请求"""
kb_id: Union[str, List[str]] = Field(..., description="知识库ID可以是字符串或列表")
question: str = Field(..., description="问题")
page: Optional[int] = Field(default=1, description="页码")
size: Optional[int] = Field(default=30, description="每页大小")
doc_ids: Optional[List[str]] = Field(default=[], description="文档ID列表")
use_kg: Optional[bool] = Field(default=False, description="是否使用知识图谱")
top_k: Optional[int] = Field(default=1024, description="Top K")
cross_languages: Optional[List[str]] = Field(default=[], description="跨语言列表")
search_id: Optional[str] = Field(default="", description="搜索ID")
rerank_id: Optional[str] = Field(default=None, description="重排序模型ID")
keyword: Optional[bool] = Field(default=False, description="是否使用关键词")
similarity_threshold: Optional[float] = Field(default=0.0, description="相似度阈值")
vector_similarity_weight: Optional[float] = Field(default=0.3, description="向量相似度权重")
highlight: Optional[bool] = Field(default=False, description="是否高亮")

View File

@@ -0,0 +1,204 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Optional, Literal, List
from pydantic import BaseModel, Field, model_validator
class CreateDocumentRequest(BaseModel):
"""创建文档请求
支持两种解析类型:
- parse_type=1: 使用内置解析器,需要 parser_idpipeline_id 为空
- parse_type=2: 使用自定义 pipeline需要 pipeline_idparser_id 为空
如果不提供 parse_type则从知识库继承解析配置
"""
name: str
kb_id: str
parse_type: Optional[Literal[1, 2]] = Field(default=None, description="解析类型1=内置解析器2=自定义pipelineNone=从知识库继承")
parser_id: Optional[str] = Field(default="", description="解析器IDparse_type=1时必需")
pipeline_id: Optional[str] = Field(default="", description="流水线IDparse_type=2时必需")
parser_config: Optional[dict] = None
@model_validator(mode='after')
def validate_parse_type_fields(self):
"""根据 parse_type 验证相应字段"""
if self.parse_type is not None:
if self.parse_type == 1:
# parse_type=1: 需要 parser_idpipeline_id 必须为空
parser_id_val = self.parser_id or ""
pipeline_id_val = self.pipeline_id or ""
if parser_id_val.strip() == "":
raise ValueError("parse_type=1时parser_id不能为空")
if pipeline_id_val.strip() != "":
raise ValueError("parse_type=1时pipeline_id必须为空")
elif self.parse_type == 2:
# parse_type=2: 需要 pipeline_idparser_id 必须为空
parser_id_val = self.parser_id or ""
pipeline_id_val = self.pipeline_id or ""
if pipeline_id_val.strip() == "":
raise ValueError("parse_type=2时pipeline_id不能为空")
if parser_id_val.strip() != "":
raise ValueError("parse_type=2时parser_id必须为空")
return self
class ChangeParserRequest(BaseModel):
"""修改文档解析器请求
支持两种解析类型:
- parse_type=1: 使用内置解析器,需要 parser_idpipeline_id 为空
- parse_type=2: 使用自定义 pipeline需要 pipeline_idparser_id 为空
"""
doc_id: str
parse_type: Literal[1, 2] = Field(..., description="解析类型1=内置解析器2=自定义pipeline")
parser_id: Optional[str] = Field(default="", description="解析器IDparse_type=1时必需")
pipeline_id: Optional[str] = Field(default="", description="流水线IDparse_type=2时必需")
parser_config: Optional[dict] = None
@model_validator(mode='after')
def validate_parse_type_fields(self):
"""根据 parse_type 验证相应字段"""
if self.parse_type == 1:
# parse_type=1: 需要 parser_idpipeline_id 必须为空
parser_id_val = self.parser_id or ""
pipeline_id_val = self.pipeline_id or ""
if parser_id_val.strip() == "":
raise ValueError("parse_type=1时parser_id不能为空")
if pipeline_id_val.strip() != "":
raise ValueError("parse_type=1时pipeline_id必须为空")
elif self.parse_type == 2:
# parse_type=2: 需要 pipeline_idparser_id 必须为空
parser_id_val = self.parser_id or ""
pipeline_id_val = self.pipeline_id or ""
if pipeline_id_val.strip() == "":
raise ValueError("parse_type=2时pipeline_id不能为空")
if parser_id_val.strip() != "":
raise ValueError("parse_type=2时parser_id必须为空")
return self
class WebCrawlRequest(BaseModel):
"""网页爬取请求"""
kb_id: str
name: str
url: str
class ListDocumentsQuery(BaseModel):
"""列出文档查询参数"""
kb_id: str
keywords: Optional[str] = ""
page: Optional[int] = 0
page_size: Optional[int] = 0
orderby: Optional[str] = "create_time"
desc: Optional[str] = "true"
create_time_from: Optional[int] = 0
create_time_to: Optional[int] = 0
class ListDocumentsBody(BaseModel):
"""列出文档请求体"""
run_status: Optional[List[str]] = []
types: Optional[List[str]] = []
suffix: Optional[List[str]] = []
class FilterDocumentsRequest(BaseModel):
"""过滤文档请求"""
kb_id: str
keywords: Optional[str] = ""
suffix: Optional[List[str]] = []
run_status: Optional[List[str]] = []
types: Optional[List[str]] = []
class GetDocumentInfosRequest(BaseModel):
"""获取文档信息请求"""
doc_ids: List[str]
class ChangeStatusRequest(BaseModel):
"""修改文档状态请求"""
doc_ids: List[str]
status: str # "0" 或 "1"
@model_validator(mode='after')
def validate_status(self):
if self.status not in ["0", "1"]:
raise ValueError('Status must be either 0 or 1!')
return self
class DeleteDocumentRequest(BaseModel):
"""删除文档请求"""
doc_id: str | List[str] # 支持单个或列表
class RunDocumentRequest(BaseModel):
"""运行文档解析请求"""
doc_ids: List[str]
run: str # TaskStatus 值
delete: Optional[bool] = False
class RenameDocumentRequest(BaseModel):
"""重命名文档请求"""
doc_id: str
name: str
class ChangeParserSimpleRequest(BaseModel):
"""简单修改解析器请求(兼容旧逻辑)"""
doc_id: str
parser_id: Optional[str] = None
pipeline_id: Optional[str] = None
parser_config: Optional[dict] = None
class UploadAndParseRequest(BaseModel):
"""上传并解析请求(仅用于验证 conversation_id"""
conversation_id: str
class ParseRequest(BaseModel):
"""解析请求"""
url: Optional[str] = None
class SetMetaRequest(BaseModel):
"""设置元数据请求"""
doc_id: str
meta: str # JSON 字符串
@model_validator(mode='after')
def validate_meta(self):
import json
try:
meta_dict = json.loads(self.meta)
if not isinstance(meta_dict, dict):
raise ValueError("Only dictionary type supported.")
for k, v in meta_dict.items():
if not isinstance(v, (str, int, float)):
raise ValueError(f"The type is not supported: {v}")
except json.JSONDecodeError as e:
raise ValueError(f"Json syntax error: {e}")
return self

View File

@@ -0,0 +1,159 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Optional, List, Dict, Any, Literal
from pydantic import BaseModel, Field, model_validator
class CreateKnowledgeBaseRequest(BaseModel):
"""创建知识库请求
支持两种解析类型:
- parse_type=1: 使用内置解析器,需要 parser_idpipeline_id 为空
- parse_type=2: 使用自定义 pipeline需要 pipeline_idparser_id 为空
"""
name: str
parse_type: Literal[1, 2] = Field(..., description="解析类型1=内置解析器2=自定义pipeline")
embd_id: str = Field(..., description="嵌入模型ID")
parser_id: Optional[str] = Field(default="", description="解析器IDparse_type=1时必需")
pipeline_id: Optional[str] = Field(default="", description="流水线IDparse_type=2时必需")
description: Optional[str] = None
pagerank: Optional[int] = None
@model_validator(mode='after')
def validate_parse_type_fields(self):
"""根据 parse_type 验证相应字段"""
if self.parse_type == 1:
# parse_type=1: 需要 parser_idpipeline_id 必须为空
parser_id_val = self.parser_id or ""
pipeline_id_val = self.pipeline_id or ""
if parser_id_val.strip() == "":
raise ValueError("parse_type=1时parser_id不能为空")
if pipeline_id_val.strip() != "":
raise ValueError("parse_type=1时pipeline_id必须为空")
elif self.parse_type == 2:
# parse_type=2: 需要 pipeline_idparser_id 必须为空
parser_id_val = self.parser_id or ""
pipeline_id_val = self.pipeline_id or ""
if pipeline_id_val.strip() == "":
raise ValueError("parse_type=2时pipeline_id不能为空")
if parser_id_val.strip() != "":
raise ValueError("parse_type=2时parser_id必须为空")
return self
class UpdateKnowledgeBaseRequest(BaseModel):
"""更新知识库请求"""
kb_id: str
name: str
description: str
parser_id: str
pagerank: Optional[int] = None
# 其他可选字段,但排除 id, tenant_id, created_by, create_time, update_time, create_date, update_date
class DeleteKnowledgeBaseRequest(BaseModel):
"""删除知识库请求"""
kb_id: str
class ListKnowledgeBasesQuery(BaseModel):
"""列出知识库查询参数"""
keywords: Optional[str] = ""
page: Optional[int] = 0
page_size: Optional[int] = 0
parser_id: Optional[str] = None
orderby: Optional[str] = "create_time"
desc: Optional[str] = "true"
class ListKnowledgeBasesBody(BaseModel):
"""列出知识库请求体"""
owner_ids: Optional[List[str]] = []
class RemoveTagsRequest(BaseModel):
"""删除标签请求"""
tags: List[str]
class RenameTagRequest(BaseModel):
"""重命名标签请求"""
from_tag: str
to_tag: str
class ListPipelineLogsQuery(BaseModel):
"""列出流水线日志查询参数"""
kb_id: str
keywords: Optional[str] = ""
page: Optional[int] = 0
page_size: Optional[int] = 0
orderby: Optional[str] = "create_time"
desc: Optional[str] = "true"
create_date_from: Optional[str] = ""
create_date_to: Optional[str] = ""
class ListPipelineLogsBody(BaseModel):
"""列出流水线日志请求体"""
operation_status: Optional[List[str]] = []
types: Optional[List[str]] = []
suffix: Optional[List[str]] = []
class ListPipelineDatasetLogsQuery(BaseModel):
"""列出流水线数据集日志查询参数"""
kb_id: str
page: Optional[int] = 0
page_size: Optional[int] = 0
orderby: Optional[str] = "create_time"
desc: Optional[str] = "true"
create_date_from: Optional[str] = ""
create_date_to: Optional[str] = ""
class ListPipelineDatasetLogsBody(BaseModel):
"""列出流水线数据集日志请求体"""
operation_status: Optional[List[str]] = []
class DeletePipelineLogsQuery(BaseModel):
"""删除流水线日志查询参数"""
kb_id: str
class DeletePipelineLogsBody(BaseModel):
"""删除流水线日志请求体"""
log_ids: List[str]
class RunGraphragRequest(BaseModel):
"""运行 GraphRAG 请求"""
kb_id: str
class RunRaptorRequest(BaseModel):
"""运行 RAPTOR 请求"""
kb_id: str
class RunMindmapRequest(BaseModel):
"""运行 Mindmap 请求"""
kb_id: str

View File

@@ -0,0 +1,101 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Optional
from pydantic import BaseModel, Field
class SetApiKeyRequest(BaseModel):
"""设置 API Key 请求"""
llm_factory: str = Field(..., description="LLM 工厂名称")
api_key: str = Field(..., description="API Key")
base_url: Optional[str] = Field(default="", description="API Base URL")
model_type: Optional[str] = Field(default=None, description="模型类型")
llm_name: Optional[str] = Field(default=None, description="LLM 名称")
class AddLLMRequest(BaseModel):
"""添加 LLM 请求"""
llm_factory: str = Field(..., description="LLM 工厂名称")
model_type: str = Field(..., description="模型类型")
llm_name: str = Field(..., description="LLM 名称")
api_key: Optional[str] = Field(default="x", description="API Key")
api_base: Optional[str] = Field(default="", description="API Base URL")
max_tokens: Optional[int] = Field(default=None, description="最大 Token 数")
# VolcEngine 特殊字段
ark_api_key: Optional[str] = None
endpoint_id: Optional[str] = None
# Tencent Hunyuan 特殊字段
hunyuan_sid: Optional[str] = None
hunyuan_sk: Optional[str] = None
# Tencent Cloud 特殊字段
tencent_cloud_sid: Optional[str] = None
tencent_cloud_sk: Optional[str] = None
# Bedrock 特殊字段
bedrock_ak: Optional[str] = None
bedrock_sk: Optional[str] = None
bedrock_region: Optional[str] = None
# XunFei Spark 特殊字段
spark_api_password: Optional[str] = None
spark_app_id: Optional[str] = None
spark_api_secret: Optional[str] = None
spark_api_key: Optional[str] = None
# BaiduYiyan 特殊字段
yiyan_ak: Optional[str] = None
yiyan_sk: Optional[str] = None
# Fish Audio 特殊字段
fish_audio_ak: Optional[str] = None
fish_audio_refid: Optional[str] = None
# Google Cloud 特殊字段
google_project_id: Optional[str] = None
google_region: Optional[str] = None
google_service_account_key: Optional[str] = None
# Azure-OpenAI 特殊字段
api_version: Optional[str] = None
# OpenRouter 特殊字段
provider_order: Optional[str] = None
class DeleteLLMRequest(BaseModel):
"""删除 LLM 请求"""
llm_factory: str = Field(..., description="LLM 工厂名称")
llm_name: str = Field(..., description="LLM 名称")
class DeleteFactoryRequest(BaseModel):
"""删除工厂请求"""
llm_factory: str = Field(..., description="LLM 工厂名称")
class MyLLMsQuery(BaseModel):
"""获取我的 LLMs 查询参数"""
include_details: Optional[str] = Field(default="false", description="是否包含详细信息")
class ListLLMsQuery(BaseModel):
"""列出 LLMs 查询参数"""
model_type: Optional[str] = Field(default=None, description="模型类型过滤")

View File

@@ -0,0 +1,99 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Optional, List, Dict, Any
from pydantic import BaseModel, Field
class ListMCPServersQuery(BaseModel):
"""列出MCP服务器查询参数"""
keywords: Optional[str] = Field(default="", description="关键词")
page: Optional[int] = Field(default=0, description="页码")
page_size: Optional[int] = Field(default=0, description="每页大小")
orderby: Optional[str] = Field(default="create_time", description="排序字段")
desc: Optional[str] = Field(default="true", description="是否降序")
class ListMCPServersBody(BaseModel):
"""列出MCP服务器请求体"""
mcp_ids: Optional[List[str]] = Field(default=[], description="MCP服务器ID列表")
class CreateMCPServerRequest(BaseModel):
"""创建MCP服务器请求"""
name: str = Field(..., description="服务器名称")
url: str = Field(..., description="服务器URL")
server_type: str = Field(..., description="服务器类型")
headers: Optional[Dict[str, Any]] = Field(default={}, description="请求头")
variables: Optional[Dict[str, Any]] = Field(default={}, description="变量")
timeout: Optional[float] = Field(default=10.0, description="超时时间")
class UpdateMCPServerRequest(BaseModel):
"""更新MCP服务器请求"""
mcp_id: str = Field(..., description="MCP服务器ID")
name: Optional[str] = Field(default=None, description="服务器名称")
url: Optional[str] = Field(default=None, description="服务器URL")
server_type: Optional[str] = Field(default=None, description="服务器类型")
headers: Optional[Dict[str, Any]] = Field(default=None, description="请求头")
variables: Optional[Dict[str, Any]] = Field(default=None, description="变量")
timeout: Optional[float] = Field(default=10.0, description="超时时间")
class DeleteMCPServersRequest(BaseModel):
"""删除MCP服务器请求"""
mcp_ids: List[str] = Field(..., description="MCP服务器ID列表")
class ImportMCPServersRequest(BaseModel):
"""批量导入MCP服务器请求"""
mcpServers: Dict[str, Any] = Field(..., description="MCP服务器配置字典")
timeout: Optional[float] = Field(default=10.0, description="超时时间")
class ExportMCPServersRequest(BaseModel):
"""批量导出MCP服务器请求"""
mcp_ids: List[str] = Field(..., description="MCP服务器ID列表")
class ListMCPToolsRequest(BaseModel):
"""列出MCP工具请求"""
mcp_ids: List[str] = Field(..., description="MCP服务器ID列表")
timeout: Optional[float] = Field(default=10.0, description="超时时间")
class TestMCPToolRequest(BaseModel):
"""测试MCP工具请求"""
mcp_id: str = Field(..., description="MCP服务器ID")
tool_name: str = Field(..., description="工具名称")
arguments: Dict[str, Any] = Field(..., description="工具参数")
timeout: Optional[float] = Field(default=10.0, description="超时时间")
class CacheMCPToolsRequest(BaseModel):
"""缓存MCP工具请求"""
mcp_id: str = Field(..., description="MCP服务器ID")
tools: List[Dict[str, Any]] = Field(..., description="工具列表")
class TestMCPRequest(BaseModel):
"""测试MCP服务器请求不需要登录"""
url: str = Field(..., description="服务器URL")
server_type: str = Field(..., description="服务器类型")
headers: Optional[Dict[str, Any]] = Field(default={}, description="请求头")
variables: Optional[Dict[str, Any]] = Field(default={}, description="变量")
timeout: Optional[float] = Field(default=10.0, description="超时时间")

View File

@@ -1,8 +1,26 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from flask import Response
from flask_login import login_required
from api.utils.api_utils import get_json_result
from plugin import GlobalPluginManager
@manager.route('/llm_tools', methods=['GET']) # noqa: F821
@login_required
def llm_tools() -> Response:

View File

@@ -25,6 +25,7 @@ from api.utils.api_utils import get_data_error_result, get_error_data_result, ge
from api.utils.api_utils import get_result
from flask import request
@manager.route('/agents', methods=['GET']) # noqa: F821
@token_required
def list_agents(tenant_id):
@@ -41,7 +42,7 @@ def list_agents(tenant_id):
desc = False
else:
desc = True
canvas = UserCanvasService.get_list(tenant_id,page_number,items_per_page,orderby,desc,id,title)
canvas = UserCanvasService.get_list(tenant_id, page_number, items_per_page, orderby, desc, id, title)
return get_result(data=canvas)
@@ -93,7 +94,7 @@ def update_agent(tenant_id: str, agent_id: str):
req["dsl"] = json.dumps(req["dsl"], ensure_ascii=False)
req["dsl"] = json.loads(req["dsl"])
if req.get("title") is not None:
req["title"] = req["title"].strip()

View File

@@ -215,7 +215,8 @@ def delete(tenant_id):
continue
kb_id_instance_pairs.append((kb_id, kb))
if len(error_kb_ids) > 0:
return get_error_permission_result(message=f"""User '{tenant_id}' lacks permission for datasets: '{", ".join(error_kb_ids)}'""")
return get_error_permission_result(
message=f"""User '{tenant_id}' lacks permission for datasets: '{", ".join(error_kb_ids)}'""")
errors = []
success_count = 0
@@ -232,7 +233,8 @@ def delete(tenant_id):
]
)
File2DocumentService.delete_by_document_id(doc.id)
FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.type == "folder", File.name == kb.name])
FileService.filter_delete(
[File.source_type == FileSource.KNOWLEDGEBASE, File.type == "folder", File.name == kb.name])
if not KnowledgebaseService.delete_by_id(kb_id):
errors.append(f"Delete dataset error for {kb_id}")
continue
@@ -329,7 +331,8 @@ def update(tenant_id, dataset_id):
try:
kb = KnowledgebaseService.get_or_none(id=dataset_id, tenant_id=tenant_id)
if kb is None:
return get_error_permission_result(message=f"User '{tenant_id}' lacks permission for dataset '{dataset_id}'")
return get_error_permission_result(
message=f"User '{tenant_id}' lacks permission for dataset '{dataset_id}'")
if req.get("parser_config"):
req["parser_config"] = deep_merge(kb.parser_config, req["parser_config"])
@@ -341,7 +344,8 @@ def update(tenant_id, dataset_id):
del req["parser_config"]
if "name" in req and req["name"].lower() != kb.name.lower():
exists = KnowledgebaseService.get_or_none(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value)
exists = KnowledgebaseService.get_or_none(name=req["name"], tenant_id=tenant_id,
status=StatusEnum.VALID.value)
if exists:
return get_error_data_result(message=f"Dataset name '{req['name']}' already exists")
@@ -349,7 +353,8 @@ def update(tenant_id, dataset_id):
if not req["embd_id"]:
req["embd_id"] = kb.embd_id
if kb.chunk_num != 0 and req["embd_id"] != kb.embd_id:
return get_error_data_result(message=f"When chunk_num ({kb.chunk_num}) > 0, embedding_model must remain {kb.embd_id}")
return get_error_data_result(
message=f"When chunk_num ({kb.chunk_num}) > 0, embedding_model must remain {kb.embd_id}")
ok, err = verify_embedding_availability(req["embd_id"], tenant_id)
if not ok:
return err
@@ -359,10 +364,12 @@ def update(tenant_id, dataset_id):
return get_error_argument_result(message="'pagerank' can only be set when doc_engine is elasticsearch")
if req["pagerank"] > 0:
settings.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]}, search.index_name(kb.tenant_id), kb.id)
settings.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]},
search.index_name(kb.tenant_id), kb.id)
else:
# Elasticsearch requires PAGERANK_FLD be non-zero!
settings.docStoreConn.update({"exists": PAGERANK_FLD}, {"remove": PAGERANK_FLD}, search.index_name(kb.tenant_id), kb.id)
settings.docStoreConn.update({"exists": PAGERANK_FLD}, {"remove": PAGERANK_FLD},
search.index_name(kb.tenant_id), kb.id)
if not KnowledgebaseService.update_by_id(kb.id, req):
return get_error_data_result(message="Update dataset error.(Database error)")
@@ -454,7 +461,7 @@ def list_datasets(tenant_id):
return get_error_permission_result(message=f"User '{tenant_id}' lacks permission for dataset '{name}'")
tenants = TenantService.get_joined_tenants_by_user_id(tenant_id)
kbs = KnowledgebaseService.get_list(
kbs, total = KnowledgebaseService.get_list(
[m["tenant_id"] for m in tenants],
tenant_id,
args["page"],
@@ -468,14 +475,15 @@ def list_datasets(tenant_id):
response_data_list = []
for kb in kbs:
response_data_list.append(remap_dictionary_keys(kb))
return get_result(data=response_data_list)
return get_result(data=response_data_list, total=total)
except OperationalError as e:
logging.exception(e)
return get_error_data_result(message="Database operation failed")
@manager.route('/datasets/<dataset_id>/knowledge_graph', methods=['GET']) # noqa: F821
@token_required
def knowledge_graph(tenant_id,dataset_id):
def knowledge_graph(tenant_id, dataset_id):
if not KnowledgebaseService.accessible(dataset_id, tenant_id):
return get_result(
data=False,
@@ -491,7 +499,7 @@ def knowledge_graph(tenant_id,dataset_id):
obj = {"graph": {}, "mind_map": {}}
if not settings.docStoreConn.indexExist(search.index_name(kb.tenant_id), dataset_id):
return get_result(data=obj)
sres = settings.retrievaler.search(req, search.index_name(kb.tenant_id), [dataset_id])
sres = settings.retriever.search(req, search.index_name(kb.tenant_id), [dataset_id])
if not len(sres.ids):
return get_result(data=obj)
@@ -507,14 +515,16 @@ def knowledge_graph(tenant_id,dataset_id):
if "nodes" in obj["graph"]:
obj["graph"]["nodes"] = sorted(obj["graph"]["nodes"], key=lambda x: x.get("pagerank", 0), reverse=True)[:256]
if "edges" in obj["graph"]:
node_id_set = { o["id"] for o in obj["graph"]["nodes"] }
filtered_edges = [o for o in obj["graph"]["edges"] if o["source"] != o["target"] and o["source"] in node_id_set and o["target"] in node_id_set]
node_id_set = {o["id"] for o in obj["graph"]["nodes"]}
filtered_edges = [o for o in obj["graph"]["edges"] if
o["source"] != o["target"] and o["source"] in node_id_set and o["target"] in node_id_set]
obj["graph"]["edges"] = sorted(filtered_edges, key=lambda x: x.get("weight", 0), reverse=True)[:128]
return get_result(data=obj)
@manager.route('/datasets/<dataset_id>/knowledge_graph', methods=['DELETE']) # noqa: F821
@token_required
def delete_knowledge_graph(tenant_id,dataset_id):
def delete_knowledge_graph(tenant_id, dataset_id):
if not KnowledgebaseService.accessible(dataset_id, tenant_id):
return get_result(
data=False,
@@ -522,6 +532,7 @@ def delete_knowledge_graph(tenant_id,dataset_id):
code=settings.RetCode.AUTHENTICATION_ERROR
)
_, kb = KnowledgebaseService.get_by_id(dataset_id)
settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]}, search.index_name(kb.tenant_id), dataset_id)
settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]},
search.index_name(kb.tenant_id), dataset_id)
return get_result(data=True)

View File

@@ -1,4 +1,4 @@
#
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -31,6 +31,89 @@ from api.db.services.dialog_service import meta_filter, convert_conditions
@apikey_required
@validate_request("knowledge_id", "query")
def retrieval(tenant_id):
"""
Dify-compatible retrieval API
---
tags:
- SDK
security:
- ApiKeyAuth: []
parameters:
- in: body
name: body
required: true
schema:
type: object
required:
- knowledge_id
- query
properties:
knowledge_id:
type: string
description: Knowledge base ID
query:
type: string
description: Query text
use_kg:
type: boolean
description: Whether to use knowledge graph
default: false
retrieval_setting:
type: object
description: Retrieval configuration
properties:
score_threshold:
type: number
description: Similarity threshold
default: 0.0
top_k:
type: integer
description: Number of results to return
default: 1024
metadata_condition:
type: object
description: Metadata filter condition
properties:
conditions:
type: array
items:
type: object
properties:
name:
type: string
description: Field name
comparison_operator:
type: string
description: Comparison operator
value:
type: string
description: Field value
responses:
200:
description: Retrieval succeeded
schema:
type: object
properties:
records:
type: array
items:
type: object
properties:
content:
type: string
description: Content text
score:
type: number
description: Similarity score
title:
type: string
description: Document title
metadata:
type: object
description: Metadata info
404:
description: Knowledge base or document not found
"""
req = request.json
question = req["query"]
kb_id = req["knowledge_id"]
@@ -38,9 +121,9 @@ def retrieval(tenant_id):
retrieval_setting = req.get("retrieval_setting", {})
similarity_threshold = float(retrieval_setting.get("score_threshold", 0.0))
top = int(retrieval_setting.get("top_k", 1024))
metadata_condition = req.get("metadata_condition",{})
metadata_condition = req.get("metadata_condition", {})
metas = DocumentService.get_meta_by_kbs([kb_id])
doc_ids = []
try:
@@ -50,12 +133,12 @@ def retrieval(tenant_id):
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
print(metadata_condition)
print("after",convert_conditions(metadata_condition))
# print("after", convert_conditions(metadata_condition))
doc_ids.extend(meta_filter(metas, convert_conditions(metadata_condition)))
print("doc_ids",doc_ids)
# print("doc_ids", doc_ids)
if not doc_ids and metadata_condition is not None:
doc_ids = ['-999']
ranks = settings.retrievaler.retrieval(
ranks = settings.retriever.retrieval(
question,
embd_mdl,
kb.tenant_id,
@@ -70,17 +153,17 @@ def retrieval(tenant_id):
)
if use_kg:
ck = settings.kg_retrievaler.retrieval(question,
[tenant_id],
[kb_id],
embd_mdl,
LLMBundle(kb.tenant_id, LLMType.CHAT))
ck = settings.kg_retriever.retrieval(question,
[tenant_id],
[kb_id],
embd_mdl,
LLMBundle(kb.tenant_id, LLMType.CHAT))
if ck["content_with_weight"]:
ranks["chunks"].insert(0, ck)
records = []
for c in ranks["chunks"]:
e, doc = DocumentService.get_by_id( c["doc_id"])
e, doc = DocumentService.get_by_id(c["doc_id"])
c.pop("vector", None)
meta = getattr(doc, 'meta_fields', {})
meta["doc_id"] = c["doc_id"]
@@ -100,5 +183,3 @@ def retrieval(tenant_id):
)
logging.exception(e)
return build_error_result(message=str(e), code=settings.RetCode.SERVER_ERROR)

View File

@@ -69,7 +69,7 @@ class Chunk(BaseModel):
@manager.route("/datasets/<dataset_id>/documents", methods=["POST"]) # noqa: F821
@token_required
async def upload(dataset_id, tenant_id):
def upload(dataset_id, tenant_id):
"""
Upload documents to a dataset.
---
@@ -151,7 +151,7 @@ async def upload(dataset_id, tenant_id):
e, kb = KnowledgebaseService.get_by_id(dataset_id)
if not e:
raise LookupError(f"Can't find the dataset with ID {dataset_id}!")
err, files = await FileService.upload_document(kb, file_objs, tenant_id)
err, files = FileService.upload_document(kb, file_objs, tenant_id)
if err:
return get_result(message="\n".join(err), code=settings.RetCode.SERVER_ERROR)
# rename key's name
@@ -470,6 +470,20 @@ def list_docs(dataset_id, tenant_id):
required: false
default: 0
description: Unix timestamp for filtering documents created before this time. 0 means no filter.
- in: query
name: suffix
type: array
items:
type: string
required: false
description: Filter by file suffix (e.g., ["pdf", "txt", "docx"]).
- in: query
name: run
type: array
items:
type: string
required: false
description: Filter by document run status. Supports both numeric ("0", "1", "2", "3", "4") and text formats ("UNSTART", "RUNNING", "CANCEL", "DONE", "FAIL").
- in: header
name: Authorization
type: string
@@ -512,63 +526,62 @@ def list_docs(dataset_id, tenant_id):
description: Processing status.
"""
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
return get_error_data_result(message=f"You don't own the dataset {dataset_id}. ")
id = request.args.get("id")
name = request.args.get("name")
return get_error_data_result(message=f"You don't own the dataset {dataset_id}. ")
if id and not DocumentService.query(id=id, kb_id=dataset_id):
return get_error_data_result(message=f"You don't own the document {id}.")
q = request.args
document_id = q.get("id")
name = q.get("name")
if document_id and not DocumentService.query(id=document_id, kb_id=dataset_id):
return get_error_data_result(message=f"You don't own the document {document_id}.")
if name and not DocumentService.query(name=name, kb_id=dataset_id):
return get_error_data_result(message=f"You don't own the document {name}.")
page = int(request.args.get("page", 1))
keywords = request.args.get("keywords", "")
page_size = int(request.args.get("page_size", 30))
orderby = request.args.get("orderby", "create_time")
if request.args.get("desc") == "False":
desc = False
else:
desc = True
docs, tol = DocumentService.get_list(dataset_id, page, page_size, orderby, desc, keywords, id, name)
page = int(q.get("page", 1))
page_size = int(q.get("page_size", 30))
orderby = q.get("orderby", "create_time")
desc = str(q.get("desc", "true")).strip().lower() != "false"
keywords = q.get("keywords", "")
create_time_from = int(request.args.get("create_time_from", 0))
create_time_to = int(request.args.get("create_time_to", 0))
# filters - align with OpenAPI parameter names
suffix = q.getlist("suffix")
run_status = q.getlist("run")
create_time_from = int(q.get("create_time_from", 0))
create_time_to = int(q.get("create_time_to", 0))
# map run status (accept text or numeric) - align with API parameter
run_status_text_to_numeric = {"UNSTART": "0", "RUNNING": "1", "CANCEL": "2", "DONE": "3", "FAIL": "4"}
run_status_converted = [run_status_text_to_numeric.get(v, v) for v in run_status]
docs, total = DocumentService.get_list(
dataset_id, page, page_size, orderby, desc, keywords, document_id, name, suffix, run_status_converted
)
# time range filter (0 means no bound)
if create_time_from or create_time_to:
filtered_docs = []
for doc in docs:
doc_create_time = doc.get("create_time", 0)
if (create_time_from == 0 or doc_create_time >= create_time_from) and (create_time_to == 0 or doc_create_time <= create_time_to):
filtered_docs.append(doc)
docs = filtered_docs
docs = [
d for d in docs
if (create_time_from == 0 or d.get("create_time", 0) >= create_time_from)
and (create_time_to == 0 or d.get("create_time", 0) <= create_time_to)
]
# rename key's name
renamed_doc_list = []
# rename keys + map run status back to text for output
key_mapping = {
"chunk_num": "chunk_count",
"kb_id": "dataset_id",
"kb_id": "dataset_id",
"token_num": "token_count",
"parser_id": "chunk_method",
}
run_mapping = {
"0": "UNSTART",
"1": "RUNNING",
"2": "CANCEL",
"3": "DONE",
"4": "FAIL",
}
for doc in docs:
renamed_doc = {}
for key, value in doc.items():
if key == "run":
renamed_doc["run"] = run_mapping.get(str(value))
new_key = key_mapping.get(key, key)
renamed_doc[new_key] = value
if key == "run":
renamed_doc["run"] = run_mapping.get(value)
renamed_doc_list.append(renamed_doc)
return get_result(data={"total": tol, "docs": renamed_doc_list})
run_status_numeric_to_text = {"0": "UNSTART", "1": "RUNNING", "2": "CANCEL", "3": "DONE", "4": "FAIL"}
output_docs = []
for d in docs:
renamed_doc = {key_mapping.get(k, k): v for k, v in d.items()}
if "run" in d:
renamed_doc["run"] = run_status_numeric_to_text.get(str(d["run"]), d["run"])
output_docs.append(renamed_doc)
return get_result(data={"total": total, "docs": output_docs})
@manager.route("/datasets/<dataset_id>/documents", methods=["DELETE"]) # noqa: F821
@token_required
@@ -982,7 +995,7 @@ def list_chunks(tenant_id, dataset_id, document_id):
_ = Chunk(**final_chunk)
elif settings.docStoreConn.indexExist(search.index_name(tenant_id), dataset_id):
sres = settings.retrievaler.search(query, search.index_name(tenant_id), [dataset_id], emb_mdl=None, highlight=True)
sres = settings.retriever.search(query, search.index_name(tenant_id), [dataset_id], emb_mdl=None, highlight=True)
res["total"] = sres.total
for id in sres.ids:
d = {
@@ -1446,7 +1459,7 @@ def retrieval_test(tenant_id):
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
question += keyword_extraction(chat_mdl, question)
ranks = settings.retrievaler.retrieval(
ranks = settings.retriever.retrieval(
question,
embd_mdl,
tenant_ids,
@@ -1462,7 +1475,7 @@ def retrieval_test(tenant_id):
rank_feature=label_question(question, kbs),
)
if use_kg:
ck = settings.kg_retrievaler.retrieval(question, [k.tenant_id for k in kbs], kb_ids, embd_mdl, LLMBundle(kb.tenant_id, LLMType.CHAT))
ck = settings.kg_retriever.retrieval(question, [k.tenant_id for k in kbs], kb_ids, embd_mdl, LLMBundle(kb.tenant_id, LLMType.CHAT))
if ck["content_with_weight"]:
ranks["chunks"].insert(0, ck)

View File

@@ -1,3 +1,20 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import pathlib
import re
@@ -17,7 +34,8 @@ from api.utils.api_utils import get_json_result
from api.utils.file_utils import filename_type
from rag.utils.storage_factory import STORAGE_IMPL
@manager.route('/file/upload', methods=['POST']) # noqa: F821
@manager.route('/file/upload', methods=['POST']) # noqa: F821
@token_required
def upload(tenant_id):
"""
@@ -44,22 +62,22 @@ def upload(tenant_id):
type: object
properties:
data:
type: array
items:
type: object
properties:
id:
type: string
description: File ID
name:
type: string
description: File name
size:
type: integer
description: File size in bytes
type:
type: string
description: File type (e.g., document, folder)
type: array
items:
type: object
properties:
id:
type: string
description: File ID
name:
type: string
description: File name
size:
type: integer
description: File size in bytes
type:
type: string
description: File type (e.g., document, folder)
"""
pf_id = request.form.get("parent_id")
@@ -97,12 +115,14 @@ def upload(tenant_id):
e, file = FileService.get_by_id(file_id_list[len_id_list - 1])
if not e:
return get_json_result(data=False, message="Folder not found!", code=404)
last_folder = FileService.create_folder(file, file_id_list[len_id_list - 1], file_obj_names, len_id_list)
last_folder = FileService.create_folder(file, file_id_list[len_id_list - 1], file_obj_names,
len_id_list)
else:
e, file = FileService.get_by_id(file_id_list[len_id_list - 2])
if not e:
return get_json_result(data=False, message="Folder not found!", code=404)
last_folder = FileService.create_folder(file, file_id_list[len_id_list - 2], file_obj_names, len_id_list)
last_folder = FileService.create_folder(file, file_id_list[len_id_list - 2], file_obj_names,
len_id_list)
filetype = filename_type(file_obj_names[file_len - 1])
location = file_obj_names[file_len - 1]
@@ -129,7 +149,7 @@ def upload(tenant_id):
return server_error_response(e)
@manager.route('/file/create', methods=['POST']) # noqa: F821
@manager.route('/file/create', methods=['POST']) # noqa: F821
@token_required
def create(tenant_id):
"""
@@ -207,7 +227,7 @@ def create(tenant_id):
return server_error_response(e)
@manager.route('/file/list', methods=['GET']) # noqa: F821
@manager.route('/file/list', methods=['GET']) # noqa: F821
@token_required
def list_files(tenant_id):
"""
@@ -299,7 +319,7 @@ def list_files(tenant_id):
return server_error_response(e)
@manager.route('/file/root_folder', methods=['GET']) # noqa: F821
@manager.route('/file/root_folder', methods=['GET']) # noqa: F821
@token_required
def get_root_folder(tenant_id):
"""
@@ -335,7 +355,7 @@ def get_root_folder(tenant_id):
return server_error_response(e)
@manager.route('/file/parent_folder', methods=['GET']) # noqa: F821
@manager.route('/file/parent_folder', methods=['GET']) # noqa: F821
@token_required
def get_parent_folder():
"""
@@ -380,7 +400,7 @@ def get_parent_folder():
return server_error_response(e)
@manager.route('/file/all_parent_folder', methods=['GET']) # noqa: F821
@manager.route('/file/all_parent_folder', methods=['GET']) # noqa: F821
@token_required
def get_all_parent_folders(tenant_id):
"""
@@ -428,7 +448,7 @@ def get_all_parent_folders(tenant_id):
return server_error_response(e)
@manager.route('/file/rm', methods=['POST']) # noqa: F821
@manager.route('/file/rm', methods=['POST']) # noqa: F821
@token_required
def rm(tenant_id):
"""
@@ -502,7 +522,7 @@ def rm(tenant_id):
return server_error_response(e)
@manager.route('/file/rename', methods=['POST']) # noqa: F821
@manager.route('/file/rename', methods=['POST']) # noqa: F821
@token_required
def rename(tenant_id):
"""
@@ -542,7 +562,8 @@ def rename(tenant_id):
if not e:
return get_json_result(message="File not found!", code=404)
if file.type != FileType.FOLDER.value and pathlib.Path(req["name"].lower()).suffix != pathlib.Path(file.name.lower()).suffix:
if file.type != FileType.FOLDER.value and pathlib.Path(req["name"].lower()).suffix != pathlib.Path(
file.name.lower()).suffix:
return get_json_result(data=False, message="The extension of file can't be changed", code=400)
for existing_file in FileService.query(name=req["name"], pf_id=file.parent_id):
@@ -562,9 +583,9 @@ def rename(tenant_id):
return server_error_response(e)
@manager.route('/file/get/<file_id>', methods=['GET']) # noqa: F821
@manager.route('/file/get/<file_id>', methods=['GET']) # noqa: F821
@token_required
def get(tenant_id,file_id):
def get(tenant_id, file_id):
"""
Download a file.
---
@@ -610,7 +631,7 @@ def get(tenant_id,file_id):
return server_error_response(e)
@manager.route('/file/mv', methods=['POST']) # noqa: F821
@manager.route('/file/mv', methods=['POST']) # noqa: F821
@token_required
def move(tenant_id):
"""
@@ -669,6 +690,7 @@ def move(tenant_id):
except Exception as e:
return server_error_response(e)
@manager.route('/file/convert', methods=['POST']) # noqa: F821
@token_required
def convert(tenant_id):
@@ -735,4 +757,4 @@ def convert(tenant_id):
file2documents.append(file2document.to_json())
return get_json_result(data=file2documents)
except Exception as e:
return server_error_response(e)
return server_error_response(e)

View File

@@ -36,7 +36,8 @@ from api.db.services.llm_service import LLMBundle
from api.db.services.search_service import SearchService
from api.db.services.user_service import UserTenantService
from api.utils import get_uuid
from api.utils.api_utils import check_duplicate_ids, get_data_openai, get_error_data_result, get_json_result, get_result, server_error_response, token_required, validate_request
from api.utils.api_utils import check_duplicate_ids, get_data_openai, get_error_data_result, get_json_result, \
get_result, server_error_response, token_required, validate_request
from rag.app.tag import label_question
from rag.prompts.template import load_prompt
from rag.prompts.generator import cross_languages, gen_meta_filter, keyword_extraction, chunks_format
@@ -88,7 +89,8 @@ def create_agent_session(tenant_id, agent_id):
canvas.reset()
cvs.dsl = json.loads(str(canvas))
conv = {"id": session_id, "dialog_id": cvs.id, "user_id": user_id, "message": [{"role": "assistant", "content": canvas.get_prologue()}], "source": "agent", "dsl": cvs.dsl}
conv = {"id": session_id, "dialog_id": cvs.id, "user_id": user_id,
"message": [{"role": "assistant", "content": canvas.get_prologue()}], "source": "agent", "dsl": cvs.dsl}
API4ConversationService.save(**conv)
conv["agent_id"] = conv.pop("dialog_id")
return get_result(data=conv)
@@ -279,7 +281,7 @@ def chat_completion_openai_like(tenant_id, chat_id):
reasoning_match = re.search(r"<think>(.*?)</think>", answer, flags=re.DOTALL)
if reasoning_match:
reasoning_part = reasoning_match.group(1)
content_part = answer[reasoning_match.end() :]
content_part = answer[reasoning_match.end():]
else:
reasoning_part = ""
content_part = answer
@@ -324,7 +326,8 @@ def chat_completion_openai_like(tenant_id, chat_id):
response["choices"][0]["delta"]["content"] = None
response["choices"][0]["delta"]["reasoning_content"] = None
response["choices"][0]["finish_reason"] = "stop"
response["usage"] = {"prompt_tokens": len(prompt), "completion_tokens": token_used, "total_tokens": len(prompt) + token_used}
response["usage"] = {"prompt_tokens": len(prompt), "completion_tokens": token_used,
"total_tokens": len(prompt) + token_used}
if need_reference:
response["choices"][0]["delta"]["reference"] = chunks_format(last_ans.get("reference", []))
response["choices"][0]["delta"]["final_content"] = last_ans.get("answer", "")
@@ -559,7 +562,8 @@ def list_agent_session(tenant_id, agent_id):
desc = True
# dsl defaults to True in all cases except for False and false
include_dsl = request.args.get("dsl") != "False" and request.args.get("dsl") != "false"
total, convs = API4ConversationService.get_list(agent_id, tenant_id, page_number, items_per_page, orderby, desc, id, user_id, include_dsl)
total, convs = API4ConversationService.get_list(agent_id, tenant_id, page_number, items_per_page, orderby, desc, id,
user_id, include_dsl)
if not convs:
return get_result(data=[])
for conv in convs:
@@ -581,7 +585,8 @@ def list_agent_session(tenant_id, agent_id):
if message_num != 0 and messages[message_num]["role"] != "user":
chunk_list = []
# Add boundary and type checks to prevent KeyError
if chunk_num < len(conv["reference"]) and conv["reference"][chunk_num] is not None and isinstance(conv["reference"][chunk_num], dict) and "chunks" in conv["reference"][chunk_num]:
if chunk_num < len(conv["reference"]) and conv["reference"][chunk_num] is not None and isinstance(
conv["reference"][chunk_num], dict) and "chunks" in conv["reference"][chunk_num]:
chunks = conv["reference"][chunk_num]["chunks"]
for chunk in chunks:
# Ensure chunk is a dictionary before calling get method
@@ -639,13 +644,16 @@ def delete(tenant_id, chat_id):
if errors:
if success_count > 0:
return get_result(data={"success_count": success_count, "errors": errors}, message=f"Partially deleted {success_count} sessions with {len(errors)} errors")
return get_result(data={"success_count": success_count, "errors": errors},
message=f"Partially deleted {success_count} sessions with {len(errors)} errors")
else:
return get_error_data_result(message="; ".join(errors))
if duplicate_messages:
if success_count > 0:
return get_result(message=f"Partially deleted {success_count} sessions with {len(duplicate_messages)} errors", data={"success_count": success_count, "errors": duplicate_messages})
return get_result(
message=f"Partially deleted {success_count} sessions with {len(duplicate_messages)} errors",
data={"success_count": success_count, "errors": duplicate_messages})
else:
return get_error_data_result(message=";".join(duplicate_messages))
@@ -691,13 +699,16 @@ def delete_agent_session(tenant_id, agent_id):
if errors:
if success_count > 0:
return get_result(data={"success_count": success_count, "errors": errors}, message=f"Partially deleted {success_count} sessions with {len(errors)} errors")
return get_result(data={"success_count": success_count, "errors": errors},
message=f"Partially deleted {success_count} sessions with {len(errors)} errors")
else:
return get_error_data_result(message="; ".join(errors))
if duplicate_messages:
if success_count > 0:
return get_result(message=f"Partially deleted {success_count} sessions with {len(duplicate_messages)} errors", data={"success_count": success_count, "errors": duplicate_messages})
return get_result(
message=f"Partially deleted {success_count} sessions with {len(duplicate_messages)} errors",
data={"success_count": success_count, "errors": duplicate_messages})
else:
return get_error_data_result(message=";".join(duplicate_messages))
@@ -730,7 +741,9 @@ def ask_about(tenant_id):
for ans in ask(req["question"], req["kb_ids"], uid):
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
except Exception as e:
yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n"
yield "data:" + json.dumps(
{"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}},
ensure_ascii=False) + "\n\n"
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
resp = Response(stream(), mimetype="text/event-stream")
@@ -882,7 +895,9 @@ def begin_inputs(agent_id):
return get_error_data_result(f"Can't find agent by ID: {agent_id}")
canvas = Canvas(json.dumps(cvs.dsl), objs[0].tenant_id)
return get_result(data={"title": cvs.title, "avatar": cvs.avatar, "inputs": canvas.get_component_input_form("begin"), "prologue": canvas.get_prologue(), "mode": canvas.get_mode()})
return get_result(
data={"title": cvs.title, "avatar": cvs.avatar, "inputs": canvas.get_component_input_form("begin"),
"prologue": canvas.get_prologue(), "mode": canvas.get_mode()})
@manager.route("/searchbots/ask", methods=["POST"]) # noqa: F821
@@ -911,7 +926,9 @@ def ask_about_embedded():
for ans in ask(req["question"], req["kb_ids"], uid, search_config=search_config):
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
except Exception as e:
yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n"
yield "data:" + json.dumps(
{"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}},
ensure_ascii=False) + "\n\n"
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
resp = Response(stream(), mimetype="text/event-stream")
@@ -978,7 +995,8 @@ def retrieval_test_embedded():
tenant_ids.append(tenant.tenant_id)
break
else:
return get_json_result(data=False, message="Only owner of knowledgebase authorized for this operation.", code=settings.RetCode.OPERATING_ERROR)
return get_json_result(data=False, message="Only owner of knowledgebase authorized for this operation.",
code=settings.RetCode.OPERATING_ERROR)
e, kb = KnowledgebaseService.get_by_id(kb_ids[0])
if not e:
@@ -998,11 +1016,13 @@ def retrieval_test_embedded():
question += keyword_extraction(chat_mdl, question)
labels = label_question(question, [kb])
ranks = settings.retrievaler.retrieval(
question, embd_mdl, tenant_ids, kb_ids, page, size, similarity_threshold, vector_similarity_weight, top, doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"), rank_feature=labels
ranks = settings.retriever.retrieval(
question, embd_mdl, tenant_ids, kb_ids, page, size, similarity_threshold, vector_similarity_weight, top,
doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"), rank_feature=labels
)
if use_kg:
ck = settings.kg_retrievaler.retrieval(question, tenant_ids, kb_ids, embd_mdl, LLMBundle(kb.tenant_id, LLMType.CHAT))
ck = settings.kg_retriever.retrieval(question, tenant_ids, kb_ids, embd_mdl,
LLMBundle(kb.tenant_id, LLMType.CHAT))
if ck["content_with_weight"]:
ranks["chunks"].insert(0, ck)
@@ -1013,7 +1033,8 @@ def retrieval_test_embedded():
return get_json_result(data=ranks)
except Exception as e:
if str(e).find("not_found") > 0:
return get_json_result(data=False, message="No chunk found! Check the chunk status please!", code=settings.RetCode.DATA_ERROR)
return get_json_result(data=False, message="No chunk found! Check the chunk status please!",
code=settings.RetCode.DATA_ERROR)
return server_error_response(e)
@@ -1082,7 +1103,8 @@ def detail_share_embedded():
if SearchService.query(tenant_id=tenant.tenant_id, id=search_id):
break
else:
return get_json_result(data=False, message="Has no permission for this operation.", code=settings.RetCode.OPERATING_ERROR)
return get_json_result(data=False, message="Has no permission for this operation.",
code=settings.RetCode.OPERATING_ERROR)
search = SearchService.get_detail(search_id)
if not search:

View File

@@ -39,6 +39,7 @@ from rag.utils.redis_conn import REDIS_CONN
from flask import jsonify
from api.utils.health_utils import run_health_checks
@manager.route("/version", methods=["GET"]) # noqa: F821
@login_required
def version():
@@ -161,7 +162,7 @@ def status():
task_executors = REDIS_CONN.smembers("TASKEXE")
now = datetime.now().timestamp()
for task_executor_id in task_executors:
heartbeats = REDIS_CONN.zrangebyscore(task_executor_id, now - 60*30, now)
heartbeats = REDIS_CONN.zrangebyscore(task_executor_id, now - 60 * 30, now)
heartbeats = [json.loads(heartbeat) for heartbeat in heartbeats]
task_executor_heartbeats[task_executor_id] = heartbeats
except Exception:
@@ -273,7 +274,8 @@ def token_list():
objs = [o.to_dict() for o in objs]
for o in objs:
if not o["beta"]:
o["beta"] = generate_confirmation_token(generate_confirmation_token(tenants[0].tenant_id)).replace("ragflow-", "")[:32]
o["beta"] = generate_confirmation_token(generate_confirmation_token(tenants[0].tenant_id)).replace(
"ragflow-", "")[:32]
APITokenService.filter_update([APIToken.tenant_id == tenant_id, APIToken.token == o["token"]], o)
return get_json_result(data=objs)
except Exception as e:

View File

@@ -14,9 +14,8 @@
# limitations under the License.
#
from fastapi import APIRouter, Depends, HTTPException, Query, status
from api.models.tenant_models import InviteUserRequest, UserTenantResponse
from api.utils.api_utils import get_current_user
from flask import request
from flask_login import login_required, current_user
from api import settings
from api.apps import smtp_mail_server
@@ -25,18 +24,13 @@ from api.db.db_models import UserTenant
from api.db.services.user_service import UserTenantService, UserService
from api.utils import get_uuid, delta_seconds
from api.utils.api_utils import get_json_result, server_error_response, get_data_error_result
from api.utils.api_utils import get_json_result, validate_request, server_error_response, get_data_error_result
from api.utils.web_utils import send_invite_email
# 创建 FastAPI 路由器
router = APIRouter()
@router.get("/{tenant_id}/user/list")
async def user_list(
tenant_id: str,
current_user = Depends(get_current_user)
):
@manager.route("/<tenant_id>/user/list", methods=["GET"]) # noqa: F821
@login_required
def user_list(tenant_id):
if current_user.id != tenant_id:
return get_json_result(
data=False,
@@ -52,19 +46,18 @@ async def user_list(
return server_error_response(e)
@router.post('/{tenant_id}/user')
async def create(
tenant_id: str,
request: InviteUserRequest,
current_user = Depends(get_current_user)
):
@manager.route('/<tenant_id>/user', methods=['POST']) # noqa: F821
@login_required
@validate_request("email")
def create(tenant_id):
if current_user.id != tenant_id:
return get_json_result(
data=False,
message='No authorization.',
code=settings.RetCode.AUTHENTICATION_ERROR)
invite_user_email = request.email
req = request.json
invite_user_email = req["email"]
invite_users = UserService.query(email=invite_user_email)
if not invite_users:
return get_data_error_result(message="User not found.")
@@ -77,7 +70,8 @@ async def create(
return get_data_error_result(message=f"{invite_user_email} is already in the team.")
if user_tenant_role == UserTenantRole.OWNER:
return get_data_error_result(message=f"{invite_user_email} is the owner of the team.")
return get_data_error_result(message=f"{invite_user_email} is in the team, but the role: {user_tenant_role} is invalid.")
return get_data_error_result(
message=f"{invite_user_email} is in the team, but the role: {user_tenant_role} is invalid.")
UserTenantService.save(
id=get_uuid(),
@@ -107,12 +101,9 @@ async def create(
return get_json_result(data=usr)
@router.delete('/{tenant_id}/user/{user_id}')
async def rm(
tenant_id: str,
user_id: str,
current_user = Depends(get_current_user)
):
@manager.route('/<tenant_id>/user/<user_id>', methods=['DELETE']) # noqa: F821
@login_required
def rm(tenant_id, user_id):
if current_user.id != tenant_id and current_user.id != user_id:
return get_json_result(
data=False,
@@ -126,10 +117,9 @@ async def rm(
return server_error_response(e)
@router.get("/list")
async def tenant_list(
current_user = Depends(get_current_user)
):
@manager.route("/list", methods=["GET"]) # noqa: F821
@login_required
def tenant_list():
try:
users = UserTenantService.get_tenants_by_user_id(current_user.id)
for u in users:
@@ -139,13 +129,12 @@ async def tenant_list(
return server_error_response(e)
@router.put("/agree/{tenant_id}")
async def agree(
tenant_id: str,
current_user = Depends(get_current_user)
):
@manager.route("/agree/<tenant_id>", methods=["PUT"]) # noqa: F821
@login_required
def agree(tenant_id):
try:
UserTenantService.filter_update([UserTenant.tenant_id == tenant_id, UserTenant.user_id == current_user.id], {"role": UserTenantRole.NORMAL})
UserTenantService.filter_update([UserTenant.tenant_id == tenant_id, UserTenant.user_id == current_user.id],
{"role": UserTenantRole.NORMAL})
return get_json_result(data=True)
except Exception as e:
return server_error_response(e)

View File

@@ -15,11 +15,14 @@
#
import json
import logging
import string
import os
import re
import secrets
import time
from datetime import datetime
from flask import redirect, request, session
from flask import redirect, request, session, make_response
from flask_login import current_user, login_required, login_user, logout_user
from werkzeug.security import check_password_hash, generate_password_hash
@@ -46,6 +49,19 @@ from api.utils.api_utils import (
validate_request,
)
from api.utils.crypt import decrypt
from rag.utils.redis_conn import REDIS_CONN
from api.apps import smtp_mail_server
from api.utils.web_utils import (
send_email_html,
OTP_LENGTH,
OTP_TTL_SECONDS,
ATTEMPT_LIMIT,
ATTEMPT_LOCK_SECONDS,
RESEND_COOLDOWN_SECONDS,
otp_keys,
hash_code,
captcha_key,
)
@manager.route("/login", methods=["POST", "GET"]) # noqa: F821
@@ -825,3 +841,172 @@ def set_tenant_info():
return get_json_result(data=True)
except Exception as e:
return server_error_response(e)
@manager.route("/forget/captcha", methods=["GET"]) # noqa: F821
def forget_get_captcha():
"""
GET /forget/captcha?email=<email>
- Generate an image captcha and cache it in Redis under key captcha:{email} with TTL = OTP_TTL_SECONDS.
- Returns the captcha as a PNG image.
"""
email = (request.args.get("email") or "")
if not email:
return get_json_result(data=False, code=settings.RetCode.ARGUMENT_ERROR, message="email is required")
users = UserService.query(email=email)
if not users:
return get_json_result(data=False, code=settings.RetCode.DATA_ERROR, message="invalid email")
# Generate captcha text
allowed = string.ascii_uppercase + string.digits
captcha_text = "".join(secrets.choice(allowed) for _ in range(OTP_LENGTH))
REDIS_CONN.set(captcha_key(email), captcha_text, 60) # Valid for 60 seconds
from captcha.image import ImageCaptcha
image = ImageCaptcha(width=300, height=120, font_sizes=[50, 60, 70])
img_bytes = image.generate(captcha_text).read()
response = make_response(img_bytes)
response.headers.set("Content-Type", "image/JPEG")
return response
@manager.route("/forget/otp", methods=["POST"]) # noqa: F821
def forget_send_otp():
"""
POST /forget/otp
- Verify the image captcha stored at captcha:{email} (case-insensitive).
- On success, generate an email OTP (AZ with length = OTP_LENGTH), store hash + salt (and timestamp) in Redis with TTL, reset attempts and cooldown, and send the OTP via email.
"""
req = request.get_json()
email = req.get("email") or ""
captcha = (req.get("captcha") or "").strip()
if not email or not captcha:
return get_json_result(data=False, code=settings.RetCode.ARGUMENT_ERROR, message="email and captcha required")
users = UserService.query(email=email)
if not users:
return get_json_result(data=False, code=settings.RetCode.DATA_ERROR, message="invalid email")
stored_captcha = REDIS_CONN.get(captcha_key(email))
if not stored_captcha:
return get_json_result(data=False, code=settings.RetCode.NOT_EFFECTIVE, message="invalid or expired captcha")
if (stored_captcha or "").strip().lower() != captcha.lower():
return get_json_result(data=False, code=settings.RetCode.AUTHENTICATION_ERROR, message="invalid or expired captcha")
# Delete captcha to prevent reuse
REDIS_CONN.delete(captcha_key(email))
k_code, k_attempts, k_last, k_lock = otp_keys(email)
now = int(time.time())
last_ts = REDIS_CONN.get(k_last)
if last_ts:
try:
elapsed = now - int(last_ts)
except Exception:
elapsed = RESEND_COOLDOWN_SECONDS
remaining = RESEND_COOLDOWN_SECONDS - elapsed
if remaining > 0:
return get_json_result(data=False, code=settings.RetCode.NOT_EFFECTIVE, message=f"you still have to wait {remaining} seconds")
# Generate OTP (uppercase letters only) and store hashed
otp = "".join(secrets.choice(string.ascii_uppercase) for _ in range(OTP_LENGTH))
salt = os.urandom(16)
code_hash = hash_code(otp, salt)
REDIS_CONN.set(k_code, f"{code_hash}:{salt.hex()}", OTP_TTL_SECONDS)
REDIS_CONN.set(k_attempts, 0, OTP_TTL_SECONDS)
REDIS_CONN.set(k_last, now, OTP_TTL_SECONDS)
REDIS_CONN.delete(k_lock)
ttl_min = OTP_TTL_SECONDS // 60
if not smtp_mail_server:
logging.warning("SMTP mail server not initialized; skip sending email.")
else:
try:
send_email_html(
subject="Your Password Reset Code",
to_email=email,
template_key="reset_code",
code=otp,
ttl_min=ttl_min,
)
except Exception:
return get_json_result(data=False, code=settings.RetCode.SERVER_ERROR, message="failed to send email")
return get_json_result(data=True, code=settings.RetCode.SUCCESS, message="verification passed, email sent")
@manager.route("/forget", methods=["POST"]) # noqa: F821
def forget():
"""
POST: Verify email + OTP and reset password, then log the user in.
Request JSON: { email, otp, new_password, confirm_new_password }
"""
req = request.get_json()
email = req.get("email") or ""
otp = (req.get("otp") or "").strip()
new_pwd = req.get("new_password")
new_pwd2 = req.get("confirm_new_password")
if not all([email, otp, new_pwd, new_pwd2]):
return get_json_result(data=False, code=settings.RetCode.ARGUMENT_ERROR, message="email, otp and passwords are required")
# For reset, passwords are provided as-is (no decrypt needed)
if new_pwd != new_pwd2:
return get_json_result(data=False, code=settings.RetCode.ARGUMENT_ERROR, message="passwords do not match")
users = UserService.query(email=email)
if not users:
return get_json_result(data=False, code=settings.RetCode.DATA_ERROR, message="invalid email")
user = users[0]
# Verify OTP from Redis
k_code, k_attempts, k_last, k_lock = otp_keys(email)
if REDIS_CONN.get(k_lock):
return get_json_result(data=False, code=settings.RetCode.NOT_EFFECTIVE, message="too many attempts, try later")
stored = REDIS_CONN.get(k_code)
if not stored:
return get_json_result(data=False, code=settings.RetCode.NOT_EFFECTIVE, message="expired otp")
try:
stored_hash, salt_hex = str(stored).split(":", 1)
salt = bytes.fromhex(salt_hex)
except Exception:
return get_json_result(data=False, code=settings.RetCode.EXCEPTION_ERROR, message="otp storage corrupted")
# Case-insensitive verification: OTP generated uppercase
calc = hash_code(otp.upper(), salt)
if calc != stored_hash:
# bump attempts
try:
attempts = int(REDIS_CONN.get(k_attempts) or 0) + 1
except Exception:
attempts = 1
REDIS_CONN.set(k_attempts, attempts, OTP_TTL_SECONDS)
if attempts >= ATTEMPT_LIMIT:
REDIS_CONN.set(k_lock, int(time.time()), ATTEMPT_LOCK_SECONDS)
return get_json_result(data=False, code=settings.RetCode.AUTHENTICATION_ERROR, message="expired otp")
# Success: consume OTP and reset password
REDIS_CONN.delete(k_code)
REDIS_CONN.delete(k_attempts)
REDIS_CONN.delete(k_last)
REDIS_CONN.delete(k_lock)
try:
UserService.update_user_password(user.id, new_pwd)
except Exception as e:
logging.exception(e)
return get_json_result(data=False, code=settings.RetCode.EXCEPTION_ERROR, message="failed to reset password")
# Auto login (reuse login flow)
user.access_token = get_uuid()
login_user(user)
user.update_time = (current_timestamp(),)
user.update_date = (datetime_format(datetime.now()),)
user.save()
msg = "Password reset successful. Logged in."
return construct_response(data=user.to_json(), auth=user.get_id(), message=msg)

View File

@@ -21,8 +21,7 @@ from datetime import datetime
from typing import Optional, Dict, Any
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
from fastapi.security import HTTPAuthorizationCredentials
from api.utils.api_utils import security
from api.apps.models.auth_dependencies import get_current_user
from fastapi.responses import RedirectResponse
from pydantic import BaseModel, EmailStr
try:
@@ -89,63 +88,7 @@ class TenantInfoRequest(BaseModel):
img2txt_id: str
llm_id: str
# 依赖项:获取当前用户
async def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(security)):
"""获取当前用户"""
from api.db import StatusEnum
try:
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
except ImportError:
# 如果没有itsdangerous使用jwt作为替代
import jwt
Serializer = jwt
jwt = Serializer(secret_key=settings.SECRET_KEY)
authorization = credentials.credentials
if authorization:
try:
access_token = str(jwt.loads(authorization))
if not access_token or not access_token.strip():
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authentication attempt with empty access token"
)
# Access tokens should be UUIDs (32 hex characters)
if len(access_token.strip()) < 32:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=f"Authentication attempt with invalid token format: {len(access_token)} chars"
)
user = UserService.query(
access_token=access_token, status=StatusEnum.VALID.value
)
if user:
if not user[0].access_token or not user[0].access_token.strip():
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=f"User {user[0].email} has empty access_token in database"
)
return user[0]
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid access token"
)
except Exception as e:
logging.warning(f"load_user got exception {e}")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid access token"
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authorization header required"
)
# 依赖项:获取当前用户 - 从 auth_dependencies 导入
@router.post("/login")
async def login(request: LoginRequest):