# # Copyright 2024 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import json import logging import re import secrets from datetime import datetime from typing import Optional, Dict, Any from fastapi import APIRouter, Depends, HTTPException, Request, Response, status from api.apps.models.auth_dependencies import get_current_user from fastapi.responses import RedirectResponse from pydantic import BaseModel, EmailStr try: from werkzeug.security import check_password_hash, generate_password_hash except ImportError: # 如果没有werkzeug,使用passlib作为替代 from passlib.context import CryptContext pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") def check_password_hash(hashed, password): return pwd_context.verify(password, hashed) def generate_password_hash(password): return pwd_context.hash(password) from api import settings from api.apps.auth import get_auth_client from api.db import FileType, UserTenantRole from api.db.db_models import TenantLLM from api.db.services.file_service import FileService from api.db.services.llm_service import get_init_tenant_llm from api.db.services.tenant_llm_service import TenantLLMService from api.db.services.user_service import TenantService, UserService, UserTenantService from api.utils import ( current_timestamp, datetime_format, download_img, get_format_time, get_uuid, ) from api.utils.api_utils import ( construct_response, get_data_error_result, get_json_result, server_error_response, validate_request, ) from api.utils.crypt import decrypt # 创建路由器 router = APIRouter() # 安全方案 # Pydantic模型 class LoginRequest(BaseModel): email: EmailStr password: str class RegisterRequest(BaseModel): nickname: str email: EmailStr password: str class UserSettingRequest(BaseModel): nickname: Optional[str] = None password: Optional[str] = None new_password: Optional[str] = None class TenantInfoRequest(BaseModel): tenant_id: str asr_id: str embd_id: str img2txt_id: str llm_id: str # 依赖项:获取当前用户 - 从 auth_dependencies 导入 @router.post("/login") async def login(request: LoginRequest): """ 用户登录端点 """ email = request.email users = UserService.query(email=email) if not users: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=f"Email: {email} is not registered!" ) password = request.password try: password = decrypt(password) except BaseException: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Fail to crypt password" ) user = UserService.query_user(email, password) if user and hasattr(user, 'is_active') and user.is_active == "0": raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="This account has been disabled, please contact the administrator!" ) elif user: response_data = user.to_json() user.access_token = get_uuid() user.update_time = (current_timestamp(),) user.update_date = (datetime_format(datetime.now()),) user.save() msg = "Welcome back!" return construct_response(data=response_data, auth=user.get_id(), message=msg) else: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Email and password do not match!" ) @router.get("/login/channels") async def get_login_channels(): """ 获取所有支持的身份验证渠道 """ try: channels = [] for channel, config in settings.OAUTH_CONFIG.items(): channels.append( { "channel": channel, "display_name": config.get("display_name", channel.title()), "icon": config.get("icon", "sso"), } ) return get_json_result(data=channels) except Exception as e: logging.exception(e) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Load channels failure, error: {str(e)}" ) @router.get("/login/{channel}") async def oauth_login(channel: str, request: Request): """OAuth登录""" channel_config = settings.OAUTH_CONFIG.get(channel) if not channel_config: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid channel name: {channel}" ) auth_cli = get_auth_client(channel_config) state = get_uuid() # 在FastAPI中,我们需要使用session来存储state # 这里简化处理,实际应该使用FastAPI的session管理 auth_url = auth_cli.get_authorization_url(state) return RedirectResponse(url=auth_url) @router.get("/oauth/callback/{channel}") async def oauth_callback(channel: str, request: Request): """ 处理各种渠道的OAuth/OIDC回调 """ try: channel_config = settings.OAUTH_CONFIG.get(channel) if not channel_config: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid channel name: {channel}" ) auth_cli = get_auth_client(channel_config) # 检查state state = request.query_params.get("state") # 在实际应用中,应该从session中获取state进行比较 if not state: return RedirectResponse(url="/?error=invalid_state") # 获取授权码 code = request.query_params.get("code") if not code: return RedirectResponse(url="/?error=missing_code") # 交换授权码获取访问令牌 token_info = auth_cli.exchange_code_for_token(code) access_token = token_info.get("access_token") if not access_token: return RedirectResponse(url="/?error=token_failed") id_token = token_info.get("id_token") # 获取用户信息 user_info = auth_cli.fetch_user_info(access_token, id_token=id_token) if not user_info.email: return RedirectResponse(url="/?error=email_missing") # 登录或注册 users = UserService.query(email=user_info.email) user_id = get_uuid() if not users: try: try: avatar = download_img(user_info.avatar_url) except Exception as e: logging.exception(e) avatar = "" users = user_register( user_id, { "access_token": get_uuid(), "email": user_info.email, "avatar": avatar, "nickname": user_info.nickname, "login_channel": channel, "last_login_time": get_format_time(), "is_superuser": False, }, ) if not users: raise Exception(f"Failed to register {user_info.email}") if len(users) > 1: raise Exception(f"Same email: {user_info.email} exists!") # 尝试登录 user = users[0] return RedirectResponse(url=f"/?auth={user.get_id()}") except Exception as e: rollback_user_registration(user_id) logging.exception(e) return RedirectResponse(url=f"/?error={str(e)}") # 用户存在,尝试登录 user = users[0] user.access_token = get_uuid() if user and hasattr(user, 'is_active') and user.is_active == "0": return RedirectResponse(url="/?error=user_inactive") user.save() return RedirectResponse(url=f"/?auth={user.get_id()}") except Exception as e: logging.exception(e) return RedirectResponse(url=f"/?error={str(e)}") @router.get("/logout") async def log_out(current_user = Depends(get_current_user)): """ 用户登出端点 """ current_user.access_token = f"INVALID_{secrets.token_hex(16)}" current_user.save() return get_json_result(data=True) @router.post("/setting") async def setting_user(request: UserSettingRequest, current_user = Depends(get_current_user)): """ 更新用户设置 """ update_dict = {} request_data = request.dict() if request_data.get("password"): new_password = request_data.get("new_password") if not check_password_hash(current_user.password, decrypt(request_data["password"])): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Password error!" ) if new_password: update_dict["password"] = generate_password_hash(decrypt(new_password)) for k in request_data.keys(): if k in [ "password", "new_password", "email", "status", "is_superuser", "login_channel", "is_anonymous", "is_active", "is_authenticated", "last_login_time", ]: continue update_dict[k] = request_data[k] try: UserService.update_by_id(current_user.id, update_dict) return get_json_result(data=True) except Exception as e: logging.exception(e) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Update failure!" ) @router.get("/info") async def user_profile(current_user = Depends(get_current_user)): """ 获取用户配置文件信息 """ return get_json_result(data=current_user.to_dict()) def rollback_user_registration(user_id): """回滚用户注册""" try: UserService.delete_by_id(user_id) except Exception: pass try: TenantService.delete_by_id(user_id) except Exception: pass try: u = UserTenantService.query(tenant_id=user_id) if u: UserTenantService.delete_by_id(u[0].id) except Exception: pass try: TenantLLM.delete().where(TenantLLM.tenant_id == user_id).execute() except Exception: pass def user_register(user_id, user): """用户注册""" user["id"] = user_id tenant = { "id": user_id, "name": user["nickname"] + "'s Kingdom", "llm_id": settings.CHAT_MDL, "embd_id": settings.EMBEDDING_MDL, "asr_id": settings.ASR_MDL, "parser_ids": settings.PARSERS, "img2txt_id": settings.IMAGE2TEXT_MDL, "rerank_id": settings.RERANK_MDL, } usr_tenant = { "tenant_id": user_id, "user_id": user_id, "invited_by": user_id, "role": UserTenantRole.OWNER, } file_id = get_uuid() file = { "id": file_id, "parent_id": file_id, "tenant_id": user_id, "created_by": user_id, "name": "/", "type": FileType.FOLDER.value, "size": 0, "location": "", } tenant_llm = get_init_tenant_llm(user_id) if not UserService.save(**user): return TenantService.insert(**tenant) UserTenantService.insert(**usr_tenant) TenantLLMService.insert_many(tenant_llm) FileService.insert(file) return UserService.query(email=user["email"]) @router.post("/register") async def user_add(request: RegisterRequest): """ 注册新用户 """ if not settings.REGISTER_ENABLED: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="User registration is disabled!" ) email_address = request.email # 验证邮箱地址 if not re.match(r"^[\w\._-]+@([\w_-]+\.)+[\w-]{2,}$", email_address): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid email address: {email_address}!" ) # 检查邮箱地址是否已被使用 if UserService.query(email=email_address): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"Email: {email_address} has already registered!" ) # 构建用户信息数据 nickname = request.nickname user_dict = { "access_token": get_uuid(), "email": email_address, "nickname": nickname, "password": decrypt(request.password), "login_channel": "password", "last_login_time": get_format_time(), "is_superuser": False, } user_id = get_uuid() try: users = user_register(user_id, user_dict) if not users: raise Exception(f"Fail to register {email_address}.") if len(users) > 1: raise Exception(f"Same email: {email_address} exists!") user = users[0] return construct_response( data=user.to_json(), auth=user.get_id(), message=f"{nickname}, welcome aboard!", ) except Exception as e: rollback_user_registration(user_id) logging.exception(e) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"User registration failure, error: {str(e)}" ) @router.get("/tenant_info") async def tenant_info(current_user = Depends(get_current_user)): """ 获取租户信息 """ try: tenants = TenantService.get_info_by(current_user.id) if not tenants: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Tenant not found!" ) return get_json_result(data=tenants[0]) except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e) ) @router.post("/set_tenant_info") async def set_tenant_info(request: TenantInfoRequest, current_user = Depends(get_current_user)): """ 更新租户信息 """ try: req_dict = request.dict() tid = req_dict.pop("tenant_id") TenantService.update_by_id(tid, req_dict) return get_json_result(data=True) except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e) )