v0.21.1-fastapi
This commit is contained in:
@@ -53,6 +53,7 @@ from api.db.services.llm_service import LLMService
|
||||
from api.db.services.tenant_llm_service import TenantLLMService
|
||||
from api.utils.json import CustomJSONEncoder, json_dumps
|
||||
|
||||
|
||||
# 自定义认证方案,支持不传Bearer格式
|
||||
class CustomHTTPBearer(SecurityBase):
|
||||
def __init__(self, *, scheme_name: str = None, auto_error: bool = True):
|
||||
@@ -71,13 +72,13 @@ class CustomHTTPBearer(SecurityBase):
|
||||
)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
# 支持Bearer格式和直接token格式
|
||||
if authorization.startswith("Bearer "):
|
||||
token = authorization[7:] # 移除"Bearer "前缀
|
||||
else:
|
||||
token = authorization # 直接使用token
|
||||
|
||||
|
||||
return HTTPAuthorizationCredentials(scheme="Bearer", credentials=token)
|
||||
|
||||
# FastAPI 安全方案
|
||||
@@ -95,8 +96,8 @@ def serialize_for_json(obj):
|
||||
if hasattr(obj, '__dict__'):
|
||||
# For objects with __dict__, try to serialize their attributes
|
||||
try:
|
||||
return {key: serialize_for_json(value) for key, value in obj.__dict__.items()
|
||||
if not key.startswith('_')}
|
||||
return {key: serialize_for_json(value) for key, value in obj.__dict__.items()
|
||||
if not key.startswith('_')}
|
||||
except (AttributeError, TypeError):
|
||||
return str(obj)
|
||||
elif hasattr(obj, '__name__'):
|
||||
@@ -112,6 +113,7 @@ def serialize_for_json(obj):
|
||||
# Fallback: convert to string representation
|
||||
return str(obj)
|
||||
|
||||
|
||||
def request(**kwargs):
|
||||
sess = requests.Session()
|
||||
stream = kwargs.pop("stream", sess.stream)
|
||||
@@ -132,7 +134,8 @@ def request(**kwargs):
|
||||
settings.HTTP_APP_KEY.encode("ascii"),
|
||||
prepped.path_url.encode("ascii"),
|
||||
prepped.body if kwargs.get("json") else b"",
|
||||
urlencode(sorted(kwargs["data"].items()), quote_via=quote, safe="-._~").encode("ascii") if kwargs.get("data") and isinstance(kwargs["data"], dict) else b"",
|
||||
urlencode(sorted(kwargs["data"].items()), quote_via=quote, safe="-._~").encode(
|
||||
"ascii") if kwargs.get("data") and isinstance(kwargs["data"], dict) else b"",
|
||||
]
|
||||
),
|
||||
"sha1",
|
||||
@@ -154,7 +157,7 @@ def request(**kwargs):
|
||||
def get_exponential_backoff_interval(retries, full_jitter=False):
|
||||
"""Calculate the exponential backoff wait time."""
|
||||
# Will be zero if factor equals 0
|
||||
countdown = min(REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC * (2**retries))
|
||||
countdown = min(REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC * (2 ** retries))
|
||||
# Full jitter according to
|
||||
# https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/
|
||||
if full_jitter:
|
||||
@@ -185,11 +188,12 @@ def server_error_response(e):
|
||||
if len(e.args) > 1:
|
||||
try:
|
||||
serialized_data = serialize_for_json(e.args[1])
|
||||
return get_json_result(code= settings.RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=serialized_data)
|
||||
return get_json_result(code=settings.RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=serialized_data)
|
||||
except Exception:
|
||||
return get_json_result(code=settings.RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=None)
|
||||
if repr(e).find("index_not_found_exception") >= 0:
|
||||
return get_json_result(code=settings.RetCode.EXCEPTION_ERROR, message="No chunk found, please upload file and parse it.")
|
||||
return get_json_result(code=settings.RetCode.EXCEPTION_ERROR,
|
||||
message="No chunk found, please upload file and parse it.")
|
||||
|
||||
return get_json_result(code=settings.RetCode.EXCEPTION_ERROR, message=repr(e))
|
||||
|
||||
@@ -214,12 +218,15 @@ def validate_request(*args, **kwargs):
|
||||
废弃的装饰器:在 FastAPI 中使用 Pydantic 模型进行验证
|
||||
这个函数保留是为了向后兼容,但不会执行任何验证
|
||||
"""
|
||||
|
||||
def wrapper(func):
|
||||
@wraps(func)
|
||||
def decorated_function(*_args, **_kwargs):
|
||||
# FastAPI 中不需要手动验证,Pydantic 会自动处理
|
||||
return func(*_args, **_kwargs)
|
||||
|
||||
return decorated_function
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@@ -228,11 +235,14 @@ def not_allowed_parameters(*params):
|
||||
废弃的装饰器:在 FastAPI 中使用 Pydantic 模型进行验证
|
||||
这个函数保留是为了向后兼容,但不会执行任何验证
|
||||
"""
|
||||
|
||||
def decorator(f):
|
||||
def wrapper(*args, **kwargs):
|
||||
# FastAPI 中不需要手动验证,Pydantic 会自动处理
|
||||
return f(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
@@ -241,10 +251,12 @@ def active_required(f):
|
||||
废弃的装饰器:在 FastAPI 中使用依赖注入进行用户验证
|
||||
这个函数保留是为了向后兼容,但不会执行任何验证
|
||||
"""
|
||||
|
||||
@wraps(f)
|
||||
def wrapper(*args, **kwargs):
|
||||
# FastAPI 中使用依赖注入进行用户验证
|
||||
return f(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@@ -281,10 +293,12 @@ def apikey_required(func):
|
||||
废弃的装饰器:在 FastAPI 中使用依赖注入进行 API Key 验证
|
||||
这个函数保留是为了向后兼容,但不会执行任何验证
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def decorated_function(*args, **kwargs):
|
||||
# FastAPI 中使用依赖注入进行 API Key 验证
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return decorated_function
|
||||
|
||||
|
||||
@@ -301,7 +315,7 @@ def construct_response(code=settings.RetCode.SUCCESS, message="success", data=No
|
||||
continue
|
||||
else:
|
||||
response_dict[key] = value
|
||||
|
||||
|
||||
headers = {
|
||||
"Access-Control-Allow-Origin": "*",
|
||||
"Access-Control-Allow-Method": "*",
|
||||
@@ -310,7 +324,7 @@ def construct_response(code=settings.RetCode.SUCCESS, message="success", data=No
|
||||
}
|
||||
if auth:
|
||||
headers["Authorization"] = auth
|
||||
|
||||
|
||||
return JSONResponse(content=response_dict, headers=headers)
|
||||
|
||||
|
||||
@@ -349,10 +363,12 @@ def token_required(func):
|
||||
废弃的装饰器:在 FastAPI 中使用依赖注入进行 Token 验证
|
||||
这个函数保留是为了向后兼容,但不会执行任何验证
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def decorated_function(*args, **kwargs):
|
||||
# FastAPI 中使用依赖注入进行 Token 验证
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return decorated_function
|
||||
|
||||
|
||||
@@ -368,8 +384,8 @@ def get_result(code=settings.RetCode.SUCCESS, message="", data=None):
|
||||
|
||||
|
||||
def get_error_data_result(
|
||||
message="Sorry! Data missing!",
|
||||
code=settings.RetCode.DATA_ERROR,
|
||||
message="Sorry! Data missing!",
|
||||
code=settings.RetCode.DATA_ERROR,
|
||||
):
|
||||
result_dict = {"code": code, "message": message}
|
||||
response = {}
|
||||
@@ -392,24 +408,24 @@ async def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(s
|
||||
try:
|
||||
jwt = URLSafeTimedSerializer(secret_key=settings.SECRET_KEY)
|
||||
authorization = credentials.credentials
|
||||
|
||||
|
||||
if authorization:
|
||||
try:
|
||||
access_token = str(jwt.loads(authorization))
|
||||
|
||||
|
||||
if not access_token or not access_token.strip():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authentication attempt with empty access token"
|
||||
)
|
||||
|
||||
|
||||
# Access tokens should be UUIDs (32 hex characters)
|
||||
if len(access_token.strip()) < 32:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=f"Authentication attempt with invalid token format: {len(access_token)} chars"
|
||||
)
|
||||
|
||||
|
||||
user = UserService.query(
|
||||
access_token=access_token, status=StatusEnum.VALID.value
|
||||
)
|
||||
@@ -474,7 +490,7 @@ def create_file_response(data, filename: str, media_type: str = "application/oct
|
||||
data = json_dumps(data)
|
||||
if isinstance(data, str):
|
||||
data = data.encode("utf-8")
|
||||
|
||||
|
||||
return StreamingResponse(
|
||||
BytesIO(data),
|
||||
media_type=media_type,
|
||||
@@ -501,7 +517,8 @@ def get_parser_config(chunk_method, parser_config):
|
||||
|
||||
# Define default configurations for each chunking method
|
||||
key_mapping = {
|
||||
"naive": {"chunk_token_num": 512, "delimiter": r"\n", "html4excel": False, "layout_recognize": "DeepDOC", "raptor": {"use_raptor": False}, "graphrag": {"use_graphrag": False}},
|
||||
"naive": {"chunk_token_num": 512, "delimiter": r"\n", "html4excel": False, "layout_recognize": "DeepDOC",
|
||||
"raptor": {"use_raptor": False}, "graphrag": {"use_graphrag": False}},
|
||||
"qa": {"raptor": {"use_raptor": False}, "graphrag": {"use_graphrag": False}},
|
||||
"tag": None,
|
||||
"resume": None,
|
||||
@@ -540,16 +557,16 @@ def get_parser_config(chunk_method, parser_config):
|
||||
|
||||
|
||||
def get_data_openai(
|
||||
id=None,
|
||||
created=None,
|
||||
model=None,
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
content=None,
|
||||
finish_reason=None,
|
||||
object="chat.completion",
|
||||
param=None,
|
||||
stream=False
|
||||
id=None,
|
||||
created=None,
|
||||
model=None,
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
content=None,
|
||||
finish_reason=None,
|
||||
object="chat.completion",
|
||||
param=None,
|
||||
stream=False
|
||||
):
|
||||
total_tokens = prompt_tokens + completion_tokens
|
||||
|
||||
@@ -661,7 +678,9 @@ def verify_embedding_availability(embd_id: str, tenant_id: str) -> tuple[bool, J
|
||||
in_llm_service = bool(LLMService.query(llm_name=llm_name, fid=llm_factory, model_type="embedding"))
|
||||
|
||||
tenant_llms = TenantLLMService.get_my_llms(tenant_id=tenant_id)
|
||||
is_tenant_model = any(llm["llm_name"] == llm_name and llm["llm_factory"] == llm_factory and llm["model_type"] == "embedding" for llm in tenant_llms)
|
||||
is_tenant_model = any(
|
||||
llm["llm_name"] == llm_name and llm["llm_factory"] == llm_factory and llm["model_type"] == "embedding" for
|
||||
llm in tenant_llms)
|
||||
|
||||
is_builtin_model = embd_id in settings.BUILTIN_EMBEDDING_MODELS
|
||||
if not (is_builtin_model or is_tenant_model or in_llm_service):
|
||||
@@ -804,7 +823,8 @@ TimeoutException = Union[Type[BaseException], BaseException]
|
||||
OnTimeoutCallback = Union[Callable[..., Any], Coroutine[Any, Any, Any]]
|
||||
|
||||
|
||||
def timeout(seconds: float | int | str = None, attempts: int = 2, *, exception: Optional[TimeoutException] = None, on_timeout: Optional[OnTimeoutCallback] = None):
|
||||
def timeout(seconds: float | int | str = None, attempts: int = 2, *, exception: Optional[TimeoutException] = None,
|
||||
on_timeout: Optional[OnTimeoutCallback] = None):
|
||||
if isinstance(seconds, str):
|
||||
seconds = float(seconds)
|
||||
def decorator(func):
|
||||
@@ -892,7 +912,8 @@ async def is_strong_enough(chat_model, embedding_model):
|
||||
_ = await trio.to_thread.run_sync(lambda: embedding_model.encode(["Are you strong enough!?"]))
|
||||
if chat_model:
|
||||
with trio.fail_after(30):
|
||||
res = await trio.to_thread.run_sync(lambda: chat_model.chat("Nothing special.", [{"role": "user", "content": "Are you strong enough!?"}], {}))
|
||||
res = await trio.to_thread.run_sync(lambda: chat_model.chat("Nothing special.", [
|
||||
{"role": "user", "content": "Are you strong enough!?"}], {}))
|
||||
if res.find("**ERROR**") >= 0:
|
||||
raise Exception(res)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user