v0.21.1-fastapi

This commit is contained in:
2025-11-04 16:06:36 +08:00
parent 3e58c3d0e9
commit d57b5d76ae
218 changed files with 19617 additions and 72339 deletions

View File

@@ -53,6 +53,7 @@ from api.db.services.llm_service import LLMService
from api.db.services.tenant_llm_service import TenantLLMService
from api.utils.json import CustomJSONEncoder, json_dumps
# 自定义认证方案支持不传Bearer格式
class CustomHTTPBearer(SecurityBase):
def __init__(self, *, scheme_name: str = None, auto_error: bool = True):
@@ -71,13 +72,13 @@ class CustomHTTPBearer(SecurityBase):
)
else:
return None
# 支持Bearer格式和直接token格式
if authorization.startswith("Bearer "):
token = authorization[7:] # 移除"Bearer "前缀
else:
token = authorization # 直接使用token
return HTTPAuthorizationCredentials(scheme="Bearer", credentials=token)
# FastAPI 安全方案
@@ -95,8 +96,8 @@ def serialize_for_json(obj):
if hasattr(obj, '__dict__'):
# For objects with __dict__, try to serialize their attributes
try:
return {key: serialize_for_json(value) for key, value in obj.__dict__.items()
if not key.startswith('_')}
return {key: serialize_for_json(value) for key, value in obj.__dict__.items()
if not key.startswith('_')}
except (AttributeError, TypeError):
return str(obj)
elif hasattr(obj, '__name__'):
@@ -112,6 +113,7 @@ def serialize_for_json(obj):
# Fallback: convert to string representation
return str(obj)
def request(**kwargs):
sess = requests.Session()
stream = kwargs.pop("stream", sess.stream)
@@ -132,7 +134,8 @@ def request(**kwargs):
settings.HTTP_APP_KEY.encode("ascii"),
prepped.path_url.encode("ascii"),
prepped.body if kwargs.get("json") else b"",
urlencode(sorted(kwargs["data"].items()), quote_via=quote, safe="-._~").encode("ascii") if kwargs.get("data") and isinstance(kwargs["data"], dict) else b"",
urlencode(sorted(kwargs["data"].items()), quote_via=quote, safe="-._~").encode(
"ascii") if kwargs.get("data") and isinstance(kwargs["data"], dict) else b"",
]
),
"sha1",
@@ -154,7 +157,7 @@ def request(**kwargs):
def get_exponential_backoff_interval(retries, full_jitter=False):
"""Calculate the exponential backoff wait time."""
# Will be zero if factor equals 0
countdown = min(REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC * (2**retries))
countdown = min(REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC * (2 ** retries))
# Full jitter according to
# https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/
if full_jitter:
@@ -185,11 +188,12 @@ def server_error_response(e):
if len(e.args) > 1:
try:
serialized_data = serialize_for_json(e.args[1])
return get_json_result(code= settings.RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=serialized_data)
return get_json_result(code=settings.RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=serialized_data)
except Exception:
return get_json_result(code=settings.RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=None)
if repr(e).find("index_not_found_exception") >= 0:
return get_json_result(code=settings.RetCode.EXCEPTION_ERROR, message="No chunk found, please upload file and parse it.")
return get_json_result(code=settings.RetCode.EXCEPTION_ERROR,
message="No chunk found, please upload file and parse it.")
return get_json_result(code=settings.RetCode.EXCEPTION_ERROR, message=repr(e))
@@ -214,12 +218,15 @@ def validate_request(*args, **kwargs):
废弃的装饰器:在 FastAPI 中使用 Pydantic 模型进行验证
这个函数保留是为了向后兼容,但不会执行任何验证
"""
def wrapper(func):
@wraps(func)
def decorated_function(*_args, **_kwargs):
# FastAPI 中不需要手动验证Pydantic 会自动处理
return func(*_args, **_kwargs)
return decorated_function
return wrapper
@@ -228,11 +235,14 @@ def not_allowed_parameters(*params):
废弃的装饰器:在 FastAPI 中使用 Pydantic 模型进行验证
这个函数保留是为了向后兼容,但不会执行任何验证
"""
def decorator(f):
def wrapper(*args, **kwargs):
# FastAPI 中不需要手动验证Pydantic 会自动处理
return f(*args, **kwargs)
return wrapper
return decorator
@@ -241,10 +251,12 @@ def active_required(f):
废弃的装饰器:在 FastAPI 中使用依赖注入进行用户验证
这个函数保留是为了向后兼容,但不会执行任何验证
"""
@wraps(f)
def wrapper(*args, **kwargs):
# FastAPI 中使用依赖注入进行用户验证
return f(*args, **kwargs)
return wrapper
@@ -281,10 +293,12 @@ def apikey_required(func):
废弃的装饰器:在 FastAPI 中使用依赖注入进行 API Key 验证
这个函数保留是为了向后兼容,但不会执行任何验证
"""
@wraps(func)
def decorated_function(*args, **kwargs):
# FastAPI 中使用依赖注入进行 API Key 验证
return func(*args, **kwargs)
return decorated_function
@@ -301,7 +315,7 @@ def construct_response(code=settings.RetCode.SUCCESS, message="success", data=No
continue
else:
response_dict[key] = value
headers = {
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Method": "*",
@@ -310,7 +324,7 @@ def construct_response(code=settings.RetCode.SUCCESS, message="success", data=No
}
if auth:
headers["Authorization"] = auth
return JSONResponse(content=response_dict, headers=headers)
@@ -349,10 +363,12 @@ def token_required(func):
废弃的装饰器:在 FastAPI 中使用依赖注入进行 Token 验证
这个函数保留是为了向后兼容,但不会执行任何验证
"""
@wraps(func)
def decorated_function(*args, **kwargs):
# FastAPI 中使用依赖注入进行 Token 验证
return func(*args, **kwargs)
return decorated_function
@@ -368,8 +384,8 @@ def get_result(code=settings.RetCode.SUCCESS, message="", data=None):
def get_error_data_result(
message="Sorry! Data missing!",
code=settings.RetCode.DATA_ERROR,
message="Sorry! Data missing!",
code=settings.RetCode.DATA_ERROR,
):
result_dict = {"code": code, "message": message}
response = {}
@@ -392,24 +408,24 @@ async def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(s
try:
jwt = URLSafeTimedSerializer(secret_key=settings.SECRET_KEY)
authorization = credentials.credentials
if authorization:
try:
access_token = str(jwt.loads(authorization))
if not access_token or not access_token.strip():
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authentication attempt with empty access token"
)
# Access tokens should be UUIDs (32 hex characters)
if len(access_token.strip()) < 32:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=f"Authentication attempt with invalid token format: {len(access_token)} chars"
)
user = UserService.query(
access_token=access_token, status=StatusEnum.VALID.value
)
@@ -474,7 +490,7 @@ def create_file_response(data, filename: str, media_type: str = "application/oct
data = json_dumps(data)
if isinstance(data, str):
data = data.encode("utf-8")
return StreamingResponse(
BytesIO(data),
media_type=media_type,
@@ -501,7 +517,8 @@ def get_parser_config(chunk_method, parser_config):
# Define default configurations for each chunking method
key_mapping = {
"naive": {"chunk_token_num": 512, "delimiter": r"\n", "html4excel": False, "layout_recognize": "DeepDOC", "raptor": {"use_raptor": False}, "graphrag": {"use_graphrag": False}},
"naive": {"chunk_token_num": 512, "delimiter": r"\n", "html4excel": False, "layout_recognize": "DeepDOC",
"raptor": {"use_raptor": False}, "graphrag": {"use_graphrag": False}},
"qa": {"raptor": {"use_raptor": False}, "graphrag": {"use_graphrag": False}},
"tag": None,
"resume": None,
@@ -540,16 +557,16 @@ def get_parser_config(chunk_method, parser_config):
def get_data_openai(
id=None,
created=None,
model=None,
prompt_tokens=0,
completion_tokens=0,
content=None,
finish_reason=None,
object="chat.completion",
param=None,
stream=False
id=None,
created=None,
model=None,
prompt_tokens=0,
completion_tokens=0,
content=None,
finish_reason=None,
object="chat.completion",
param=None,
stream=False
):
total_tokens = prompt_tokens + completion_tokens
@@ -661,7 +678,9 @@ def verify_embedding_availability(embd_id: str, tenant_id: str) -> tuple[bool, J
in_llm_service = bool(LLMService.query(llm_name=llm_name, fid=llm_factory, model_type="embedding"))
tenant_llms = TenantLLMService.get_my_llms(tenant_id=tenant_id)
is_tenant_model = any(llm["llm_name"] == llm_name and llm["llm_factory"] == llm_factory and llm["model_type"] == "embedding" for llm in tenant_llms)
is_tenant_model = any(
llm["llm_name"] == llm_name and llm["llm_factory"] == llm_factory and llm["model_type"] == "embedding" for
llm in tenant_llms)
is_builtin_model = embd_id in settings.BUILTIN_EMBEDDING_MODELS
if not (is_builtin_model or is_tenant_model or in_llm_service):
@@ -804,7 +823,8 @@ TimeoutException = Union[Type[BaseException], BaseException]
OnTimeoutCallback = Union[Callable[..., Any], Coroutine[Any, Any, Any]]
def timeout(seconds: float | int | str = None, attempts: int = 2, *, exception: Optional[TimeoutException] = None, on_timeout: Optional[OnTimeoutCallback] = None):
def timeout(seconds: float | int | str = None, attempts: int = 2, *, exception: Optional[TimeoutException] = None,
on_timeout: Optional[OnTimeoutCallback] = None):
if isinstance(seconds, str):
seconds = float(seconds)
def decorator(func):
@@ -892,7 +912,8 @@ async def is_strong_enough(chat_model, embedding_model):
_ = await trio.to_thread.run_sync(lambda: embedding_model.encode(["Are you strong enough!?"]))
if chat_model:
with trio.fail_after(30):
res = await trio.to_thread.run_sync(lambda: chat_model.chat("Nothing special.", [{"role": "user", "content": "Are you strong enough!?"}], {}))
res = await trio.to_thread.run_sync(lambda: chat_model.chat("Nothing special.", [
{"role": "user", "content": "Are you strong enough!?"}], {}))
if res.find("**ERROR**") >= 0:
raise Exception(res)

View File

@@ -0,0 +1,25 @@
"""
Reusable HTML email templates and registry.
"""
# Invitation email template
INVITE_EMAIL_TMPL = """
<p>Hi {{email}},</p>
<p>{{inviter}} has invited you to join their team (ID: {{tenant_id}}).</p>
<p>Click the link below to complete your registration:<br>
<a href="{{invite_url}}">{{invite_url}}</a></p>
<p>If you did not request this, please ignore this email.</p>
"""
# Password reset code template
RESET_CODE_EMAIL_TMPL = """
<p>Hello,</p>
<p>Your password reset code is: <b>{{ code }}</b></p>
<p>This code will expire in {{ ttl_min }} minutes.</p>
"""
# Template registry
EMAIL_TEMPLATES = {
"invite": INVITE_EMAIL_TMPL,
"reset_code": RESET_CODE_EMAIL_TMPL,
}

View File

@@ -13,7 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Standard library imports
import base64
import hashlib
import io
import json
import os
import re
@@ -22,13 +27,20 @@ import subprocess
import sys
import tempfile
import threading
import zipfile
from io import BytesIO
# Typing
from typing import List, Union, Tuple
# Third-party imports
import olefile
import pdfplumber
from cachetools import LRUCache, cached
from PIL import Image
from ruamel.yaml import YAML
# Local imports
from api.constants import IMG_BASE64_PREFIX
from api.db import FileType
@@ -161,7 +173,7 @@ def filename_type(filename):
if re.match(r".*\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus)$", filename):
return FileType.AURAL.value
if re.match(r".*\.(jpg|jpeg|png|tif|gif|pcx|tga|exif|fpx|svg|psd|cdr|pcd|dxf|ufo|eps|ai|raw|WMF|webp|avif|apng|icon|ico|mpg|mpeg|avi|rm|rmvb|mov|wmv|asf|dat|asx|wvx|mpe|mpa|mp4)$", filename):
if re.match(r".*\.(jpg|jpeg|png|tif|gif|pcx|tga|exif|fpx|svg|psd|cdr|pcd|dxf|ufo|eps|ai|raw|WMF|webp|avif|apng|icon|ico|mpg|mpeg|avi|rm|rmvb|mov|wmv|asf|dat|asx|wvx|mpe|mpa|mp4|avi|mkv)$", filename):
return FileType.VISUAL.value
return FileType.OTHER.value
@@ -284,3 +296,125 @@ def read_potential_broken_pdf(blob):
return repaired
return blob
def _is_zip(h: bytes) -> bool:
return h.startswith(b"PK\x03\x04") or h.startswith(b"PK\x05\x06") or h.startswith(b"PK\x07\x08")
def _is_pdf(h: bytes) -> bool:
return h.startswith(b"%PDF-")
def _is_ole(h: bytes) -> bool:
return h.startswith(b"\xD0\xCF\x11\xE0\xA1\xB1\x1A\xE1")
def _sha10(b: bytes) -> str:
return hashlib.sha256(b).hexdigest()[:10]
def _guess_ext(b: bytes) -> str:
h = b[:8]
if _is_zip(h):
try:
with zipfile.ZipFile(io.BytesIO(b), "r") as z:
names = [n.lower() for n in z.namelist()]
if any(n.startswith("word/") for n in names):
return ".docx"
if any(n.startswith("ppt/") for n in names):
return ".pptx"
if any(n.startswith("xl/") for n in names):
return ".xlsx"
except Exception:
pass
return ".zip"
if _is_pdf(h):
return ".pdf"
if _is_ole(h):
return ".doc"
return ".bin"
# Try to extract the real embedded payload from OLE's Ole10Native
def _extract_ole10native_payload(data: bytes) -> bytes:
try:
pos = 0
if len(data) < 4:
return data
_ = int.from_bytes(data[pos:pos+4], "little")
pos += 4
# filename/src/tmp (NUL-terminated ANSI)
for _ in range(3):
z = data.index(b"\x00", pos)
pos = z + 1
# skip unknown 4 bytes
pos += 4
if pos + 4 > len(data):
return data
size = int.from_bytes(data[pos:pos+4], "little")
pos += 4
if pos + size <= len(data):
return data[pos:pos+size]
except Exception:
pass
return data
def extract_embed_file(target: Union[bytes, bytearray]) -> List[Tuple[str, bytes]]:
"""
Only extract the 'first layer' of embedding, returning raw (filename, bytes).
"""
top = bytes(target)
head = top[:8]
out: List[Tuple[str, bytes]] = []
seen = set()
def push(b: bytes, name_hint: str = ""):
h10 = _sha10(b)
if h10 in seen:
return
seen.add(h10)
ext = _guess_ext(b)
# If name_hint has an extension use its basename; else fallback to guessed ext
if "." in name_hint:
fname = name_hint.split("/")[-1]
else:
fname = f"{h10}{ext}"
out.append((fname, b))
# OOXML/ZIP container (docx/xlsx/pptx)
if _is_zip(head):
try:
with zipfile.ZipFile(io.BytesIO(top), "r") as z:
embed_dirs = (
"word/embeddings/", "word/objects/", "word/activex/",
"xl/embeddings/", "ppt/embeddings/"
)
for name in z.namelist():
low = name.lower()
if any(low.startswith(d) for d in embed_dirs):
try:
b = z.read(name)
push(b, name)
except Exception:
pass
except Exception:
pass
return out
# OLE container (doc/ppt/xls)
if _is_ole(head):
try:
with olefile.OleFileIO(io.BytesIO(top)) as ole:
for entry in ole.listdir():
p = "/".join(entry)
try:
data = ole.openstream(entry).read()
except Exception:
continue
if not data:
continue
if "Ole10Native" in p or "ole10native" in p.lower():
data = _extract_ole10native_payload(data)
push(data, p)
except Exception:
pass
return out
return out

View File

@@ -74,12 +74,12 @@ def get_es_cluster_stats() -> dict:
raise Exception("Elasticsearch is not in use.")
try:
return {
"alive": True,
"status": "alive",
"message": ESConnection().get_cluster_stats()
}
except Exception as e:
return {
"alive": False,
"status": "timeout",
"message": f"error: {str(e)}",
}
@@ -90,12 +90,12 @@ def get_infinity_status():
raise Exception("Infinity is not in use.")
try:
return {
"alive": True,
"status": "alive",
"message": InfinityConnection().health()
}
except Exception as e:
return {
"alive": False,
"status": "timeout",
"message": f"error: {str(e)}",
}
@@ -107,12 +107,12 @@ def get_mysql_status():
headers = ['id', 'user', 'host', 'db', 'command', 'time', 'state', 'info']
cursor.close()
return {
"alive": True,
"status": "alive",
"message": [dict(zip(headers, r)) for r in res_rows]
}
except Exception as e:
return {
"alive": False,
"status": "timeout",
"message": f"error: {str(e)}",
}
@@ -122,12 +122,12 @@ def check_minio_alive():
try:
response = requests.get(f'http://{rag_settings.MINIO["host"]}/minio/health/live')
if response.status_code == 200:
return {'alive': True, "message": f"Confirm elapsed: {(timer() - start_time) * 1000.0:.1f} ms."}
return {"status": "alive", "message": f"Confirm elapsed: {(timer() - start_time) * 1000.0:.1f} ms."}
else:
return {'alive': False, "message": f"Confirm elapsed: {(timer() - start_time) * 1000.0:.1f} ms."}
return {"status": "timeout", "message": f"Confirm elapsed: {(timer() - start_time) * 1000.0:.1f} ms."}
except Exception as e:
return {
"alive": False,
"status": "timeout",
"message": f"error: {str(e)}",
}
@@ -135,12 +135,12 @@ def check_minio_alive():
def get_redis_info():
try:
return {
"alive": True,
"status": "alive",
"message": REDIS_CONN.info()
}
except Exception as e:
return {
"alive": False,
"status": "timeout",
"message": f"error: {str(e)}",
}
@@ -150,12 +150,12 @@ def check_ragflow_server_alive():
try:
response = requests.get(f'http://{settings.HOST_IP}:{settings.HOST_PORT}/v1/system/ping')
if response.status_code == 200:
return {'alive': True, "message": f"Confirm elapsed: {(timer() - start_time) * 1000.0:.1f} ms."}
return {"status": "alive", "message": f"Confirm elapsed: {(timer() - start_time) * 1000.0:.1f} ms."}
else:
return {'alive': False, "message": f"Confirm elapsed: {(timer() - start_time) * 1000.0:.1f} ms."}
return {"status": "timeout", "message": f"Confirm elapsed: {(timer() - start_time) * 1000.0:.1f} ms."}
except Exception as e:
return {
"alive": False,
"status": "timeout",
"message": f"error: {str(e)}",
}
@@ -192,9 +192,7 @@ def run_health_checks() -> tuple[dict, bool]:
except Exception:
result["storage"] = "nok"
all_ok = (result.get("db") == "ok") and (result.get("redis") == "ok") and (result.get("doc_engine") == "ok") and (result.get("storage") == "ok")
all_ok = (result.get("db") == "ok") and (result.get("redis") == "ok") and (result.get("doc_engine") == "ok") and (
result.get("storage") == "ok")
result["status"] = "ok" if all_ok else "nok"
return result, all_ok

View File

@@ -24,6 +24,7 @@ from urllib.parse import urlparse
from api.apps import smtp_mail_server
from flask_mail import Message
from flask import render_template_string
from api.utils.email_templates import EMAIL_TEMPLATES
from selenium import webdriver
from selenium.common.exceptions import TimeoutException
from selenium.webdriver.chrome.options import Options
@@ -34,6 +35,12 @@ from selenium.webdriver.support.ui import WebDriverWait
from webdriver_manager.chrome import ChromeDriverManager
OTP_LENGTH = 8
OTP_TTL_SECONDS = 5 * 60
ATTEMPT_LIMIT = 5
ATTEMPT_LOCK_SECONDS = 30 * 60
RESEND_COOLDOWN_SECONDS = 60
CONTENT_TYPE_MAP = {
# Office
@@ -178,24 +185,49 @@ def get_float(req: dict, key: str, default: float | int = 10.0) -> float:
return default
INVITE_EMAIL_TMPL = """
<p>Hi {{email}},</p>
<p>{{inviter}} has invited you to join their team (ID: {{tenant_id}}).</p>
<p>Click the link below to complete your registration:<br>
<a href="{{invite_url}}">{{invite_url}}</a></p>
<p>If you did not request this, please ignore this email.</p>
"""
def send_email_html(subject: str, to_email: str, template_key: str, **context):
"""Generic HTML email sender using shared templates.
template_key must exist in EMAIL_TEMPLATES.
"""
from api.apps import app
tmpl = EMAIL_TEMPLATES.get(template_key)
if not tmpl:
raise ValueError(f"Unknown email template: {template_key}")
with app.app_context():
msg = Message(subject=subject, recipients=[to_email])
msg.html = render_template_string(tmpl, **context)
smtp_mail_server.send(msg)
def send_invite_email(to_email, invite_url, tenant_id, inviter):
from api.apps import app
with app.app_context():
msg = Message(subject="RAGFlow Invitation",
recipients=[to_email])
msg.html = render_template_string(
INVITE_EMAIL_TMPL,
email=to_email,
invite_url=invite_url,
tenant_id=tenant_id,
inviter=inviter,
)
smtp_mail_server.send(msg)
# Reuse the generic HTML sender with 'invite' template
send_email_html(
subject="RAGFlow Invitation",
to_email=to_email,
template_key="invite",
email=to_email,
invite_url=invite_url,
tenant_id=tenant_id,
inviter=inviter,
)
def otp_keys(email: str):
email = (email or "").strip().lower()
return (
f"otp:{email}",
f"otp_attempts:{email}",
f"otp_last_sent:{email}",
f"otp_lock:{email}",
)
def hash_code(code: str, salt: bytes) -> str:
import hashlib
import hmac
return hmac.new(salt, (code or "").encode("utf-8"), hashlib.sha256).hexdigest()
def captcha_key(email: str) -> str:
return f"captcha:{email}"