Files
TERES_fastapi_backend/api/apps/user_app_fastapi.py

541 lines
17 KiB
Python
Raw Normal View History

2025-10-13 13:18:03 +08:00
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
import logging
import re
import secrets
from datetime import datetime
from typing import Optional, Dict, Any
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi.responses import RedirectResponse
from pydantic import BaseModel, EmailStr
try:
from werkzeug.security import check_password_hash, generate_password_hash
except ImportError:
# 如果没有werkzeug使用passlib作为替代
from passlib.context import CryptContext
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
def check_password_hash(hashed, password):
return pwd_context.verify(password, hashed)
def generate_password_hash(password):
return pwd_context.hash(password)
from api import settings
from api.apps.auth import get_auth_client
from api.db import FileType, UserTenantRole
from api.db.db_models import TenantLLM
from api.db.services.file_service import FileService
from api.db.services.llm_service import get_init_tenant_llm
from api.db.services.tenant_llm_service import TenantLLMService
from api.db.services.user_service import TenantService, UserService, UserTenantService
from api.utils import (
current_timestamp,
datetime_format,
download_img,
get_format_time,
get_uuid,
)
from api.utils.api_utils import (
construct_response,
get_data_error_result,
get_json_result,
server_error_response,
validate_request,
)
from api.utils.crypt import decrypt
# 创建路由器
router = APIRouter()
# 安全方案
security = HTTPBearer()
# Pydantic模型
class LoginRequest(BaseModel):
email: EmailStr
password: str
class RegisterRequest(BaseModel):
nickname: str
email: EmailStr
password: str
class UserSettingRequest(BaseModel):
nickname: Optional[str] = None
password: Optional[str] = None
new_password: Optional[str] = None
class TenantInfoRequest(BaseModel):
tenant_id: str
asr_id: str
embd_id: str
img2txt_id: str
llm_id: str
# 依赖项:获取当前用户
async def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(security)):
"""获取当前用户"""
from api.db import StatusEnum
try:
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
except ImportError:
# 如果没有itsdangerous使用jwt作为替代
import jwt
Serializer = jwt
jwt = Serializer(secret_key=settings.SECRET_KEY)
authorization = credentials.credentials
if authorization:
try:
access_token = str(jwt.loads(authorization))
if not access_token or not access_token.strip():
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authentication attempt with empty access token"
)
# Access tokens should be UUIDs (32 hex characters)
if len(access_token.strip()) < 32:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=f"Authentication attempt with invalid token format: {len(access_token)} chars"
)
user = UserService.query(
access_token=access_token, status=StatusEnum.VALID.value
)
if user:
if not user[0].access_token or not user[0].access_token.strip():
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=f"User {user[0].email} has empty access_token in database"
)
return user[0]
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid access token"
)
except Exception as e:
logging.warning(f"load_user got exception {e}")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid access token"
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authorization header required"
)
@router.post("/login")
async def login(request: LoginRequest):
"""
用户登录端点
"""
email = request.email
users = UserService.query(email=email)
if not users:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=f"Email: {email} is not registered!"
)
password = request.password
try:
password = decrypt(password)
except BaseException:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Fail to crypt password"
)
user = UserService.query_user(email, password)
if user and hasattr(user, 'is_active') and user.is_active == "0":
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="This account has been disabled, please contact the administrator!"
)
elif user:
response_data = user.to_json()
user.access_token = get_uuid()
user.update_time = (current_timestamp(),)
user.update_date = (datetime_format(datetime.now()),)
user.save()
msg = "Welcome back!"
return construct_response(data=response_data, auth=user.get_id(), message=msg)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Email and password do not match!"
)
@router.get("/login/channels")
async def get_login_channels():
"""
获取所有支持的身份验证渠道
"""
try:
channels = []
for channel, config in settings.OAUTH_CONFIG.items():
channels.append(
{
"channel": channel,
"display_name": config.get("display_name", channel.title()),
"icon": config.get("icon", "sso"),
}
)
return get_json_result(data=channels)
except Exception as e:
logging.exception(e)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Load channels failure, error: {str(e)}"
)
@router.get("/login/{channel}")
async def oauth_login(channel: str, request: Request):
"""OAuth登录"""
channel_config = settings.OAUTH_CONFIG.get(channel)
if not channel_config:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Invalid channel name: {channel}"
)
auth_cli = get_auth_client(channel_config)
state = get_uuid()
# 在FastAPI中我们需要使用session来存储state
# 这里简化处理实际应该使用FastAPI的session管理
auth_url = auth_cli.get_authorization_url(state)
return RedirectResponse(url=auth_url)
@router.get("/oauth/callback/{channel}")
async def oauth_callback(channel: str, request: Request):
"""
处理各种渠道的OAuth/OIDC回调
"""
try:
channel_config = settings.OAUTH_CONFIG.get(channel)
if not channel_config:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Invalid channel name: {channel}"
)
auth_cli = get_auth_client(channel_config)
# 检查state
state = request.query_params.get("state")
# 在实际应用中应该从session中获取state进行比较
if not state:
return RedirectResponse(url="/?error=invalid_state")
# 获取授权码
code = request.query_params.get("code")
if not code:
return RedirectResponse(url="/?error=missing_code")
# 交换授权码获取访问令牌
token_info = auth_cli.exchange_code_for_token(code)
access_token = token_info.get("access_token")
if not access_token:
return RedirectResponse(url="/?error=token_failed")
id_token = token_info.get("id_token")
# 获取用户信息
user_info = auth_cli.fetch_user_info(access_token, id_token=id_token)
if not user_info.email:
return RedirectResponse(url="/?error=email_missing")
# 登录或注册
users = UserService.query(email=user_info.email)
user_id = get_uuid()
if not users:
try:
try:
avatar = download_img(user_info.avatar_url)
except Exception as e:
logging.exception(e)
avatar = ""
users = user_register(
user_id,
{
"access_token": get_uuid(),
"email": user_info.email,
"avatar": avatar,
"nickname": user_info.nickname,
"login_channel": channel,
"last_login_time": get_format_time(),
"is_superuser": False,
},
)
if not users:
raise Exception(f"Failed to register {user_info.email}")
if len(users) > 1:
raise Exception(f"Same email: {user_info.email} exists!")
# 尝试登录
user = users[0]
return RedirectResponse(url=f"/?auth={user.get_id()}")
except Exception as e:
rollback_user_registration(user_id)
logging.exception(e)
return RedirectResponse(url=f"/?error={str(e)}")
# 用户存在,尝试登录
user = users[0]
user.access_token = get_uuid()
if user and hasattr(user, 'is_active') and user.is_active == "0":
return RedirectResponse(url="/?error=user_inactive")
user.save()
return RedirectResponse(url=f"/?auth={user.get_id()}")
except Exception as e:
logging.exception(e)
return RedirectResponse(url=f"/?error={str(e)}")
@router.get("/logout")
async def log_out(current_user = Depends(get_current_user)):
"""
用户登出端点
"""
current_user.access_token = f"INVALID_{secrets.token_hex(16)}"
current_user.save()
return get_json_result(data=True)
@router.post("/setting")
async def setting_user(request: UserSettingRequest, current_user = Depends(get_current_user)):
"""
更新用户设置
"""
update_dict = {}
request_data = request.dict()
if request_data.get("password"):
new_password = request_data.get("new_password")
if not check_password_hash(current_user.password, decrypt(request_data["password"])):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Password error!"
)
if new_password:
update_dict["password"] = generate_password_hash(decrypt(new_password))
for k in request_data.keys():
if k in [
"password",
"new_password",
"email",
"status",
"is_superuser",
"login_channel",
"is_anonymous",
"is_active",
"is_authenticated",
"last_login_time",
]:
continue
update_dict[k] = request_data[k]
try:
UserService.update_by_id(current_user.id, update_dict)
return get_json_result(data=True)
except Exception as e:
logging.exception(e)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Update failure!"
)
@router.get("/info")
async def user_profile(current_user = Depends(get_current_user)):
"""
获取用户配置文件信息
"""
return get_json_result(data=current_user.to_dict())
def rollback_user_registration(user_id):
"""回滚用户注册"""
try:
UserService.delete_by_id(user_id)
except Exception:
pass
try:
TenantService.delete_by_id(user_id)
except Exception:
pass
try:
u = UserTenantService.query(tenant_id=user_id)
if u:
UserTenantService.delete_by_id(u[0].id)
except Exception:
pass
try:
TenantLLM.delete().where(TenantLLM.tenant_id == user_id).execute()
except Exception:
pass
def user_register(user_id, user):
"""用户注册"""
user["id"] = user_id
tenant = {
"id": user_id,
"name": user["nickname"] + "'s Kingdom",
"llm_id": settings.CHAT_MDL,
"embd_id": settings.EMBEDDING_MDL,
"asr_id": settings.ASR_MDL,
"parser_ids": settings.PARSERS,
"img2txt_id": settings.IMAGE2TEXT_MDL,
"rerank_id": settings.RERANK_MDL,
}
usr_tenant = {
"tenant_id": user_id,
"user_id": user_id,
"invited_by": user_id,
"role": UserTenantRole.OWNER,
}
file_id = get_uuid()
file = {
"id": file_id,
"parent_id": file_id,
"tenant_id": user_id,
"created_by": user_id,
"name": "/",
"type": FileType.FOLDER.value,
"size": 0,
"location": "",
}
tenant_llm = get_init_tenant_llm(user_id)
if not UserService.save(**user):
return
TenantService.insert(**tenant)
UserTenantService.insert(**usr_tenant)
TenantLLMService.insert_many(tenant_llm)
FileService.insert(file)
return UserService.query(email=user["email"])
@router.post("/register")
async def user_add(request: RegisterRequest):
"""
注册新用户
"""
if not settings.REGISTER_ENABLED:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="User registration is disabled!"
)
email_address = request.email
# 验证邮箱地址
if not re.match(r"^[\w\._-]+@([\w_-]+\.)+[\w-]{2,}$", email_address):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Invalid email address: {email_address}!"
)
# 检查邮箱地址是否已被使用
if UserService.query(email=email_address):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Email: {email_address} has already registered!"
)
# 构建用户信息数据
nickname = request.nickname
user_dict = {
"access_token": get_uuid(),
"email": email_address,
"nickname": nickname,
"password": decrypt(request.password),
"login_channel": "password",
"last_login_time": get_format_time(),
"is_superuser": False,
}
user_id = get_uuid()
try:
users = user_register(user_id, user_dict)
if not users:
raise Exception(f"Fail to register {email_address}.")
if len(users) > 1:
raise Exception(f"Same email: {email_address} exists!")
user = users[0]
return construct_response(
data=user.to_json(),
auth=user.get_id(),
message=f"{nickname}, welcome aboard!",
)
except Exception as e:
rollback_user_registration(user_id)
logging.exception(e)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"User registration failure, error: {str(e)}"
)
@router.get("/tenant_info")
async def tenant_info(current_user = Depends(get_current_user)):
"""
获取租户信息
"""
try:
tenants = TenantService.get_info_by(current_user.id)
if not tenants:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Tenant not found!"
)
return get_json_result(data=tenants[0])
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=str(e)
)
@router.post("/set_tenant_info")
async def set_tenant_info(request: TenantInfoRequest, current_user = Depends(get_current_user)):
"""
更新租户信息
"""
try:
req_dict = request.dict()
tid = req_dict.pop("tenant_id")
TenantService.update_by_id(tid, req_dict)
return get_json_result(data=True)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=str(e)
)