Files
TERES_fastapi_backend/api/utils/api_utils.py

903 lines
32 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
import functools
import json
import logging
import os
import queue
import random
import threading
import time
from base64 import b64encode
from copy import deepcopy
from functools import wraps
from hmac import HMAC
from io import BytesIO
from typing import Any, Callable, Coroutine, Optional, Type, Union
from urllib.parse import quote, urlencode
from uuid import uuid1
import requests
import trio
# FastAPI imports
from fastapi import Request, Response as FastAPIResponse, HTTPException, status
from fastapi.responses import JSONResponse, FileResponse, StreamingResponse
from fastapi import Depends
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi.security.base import SecurityBase
from fastapi.openapi.models import OAuthFlows as OAuthFlowsModel
from itsdangerous import URLSafeTimedSerializer
from peewee import OperationalError
from werkzeug.http import HTTP_STATUS_CODES
from api import settings
from api.constants import REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC
from api.db import ActiveEnum
from api.db.db_models import APIToken
from api.db.services import UserService
from api.db.services.llm_service import LLMService
from api.db.services.tenant_llm_service import TenantLLMService
from api.utils.json import CustomJSONEncoder, json_dumps
# 自定义认证方案支持不传Bearer格式
class CustomHTTPBearer(SecurityBase):
def __init__(self, *, scheme_name: str = None, auto_error: bool = True):
self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error
# 添加 model 属性用于 OpenAPI 文档生成
self.model = HTTPBearer()
async def __call__(self, request: Request) -> HTTPAuthorizationCredentials:
authorization: str = request.headers.get("Authorization")
if not authorization:
if self.auto_error:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Not authenticated"
)
else:
return None
# 支持Bearer格式和直接token格式
if authorization.startswith("Bearer "):
token = authorization[7:] # 移除"Bearer "前缀
else:
token = authorization # 直接使用token
return HTTPAuthorizationCredentials(scheme="Bearer", credentials=token)
# FastAPI 安全方案
security = CustomHTTPBearer()
from api.utils import get_uuid
from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
requests.models.complexjson.dumps = functools.partial(json.dumps, cls=CustomJSONEncoder)
def serialize_for_json(obj):
"""
Recursively serialize objects to make them JSON serializable.
Handles ModelMetaclass and other non-serializable objects.
"""
if hasattr(obj, '__dict__'):
# For objects with __dict__, try to serialize their attributes
try:
return {key: serialize_for_json(value) for key, value in obj.__dict__.items()
if not key.startswith('_')}
except (AttributeError, TypeError):
return str(obj)
elif hasattr(obj, '__name__'):
# For classes and metaclasses, return their name
return f"<{obj.__module__}.{obj.__name__}>" if hasattr(obj, '__module__') else f"<{obj.__name__}>"
elif isinstance(obj, (list, tuple)):
return [serialize_for_json(item) for item in obj]
elif isinstance(obj, dict):
return {key: serialize_for_json(value) for key, value in obj.items()}
elif isinstance(obj, (str, int, float, bool)) or obj is None:
return obj
else:
# Fallback: convert to string representation
return str(obj)
def request(**kwargs):
sess = requests.Session()
stream = kwargs.pop("stream", sess.stream)
timeout = kwargs.pop("timeout", None)
kwargs["headers"] = {k.replace("_", "-").upper(): v for k, v in kwargs.get("headers", {}).items()}
prepped = requests.Request(**kwargs).prepare()
if settings.CLIENT_AUTHENTICATION and settings.HTTP_APP_KEY and settings.SECRET_KEY:
timestamp = str(round(time() * 1000))
nonce = str(uuid1())
signature = b64encode(
HMAC(
settings.SECRET_KEY.encode("ascii"),
b"\n".join(
[
timestamp.encode("ascii"),
nonce.encode("ascii"),
settings.HTTP_APP_KEY.encode("ascii"),
prepped.path_url.encode("ascii"),
prepped.body if kwargs.get("json") else b"",
urlencode(sorted(kwargs["data"].items()), quote_via=quote, safe="-._~").encode("ascii") if kwargs.get("data") and isinstance(kwargs["data"], dict) else b"",
]
),
"sha1",
).digest()
).decode("ascii")
prepped.headers.update(
{
"TIMESTAMP": timestamp,
"NONCE": nonce,
"APP-KEY": settings.HTTP_APP_KEY,
"SIGNATURE": signature,
}
)
return sess.send(prepped, stream=stream, timeout=timeout)
def get_exponential_backoff_interval(retries, full_jitter=False):
"""Calculate the exponential backoff wait time."""
# Will be zero if factor equals 0
countdown = min(REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC * (2**retries))
# Full jitter according to
# https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/
if full_jitter:
countdown = random.randrange(countdown + 1)
# Adjust according to maximum wait time and account for negative values.
return max(0, countdown)
def get_data_error_result(code=settings.RetCode.DATA_ERROR, message="Sorry! Data missing!"):
logging.exception(Exception(message))
result_dict = {"code": code, "message": message}
response = {}
for key, value in result_dict.items():
if value is None and key != "code":
continue
else:
response[key] = value
return JSONResponse(content=response)
def server_error_response(e):
logging.exception(e)
try:
if e.code == 401:
return get_json_result(code=401, message=repr(e))
except BaseException:
pass
if len(e.args) > 1:
try:
serialized_data = serialize_for_json(e.args[1])
return get_json_result(code= settings.RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=serialized_data)
except Exception:
return get_json_result(code=settings.RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=None)
if repr(e).find("index_not_found_exception") >= 0:
return get_json_result(code=settings.RetCode.EXCEPTION_ERROR, message="No chunk found, please upload file and parse it.")
return get_json_result(code=settings.RetCode.EXCEPTION_ERROR, message=repr(e))
def error_response(response_code, message=None):
if message is None:
message = HTTP_STATUS_CODES.get(response_code, "Unknown Error")
return JSONResponse(
content={
"message": message,
"code": response_code,
},
status_code=response_code,
)
# FastAPI 版本:使用 Pydantic 模型进行验证,而不是装饰器
# 这个装饰器在 FastAPI 中不再需要,因为 FastAPI 会自动验证 Pydantic 模型
def validate_request(*args, **kwargs):
"""
废弃的装饰器:在 FastAPI 中使用 Pydantic 模型进行验证
这个函数保留是为了向后兼容,但不会执行任何验证
"""
def wrapper(func):
@wraps(func)
def decorated_function(*_args, **_kwargs):
# FastAPI 中不需要手动验证Pydantic 会自动处理
return func(*_args, **_kwargs)
return decorated_function
return wrapper
def not_allowed_parameters(*params):
"""
废弃的装饰器:在 FastAPI 中使用 Pydantic 模型进行验证
这个函数保留是为了向后兼容,但不会执行任何验证
"""
def decorator(f):
def wrapper(*args, **kwargs):
# FastAPI 中不需要手动验证Pydantic 会自动处理
return f(*args, **kwargs)
return wrapper
return decorator
def active_required(f):
"""
废弃的装饰器:在 FastAPI 中使用依赖注入进行用户验证
这个函数保留是为了向后兼容,但不会执行任何验证
"""
@wraps(f)
def wrapper(*args, **kwargs):
# FastAPI 中使用依赖注入进行用户验证
return f(*args, **kwargs)
return wrapper
def is_localhost(ip):
return ip in {"127.0.0.1", "::1", "[::1]", "localhost"}
def send_file_in_mem(data, filename):
"""
发送内存中的文件数据
注意:在 FastAPI 中,这个函数需要接收 Request 参数来正确处理响应
"""
if not isinstance(data, (str, bytes)):
data = json_dumps(data)
if isinstance(data, str):
data = data.encode("utf-8")
f = BytesIO()
f.write(data)
f.seek(0)
# 在 FastAPI 中,应该使用 FileResponse 或 StreamingResponse
# 这里返回文件对象,调用者需要处理响应
return f
def get_json_result(code=settings.RetCode.SUCCESS, message="success", data=None):
response = {"code": code, "message": message, "data": data}
return JSONResponse(content=response)
def apikey_required(func):
"""
废弃的装饰器:在 FastAPI 中使用依赖注入进行 API Key 验证
这个函数保留是为了向后兼容,但不会执行任何验证
"""
@wraps(func)
def decorated_function(*args, **kwargs):
# FastAPI 中使用依赖注入进行 API Key 验证
return func(*args, **kwargs)
return decorated_function
def build_error_result(code=settings.RetCode.FORBIDDEN, message="success"):
response = {"code": code, "message": message}
return JSONResponse(content=response, status_code=code)
def construct_response(code=settings.RetCode.SUCCESS, message="success", data=None, auth=None):
result_dict = {"code": code, "message": message, "data": data}
response_dict = {}
for key, value in result_dict.items():
if value is None and key != "code":
continue
else:
response_dict[key] = value
headers = {
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Method": "*",
"Access-Control-Allow-Headers": "*",
"Access-Control-Expose-Headers": "Authorization"
}
if auth:
headers["Authorization"] = auth
return JSONResponse(content=response_dict, headers=headers)
def construct_result(code=settings.RetCode.DATA_ERROR, message="data is missing"):
result_dict = {"code": code, "message": message}
response = {}
for key, value in result_dict.items():
if value is None and key != "code":
continue
else:
response[key] = value
return JSONResponse(content=response)
def construct_json_result(code=settings.RetCode.SUCCESS, message="success", data=None):
if data is None:
return JSONResponse(content={"code": code, "message": message})
else:
return JSONResponse(content={"code": code, "message": message, "data": data})
def construct_error_response(e):
logging.exception(e)
try:
if e.code == 401:
return construct_json_result(code=settings.RetCode.UNAUTHORIZED, message=repr(e))
except BaseException:
pass
if len(e.args) > 1:
return construct_json_result(code=settings.RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1])
return construct_json_result(code=settings.RetCode.EXCEPTION_ERROR, message=repr(e))
def token_required(func):
"""
废弃的装饰器:在 FastAPI 中使用依赖注入进行 Token 验证
这个函数保留是为了向后兼容,但不会执行任何验证
"""
@wraps(func)
def decorated_function(*args, **kwargs):
# FastAPI 中使用依赖注入进行 Token 验证
return func(*args, **kwargs)
return decorated_function
def get_result(code=settings.RetCode.SUCCESS, message="", data=None):
if code == 0:
if data is not None:
response = {"code": code, "data": data}
else:
response = {"code": code}
else:
response = {"code": code, "message": message}
return JSONResponse(content=response)
def get_error_data_result(
message="Sorry! Data missing!",
code=settings.RetCode.DATA_ERROR,
):
result_dict = {"code": code, "message": message}
response = {}
for key, value in result_dict.items():
if value is None and key != "code":
continue
else:
response[key] = value
return JSONResponse(content=response)
def get_error_argument_result(message="Invalid arguments"):
return get_result(code=settings.RetCode.ARGUMENT_ERROR, message=message)
# FastAPI 依赖注入函数
async def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(security)):
"""获取当前用户 - FastAPI 版本"""
from api.db import StatusEnum
try:
jwt = URLSafeTimedSerializer(secret_key=settings.SECRET_KEY)
authorization = credentials.credentials
if authorization:
try:
access_token = str(jwt.loads(authorization))
if not access_token or not access_token.strip():
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authentication attempt with empty access token"
)
# Access tokens should be UUIDs (32 hex characters)
if len(access_token.strip()) < 32:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=f"Authentication attempt with invalid token format: {len(access_token)} chars"
)
user = UserService.query(
access_token=access_token, status=StatusEnum.VALID.value
)
if user:
if not user[0].access_token or not user[0].access_token.strip():
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authentication attempt with empty access token"
)
return user[0]
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authentication failed: Invalid access token"
)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=f"Authentication failed: {str(e)}"
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authentication failed: No authorization header"
)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=f"Authentication failed: {str(e)}"
)
async def get_current_user_optional(credentials: HTTPAuthorizationCredentials = Depends(security)):
"""获取当前用户(可选)- FastAPI 版本"""
try:
return await get_current_user(credentials)
except HTTPException:
return None
async def verify_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)):
"""验证 API Key - FastAPI 版本"""
try:
token = credentials.credentials
objs = APIToken.query(token=token)
if not objs:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="API-KEY is invalid!"
)
return objs[0]
except Exception as e:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"API Key verification failed: {str(e)}"
)
def create_file_response(data, filename: str, media_type: str = "application/octet-stream"):
"""创建文件响应 - FastAPI 版本"""
if not isinstance(data, (str, bytes)):
data = json_dumps(data)
if isinstance(data, str):
data = data.encode("utf-8")
return StreamingResponse(
BytesIO(data),
media_type=media_type,
headers={"Content-Disposition": f"attachment; filename={filename}"}
)
def get_error_permission_result(message="Permission error"):
return get_result(code=settings.RetCode.PERMISSION_ERROR, message=message)
def get_error_operating_result(message="Operating error"):
return get_result(code=settings.RetCode.OPERATING_ERROR, message=message)
def generate_confirmation_token(tenant_id):
serializer = URLSafeTimedSerializer(tenant_id)
return "ragflow-" + serializer.dumps(get_uuid(), salt=tenant_id)[2:34]
def get_parser_config(chunk_method, parser_config):
if not chunk_method:
chunk_method = "naive"
# Define default configurations for each chunking method
key_mapping = {
"naive": {"chunk_token_num": 512, "delimiter": r"\n", "html4excel": False, "layout_recognize": "DeepDOC", "raptor": {"use_raptor": False}, "graphrag": {"use_graphrag": False}},
"qa": {"raptor": {"use_raptor": False}, "graphrag": {"use_graphrag": False}},
"tag": None,
"resume": None,
"manual": {"raptor": {"use_raptor": False}, "graphrag": {"use_graphrag": False}},
"table": None,
"paper": {"raptor": {"use_raptor": False}, "graphrag": {"use_graphrag": False}},
"book": {"raptor": {"use_raptor": False}, "graphrag": {"use_graphrag": False}},
"laws": {"raptor": {"use_raptor": False}, "graphrag": {"use_graphrag": False}},
"presentation": {"raptor": {"use_raptor": False}, "graphrag": {"use_graphrag": False}},
"one": None,
"knowledge_graph": {
"chunk_token_num": 8192,
"delimiter": r"\n",
"entity_types": ["organization", "person", "location", "event", "time"],
"raptor": {"use_raptor": False},
"graphrag": {"use_graphrag": False},
},
"email": None,
"picture": None,
}
default_config = key_mapping[chunk_method]
# If no parser_config provided, return default
if not parser_config:
return default_config
# If parser_config is provided, merge with defaults to ensure required fields exist
if default_config is None:
return parser_config
# Ensure raptor and graphrag fields have default values if not provided
merged_config = deep_merge(default_config, parser_config)
return merged_config
def get_data_openai(
id=None,
created=None,
model=None,
prompt_tokens=0,
completion_tokens=0,
content=None,
finish_reason=None,
object="chat.completion",
param=None,
stream=False
):
total_tokens = prompt_tokens + completion_tokens
if stream:
return {
"id": f"{id}",
"object": "chat.completion.chunk",
"model": model,
"choices": [{
"delta": {"content": content},
"finish_reason": finish_reason,
"index": 0,
}],
}
return {
"id": f"{id}",
"object": object,
"created": int(time.time()) if created else None,
"model": model,
"param": param,
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": total_tokens,
"completion_tokens_details": {
"reasoning_tokens": 0,
"accepted_prediction_tokens": 0,
"rejected_prediction_tokens": 0,
},
},
"choices": [{
"message": {
"role": "assistant",
"content": content
},
"logprobs": None,
"finish_reason": finish_reason,
"index": 0,
}],
}
def check_duplicate_ids(ids, id_type="item"):
"""
Check for duplicate IDs in a list and return unique IDs and error messages.
Args:
ids (list): List of IDs to check for duplicates
id_type (str): Type of ID for error messages (e.g., 'document', 'dataset', 'chunk')
Returns:
tuple: (unique_ids, error_messages)
- unique_ids (list): List of unique IDs
- error_messages (list): List of error messages for duplicate IDs
"""
id_count = {}
duplicate_messages = []
# Count occurrences of each ID
for id_value in ids:
id_count[id_value] = id_count.get(id_value, 0) + 1
# Check for duplicates
for id_value, count in id_count.items():
if count > 1:
duplicate_messages.append(f"Duplicate {id_type} ids: {id_value}")
# Return unique IDs and error messages
return list(set(ids)), duplicate_messages
def verify_embedding_availability(embd_id: str, tenant_id: str) -> tuple[bool, JSONResponse | None]:
"""
Verifies availability of an embedding model for a specific tenant.
Performs comprehensive verification through:
1. Identifier Parsing: Decomposes embd_id into name and factory components
2. System Verification: Checks model registration in LLMService
3. Tenant Authorization: Validates tenant-specific model assignments
4. Built-in Model Check: Confirms inclusion in predefined system models
Args:
embd_id (str): Unique identifier for the embedding model in format "model_name@factory"
tenant_id (str): Tenant identifier for access control
Returns:
tuple[bool, Response | None]:
- First element (bool):
- True: Model is available and authorized
- False: Validation failed
- Second element contains:
- None on success
- Error detail dict on failure
Raises:
ValueError: When model identifier format is invalid
OperationalError: When database connection fails (auto-handled)
Examples:
>>> verify_embedding_availability("text-embedding@openai", "tenant_123")
(True, None)
>>> verify_embedding_availability("invalid_model", "tenant_123")
(False, {'code': 101, 'message': "Unsupported model: <invalid_model>"})
"""
try:
llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(embd_id)
in_llm_service = bool(LLMService.query(llm_name=llm_name, fid=llm_factory, model_type="embedding"))
tenant_llms = TenantLLMService.get_my_llms(tenant_id=tenant_id)
is_tenant_model = any(llm["llm_name"] == llm_name and llm["llm_factory"] == llm_factory and llm["model_type"] == "embedding" for llm in tenant_llms)
is_builtin_model = embd_id in settings.BUILTIN_EMBEDDING_MODELS
if not (is_builtin_model or is_tenant_model or in_llm_service):
return False, get_error_argument_result(f"Unsupported model: <{embd_id}>")
if not (is_builtin_model or is_tenant_model):
return False, get_error_argument_result(f"Unauthorized model: <{embd_id}>")
except OperationalError as e:
logging.exception(e)
return False, get_error_data_result(message="Database operation failed")
return True, None
def deep_merge(default: dict, custom: dict) -> dict:
"""
Recursively merges two dictionaries with priority given to `custom` values.
Creates a deep copy of the `default` dictionary and iteratively merges nested
dictionaries using a stack-based approach. Non-dict values in `custom` will
completely override corresponding entries in `default`.
Args:
default (dict): Base dictionary containing default values.
custom (dict): Dictionary containing overriding values.
Returns:
dict: New merged dictionary combining values from both inputs.
Example:
>>> from copy import deepcopy
>>> default = {"a": 1, "nested": {"x": 10, "y": 20}}
>>> custom = {"b": 2, "nested": {"y": 99, "z": 30}}
>>> deep_merge(default, custom)
{'a': 1, 'b': 2, 'nested': {'x': 10, 'y': 99, 'z': 30}}
>>> deep_merge({"config": {"mode": "auto"}}, {"config": "manual"})
{'config': 'manual'}
Notes:
1. Merge priority is always given to `custom` values at all nesting levels
2. Non-dict values (e.g. list, str) in `custom` will replace entire values
in `default`, even if the original value was a dictionary
3. Time complexity: O(N) where N is total key-value pairs in `custom`
4. Recommended for configuration merging and nested data updates
"""
merged = deepcopy(default)
stack = [(merged, custom)]
while stack:
base_dict, override_dict = stack.pop()
for key, val in override_dict.items():
if key in base_dict and isinstance(val, dict) and isinstance(base_dict[key], dict):
stack.append((base_dict[key], val))
else:
base_dict[key] = val
return merged
def remap_dictionary_keys(source_data: dict, key_aliases: dict = None) -> dict:
"""
Transform dictionary keys using a configurable mapping schema.
Args:
source_data: Original dictionary to process
key_aliases: Custom key transformation rules (Optional)
When provided, overrides default key mapping
Format: {<original_key>: <new_key>, ...}
Returns:
dict: New dictionary with transformed keys preserving original values
Example:
>>> input_data = {"old_key": "value", "another_field": 42}
>>> remap_dictionary_keys(input_data, {"old_key": "new_key"})
{'new_key': 'value', 'another_field': 42}
"""
DEFAULT_KEY_MAP = {
"chunk_num": "chunk_count",
"doc_num": "document_count",
"parser_id": "chunk_method",
"embd_id": "embedding_model",
}
transformed_data = {}
mapping = key_aliases or DEFAULT_KEY_MAP
for original_key, value in source_data.items():
mapped_key = mapping.get(original_key, original_key)
transformed_data[mapped_key] = value
return transformed_data
def group_by(list_of_dict, key):
res = {}
for item in list_of_dict:
if item[key] in res.keys():
res[item[key]].append(item)
else:
res[item[key]] = [item]
return res
def get_mcp_tools(mcp_servers: list, timeout: float | int = 10) -> tuple[dict, str]:
results = {}
tool_call_sessions = []
try:
for mcp_server in mcp_servers:
server_key = mcp_server.id
cached_tools = mcp_server.variables.get("tools", {})
tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables)
tool_call_sessions.append(tool_call_session)
try:
tools = tool_call_session.get_tools(timeout)
except Exception:
tools = []
results[server_key] = []
for tool in tools:
tool_dict = tool.model_dump()
cached_tool = cached_tools.get(tool_dict["name"], {})
tool_dict["enabled"] = cached_tool.get("enabled", True)
results[server_key].append(tool_dict)
# PERF: blocking call to close sessions — consider moving to background thread or task queue
close_multiple_mcp_toolcall_sessions(tool_call_sessions)
return results, ""
except Exception as e:
return {}, str(e)
TimeoutException = Union[Type[BaseException], BaseException]
OnTimeoutCallback = Union[Callable[..., Any], Coroutine[Any, Any, Any]]
def timeout(seconds: float | int | str = None, attempts: int = 2, *, exception: Optional[TimeoutException] = None, on_timeout: Optional[OnTimeoutCallback] = None):
if isinstance(seconds, str):
seconds = float(seconds)
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
result_queue = queue.Queue(maxsize=1)
def target():
try:
result = func(*args, **kwargs)
result_queue.put(result)
except Exception as e:
result_queue.put(e)
thread = threading.Thread(target=target)
thread.daemon = True
thread.start()
for a in range(attempts):
try:
if os.environ.get("ENABLE_TIMEOUT_ASSERTION"):
result = result_queue.get(timeout=seconds)
else:
result = result_queue.get()
if isinstance(result, Exception):
raise result
return result
except queue.Empty:
pass
raise TimeoutError(f"Function '{func.__name__}' timed out after {seconds} seconds and {attempts} attempts.")
@wraps(func)
async def async_wrapper(*args, **kwargs) -> Any:
if seconds is None:
return await func(*args, **kwargs)
for a in range(attempts):
try:
if os.environ.get("ENABLE_TIMEOUT_ASSERTION"):
with trio.fail_after(seconds):
return await func(*args, **kwargs)
else:
return await func(*args, **kwargs)
except trio.TooSlowError:
if a < attempts - 1:
continue
if on_timeout is not None:
if callable(on_timeout):
result = on_timeout()
if isinstance(result, Coroutine):
return await result
return result
return on_timeout
if exception is None:
raise TimeoutError(f"Operation timed out after {seconds} seconds and {attempts} attempts.")
if isinstance(exception, BaseException):
raise exception
if isinstance(exception, type) and issubclass(exception, BaseException):
raise exception(f"Operation timed out after {seconds} seconds and {attempts} attempts.")
raise RuntimeError("Invalid exception type provided")
if asyncio.iscoroutinefunction(func):
return async_wrapper
return wrapper
return decorator
async def is_strong_enough(chat_model, embedding_model):
count = settings.STRONG_TEST_COUNT
if not chat_model or not embedding_model:
return
if isinstance(count, int) and count <= 0:
return
@timeout(60, 2)
async def _is_strong_enough():
nonlocal chat_model, embedding_model
if embedding_model:
with trio.fail_after(10):
_ = await trio.to_thread.run_sync(lambda: embedding_model.encode(["Are you strong enough!?"]))
if chat_model:
with trio.fail_after(30):
res = await trio.to_thread.run_sync(lambda: chat_model.chat("Nothing special.", [{"role": "user", "content": "Are you strong enough!?"}], {}))
if res.find("**ERROR**") >= 0:
raise Exception(res)
# Pressure test for GraphRAG task
async with trio.open_nursery() as nursery:
for _ in range(count):
nursery.start_soon(_is_strong_enough)