# # Copyright 2024 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import json import logging import os import re import secrets import string import time from datetime import datetime from typing import Optional, Dict, Any from fastapi import APIRouter, Depends, HTTPException, Request, Response, Query, status from api.apps.models.auth_dependencies import get_current_user from fastapi.responses import RedirectResponse from pydantic import BaseModel, EmailStr try: from werkzeug.security import check_password_hash, generate_password_hash except ImportError: # 如果没有werkzeug,使用passlib作为替代 from passlib.context import CryptContext pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") def check_password_hash(hashed, password): return pwd_context.verify(password, hashed) def generate_password_hash(password): return pwd_context.hash(password) from api import settings from api.apps.auth import get_auth_client from api.db import FileType, UserTenantRole from api.db.db_models import TenantLLM from api.db.services.file_service import FileService from api.db.services.llm_service import get_init_tenant_llm from api.db.services.tenant_llm_service import TenantLLMService from api.db.services.user_service import TenantService, UserService, UserTenantService from api.utils import ( current_timestamp, datetime_format, download_img, get_format_time, get_uuid, ) from api.utils.api_utils import ( construct_response, get_data_error_result, get_json_result, server_error_response, validate_request, ) from api.utils.crypt import decrypt from rag.utils.redis_conn import REDIS_CONN from api.apps import smtp_mail_server from api.utils.web_utils import ( send_email_html, OTP_LENGTH, OTP_TTL_SECONDS, ATTEMPT_LIMIT, ATTEMPT_LOCK_SECONDS, RESEND_COOLDOWN_SECONDS, otp_keys, hash_code, captcha_key, ) # 创建路由器 router = APIRouter() # 安全方案 # Pydantic模型 class LoginRequest(BaseModel): email: EmailStr password: str class RegisterRequest(BaseModel): nickname: str email: EmailStr password: str class UserSettingRequest(BaseModel): language: Optional[str] = None class TenantInfoRequest(BaseModel): tenant_id: str asr_id: str embd_id: str img2txt_id: str llm_id: str class ForgetOtpRequest(BaseModel): email: str captcha: str class ForgetPasswordRequest(BaseModel): email: str otp: str new_password: str confirm_new_password: str # 依赖项:获取当前用户 - 从 auth_dependencies 导入 @router.post("/login") async def login(request: LoginRequest): """ 用户登录端点 """ email = request.email users = UserService.query(email=email) if not users: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=f"Email: {email} is not registered!" ) password = request.password try: password = decrypt(password) except BaseException: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Fail to crypt password" ) user = UserService.query_user(email, password) if user and hasattr(user, 'is_active') and user.is_active == "0": raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="This account has been disabled, please contact the administrator!" ) elif user: response_data = user.to_json() user.access_token = get_uuid() user.update_time = (current_timestamp(),) user.update_date = (datetime_format(datetime.now()),) user.save() msg = "Welcome back!" return construct_response(data=response_data, auth=user.get_id(), message=msg) else: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Email and password do not match!" ) @router.get("/login/channels") async def get_login_channels(): """ 获取所有支持的身份验证渠道 """ try: channels = [] for channel, config in settings.OAUTH_CONFIG.items(): channels.append( { "channel": channel, "display_name": config.get("display_name", channel.title()), "icon": config.get("icon", "sso"), } ) return get_json_result(data=channels) except Exception as e: logging.exception(e) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Load channels failure, error: {str(e)}" ) @router.get("/login/{channel}") async def oauth_login(channel: str, request: Request): """OAuth登录""" channel_config = settings.OAUTH_CONFIG.get(channel) if not channel_config: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid channel name: {channel}" ) auth_cli = get_auth_client(channel_config) state = get_uuid() # 在FastAPI中,我们需要使用session来存储state # 这里简化处理,实际应该使用FastAPI的session管理 auth_url = auth_cli.get_authorization_url(state) return RedirectResponse(url=auth_url) @router.get("/oauth/callback/{channel}") async def oauth_callback(channel: str, request: Request): """ 处理各种渠道的OAuth/OIDC回调 """ try: channel_config = settings.OAUTH_CONFIG.get(channel) if not channel_config: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid channel name: {channel}" ) auth_cli = get_auth_client(channel_config) # 检查state state = request.query_params.get("state") # 在实际应用中,应该从session中获取state进行比较 if not state: return RedirectResponse(url="/?error=invalid_state") # 获取授权码 code = request.query_params.get("code") if not code: return RedirectResponse(url="/?error=missing_code") # 交换授权码获取访问令牌 token_info = auth_cli.exchange_code_for_token(code) access_token = token_info.get("access_token") if not access_token: return RedirectResponse(url="/?error=token_failed") id_token = token_info.get("id_token") # 获取用户信息 user_info = auth_cli.fetch_user_info(access_token, id_token=id_token) if not user_info.email: return RedirectResponse(url="/?error=email_missing") # 登录或注册 users = UserService.query(email=user_info.email) user_id = get_uuid() if not users: try: try: avatar = download_img(user_info.avatar_url) except Exception as e: logging.exception(e) avatar = "" users = user_register( user_id, { "access_token": get_uuid(), "email": user_info.email, "avatar": avatar, "nickname": user_info.nickname, "login_channel": channel, "last_login_time": get_format_time(), "is_superuser": False, }, ) if not users: raise Exception(f"Failed to register {user_info.email}") if len(users) > 1: raise Exception(f"Same email: {user_info.email} exists!") # 尝试登录 user = users[0] return RedirectResponse(url=f"/?auth={user.get_id()}") except Exception as e: rollback_user_registration(user_id) logging.exception(e) return RedirectResponse(url=f"/?error={str(e)}") # 用户存在,尝试登录 user = users[0] user.access_token = get_uuid() if user and hasattr(user, 'is_active') and user.is_active == "0": return RedirectResponse(url="/?error=user_inactive") user.save() return RedirectResponse(url=f"/?auth={user.get_id()}") except Exception as e: logging.exception(e) return RedirectResponse(url=f"/?error={str(e)}") @router.get("/logout") async def log_out(current_user = Depends(get_current_user)): """ 用户登出端点 """ current_user.access_token = f"INVALID_{secrets.token_hex(16)}" current_user.save() return get_json_result(data=True) @router.post("/setting") async def setting_user(request: UserSettingRequest, current_user = Depends(get_current_user)): """ 更新用户设置 """ update_dict = {} request_data = request.dict() if request_data.get("password"): new_password = request_data.get("new_password") if not check_password_hash(current_user.password, decrypt(request_data["password"])): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Password error!" ) if new_password: update_dict["password"] = generate_password_hash(decrypt(new_password)) for k in request_data.keys(): if k in [ "password", "new_password", "email", "status", "is_superuser", "login_channel", "is_anonymous", "is_active", "is_authenticated", "last_login_time", ]: continue update_dict[k] = request_data[k] try: UserService.update_by_id(current_user.id, update_dict) return get_json_result(data=True) except Exception as e: logging.exception(e) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Update failure!" ) @router.get("/info") async def user_profile(current_user = Depends(get_current_user)): """ 获取用户配置文件信息 """ return get_json_result(data=current_user.to_dict()) def rollback_user_registration(user_id): """回滚用户注册""" try: UserService.delete_by_id(user_id) except Exception: pass try: TenantService.delete_by_id(user_id) except Exception: pass try: u = UserTenantService.query(tenant_id=user_id) if u: UserTenantService.delete_by_id(u[0].id) except Exception: pass try: TenantLLM.delete().where(TenantLLM.tenant_id == user_id).execute() except Exception: pass def user_register(user_id, user): """用户注册""" user["id"] = user_id tenant = { "id": user_id, "name": user["nickname"] + "'s Kingdom", "llm_id": settings.CHAT_MDL, "embd_id": settings.EMBEDDING_MDL, "asr_id": settings.ASR_MDL, "parser_ids": settings.PARSERS, "img2txt_id": settings.IMAGE2TEXT_MDL, "rerank_id": settings.RERANK_MDL, } usr_tenant = { "tenant_id": user_id, "user_id": user_id, "invited_by": user_id, "role": UserTenantRole.OWNER, } file_id = get_uuid() file = { "id": file_id, "parent_id": file_id, "tenant_id": user_id, "created_by": user_id, "name": "/", "type": FileType.FOLDER.value, "size": 0, "location": "", } tenant_llm = get_init_tenant_llm(user_id) if not UserService.save(**user): return TenantService.insert(**tenant) UserTenantService.insert(**usr_tenant) TenantLLMService.insert_many(tenant_llm) FileService.insert(file) return UserService.query(email=user["email"]) @router.post("/register") async def user_add(request: RegisterRequest): """ 注册新用户 """ if not settings.REGISTER_ENABLED: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="User registration is disabled!" ) email_address = request.email # 验证邮箱地址 if not re.match(r"^[\w\._-]+@([\w_-]+\.)+[\w-]{2,}$", email_address): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid email address: {email_address}!" ) # 检查邮箱地址是否已被使用 if UserService.query(email=email_address): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"Email: {email_address} has already registered!" ) # 构建用户信息数据 nickname = request.nickname user_dict = { "access_token": get_uuid(), "email": email_address, "nickname": nickname, "password": decrypt(request.password), "login_channel": "password", "last_login_time": get_format_time(), "is_superuser": False, } user_id = get_uuid() try: users = user_register(user_id, user_dict) if not users: raise Exception(f"Fail to register {email_address}.") if len(users) > 1: raise Exception(f"Same email: {email_address} exists!") user = users[0] return construct_response( data=user.to_json(), auth=user.get_id(), message=f"{nickname}, welcome aboard!", ) except Exception as e: rollback_user_registration(user_id) logging.exception(e) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"User registration failure, error: {str(e)}" ) @router.get("/tenant_info") async def tenant_info(current_user = Depends(get_current_user)): """ 获取租户信息 """ try: tenants = TenantService.get_info_by(current_user.id) if not tenants: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Tenant not found!" ) return get_json_result(data=tenants[0]) except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e) ) @router.post("/set_tenant_info") async def set_tenant_info(request: TenantInfoRequest, current_user = Depends(get_current_user)): """ 更新租户信息 """ try: req_dict = request.dict() tid = req_dict.pop("tenant_id") TenantService.update_by_id(tid, req_dict) return get_json_result(data=True) except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e) ) @router.get("/github_callback") async def github_callback(code: Optional[str] = Query(None)): """ **Deprecated**, Use `/oauth/callback/` instead. GitHub OAuth callback endpoint. """ import requests if not code: return RedirectResponse(url="/?error=missing_code") res = requests.post( settings.GITHUB_OAUTH.get("url"), data={ "client_id": settings.GITHUB_OAUTH.get("client_id"), "client_secret": settings.GITHUB_OAUTH.get("secret_key"), "code": code, }, headers={"Accept": "application/json"}, ) res = res.json() if "error" in res: return RedirectResponse(url=f"/?error={res.get('error_description', res.get('error'))}") if "user:email" not in res.get("scope", "").split(","): return RedirectResponse(url="/?error=user:email not in scope") access_token = res["access_token"] user_info = user_info_from_github(access_token) email_address = user_info["email"] users = UserService.query(email=email_address) user_id = get_uuid() if not users: try: try: avatar = download_img(user_info["avatar_url"]) except Exception as e: logging.exception(e) avatar = "" users = user_register( user_id, { "access_token": access_token, "email": email_address, "avatar": avatar, "nickname": user_info["login"], "login_channel": "github", "last_login_time": get_format_time(), "is_superuser": False, }, ) if not users: raise Exception(f"Fail to register {email_address}.") if len(users) > 1: raise Exception(f"Same email: {email_address} exists!") user = users[0] return RedirectResponse(url=f"/?auth={user.get_id()}") except Exception as e: rollback_user_registration(user_id) logging.exception(e) return RedirectResponse(url=f"/?error={str(e)}") # User has already registered, try to log in user = users[0] user.access_token = get_uuid() if user and hasattr(user, 'is_active') and user.is_active == "0": return RedirectResponse(url="/?error=user_inactive") user.save() return RedirectResponse(url=f"/?auth={user.get_id()}") @router.get("/feishu_callback") async def feishu_callback(code: Optional[str] = Query(None)): """ Feishu OAuth callback endpoint. """ import requests if not code: return RedirectResponse(url="/?error=missing_code") app_access_token_res = requests.post( settings.FEISHU_OAUTH.get("app_access_token_url"), data=json.dumps( { "app_id": settings.FEISHU_OAUTH.get("app_id"), "app_secret": settings.FEISHU_OAUTH.get("app_secret"), } ), headers={"Content-Type": "application/json; charset=utf-8"}, ) app_access_token_res = app_access_token_res.json() if app_access_token_res.get("code") != 0: return RedirectResponse(url=f"/?error={app_access_token_res}") res = requests.post( settings.FEISHU_OAUTH.get("user_access_token_url"), data=json.dumps( { "grant_type": settings.FEISHU_OAUTH.get("grant_type"), "code": code, } ), headers={ "Content-Type": "application/json; charset=utf-8", "Authorization": f"Bearer {app_access_token_res['app_access_token']}", }, ) res = res.json() if res.get("code") != 0: return RedirectResponse(url=f"/?error={res.get('message', 'unknown_error')}") if "contact:user.email:readonly" not in res.get("data", {}).get("scope", "").split(): return RedirectResponse(url="/?error=contact:user.email:readonly not in scope") access_token = res["data"]["access_token"] user_info = user_info_from_feishu(access_token) email_address = user_info["email"] users = UserService.query(email=email_address) user_id = get_uuid() if not users: try: try: avatar = download_img(user_info["avatar_url"]) except Exception as e: logging.exception(e) avatar = "" users = user_register( user_id, { "access_token": access_token, "email": email_address, "avatar": avatar, "nickname": user_info["en_name"], "login_channel": "feishu", "last_login_time": get_format_time(), "is_superuser": False, }, ) if not users: raise Exception(f"Fail to register {email_address}.") if len(users) > 1: raise Exception(f"Same email: {email_address} exists!") user = users[0] return RedirectResponse(url=f"/?auth={user.get_id()}") except Exception as e: rollback_user_registration(user_id) logging.exception(e) return RedirectResponse(url=f"/?error={str(e)}") # User has already registered, try to log in user = users[0] if user and hasattr(user, 'is_active') and user.is_active == "0": return RedirectResponse(url="/?error=user_inactive") user.access_token = get_uuid() user.save() return RedirectResponse(url=f"/?auth={user.get_id()}") def user_info_from_feishu(access_token): """从飞书获取用户信息""" import requests headers = { "Content-Type": "application/json; charset=utf-8", "Authorization": f"Bearer {access_token}", } res = requests.get("https://open.feishu.cn/open-apis/authen/v1/user_info", headers=headers) user_info = res.json()["data"] user_info["email"] = None if user_info.get("email") == "" else user_info["email"] return user_info def user_info_from_github(access_token): """从GitHub获取用户信息""" import requests headers = {"Accept": "application/json", "Authorization": f"token {access_token}"} res = requests.get(f"https://api.github.com/user?access_token={access_token}", headers=headers) user_info = res.json() email_info = requests.get( f"https://api.github.com/user/emails?access_token={access_token}", headers=headers, ).json() user_info["email"] = next((email for email in email_info if email["primary"]), None)["email"] return user_info @router.get("/forget/captcha") async def forget_get_captcha(email: str = Query(...)): """ GET /forget/captcha?email= - Generate an image captcha and cache it in Redis under key captcha:{email} with TTL = 60 seconds. - Returns the captcha as a JPEG image. """ if not email: return get_json_result(data=False, code=settings.RetCode.ARGUMENT_ERROR, message="email is required") users = UserService.query(email=email) if not users: return get_json_result(data=False, code=settings.RetCode.DATA_ERROR, message="invalid email") # Generate captcha text allowed = string.ascii_uppercase + string.digits captcha_text = "".join(secrets.choice(allowed) for _ in range(OTP_LENGTH)) REDIS_CONN.set(captcha_key(email), captcha_text, 60) # Valid for 60 seconds from captcha.image import ImageCaptcha image = ImageCaptcha(width=300, height=120, font_sizes=[50, 60, 70]) img_bytes = image.generate(captcha_text).read() return Response(content=img_bytes, media_type="image/JPEG") @router.post("/forget/otp") async def forget_send_otp(request: ForgetOtpRequest): """ POST /forget/otp - Verify the image captcha stored at captcha:{email} (case-insensitive). - On success, generate an email OTP (A–Z with length = OTP_LENGTH), store hash + salt (and timestamp) in Redis with TTL, reset attempts and cooldown, and send the OTP via email. """ email = request.email or "" captcha = (request.captcha or "").strip() if not email or not captcha: return get_json_result(data=False, code=settings.RetCode.ARGUMENT_ERROR, message="email and captcha required") users = UserService.query(email=email) if not users: return get_json_result(data=False, code=settings.RetCode.DATA_ERROR, message="invalid email") stored_captcha = REDIS_CONN.get(captcha_key(email)) if not stored_captcha: return get_json_result(data=False, code=settings.RetCode.NOT_EFFECTIVE, message="invalid or expired captcha") if (stored_captcha or "").strip().lower() != captcha.lower(): return get_json_result(data=False, code=settings.RetCode.AUTHENTICATION_ERROR, message="invalid or expired captcha") # Delete captcha to prevent reuse REDIS_CONN.delete(captcha_key(email)) k_code, k_attempts, k_last, k_lock = otp_keys(email) now = int(time.time()) last_ts = REDIS_CONN.get(k_last) if last_ts: try: elapsed = now - int(last_ts) except Exception: elapsed = RESEND_COOLDOWN_SECONDS remaining = RESEND_COOLDOWN_SECONDS - elapsed if remaining > 0: return get_json_result(data=False, code=settings.RetCode.NOT_EFFECTIVE, message=f"you still have to wait {remaining} seconds") # Generate OTP (uppercase letters only) and store hashed otp = "".join(secrets.choice(string.ascii_uppercase) for _ in range(OTP_LENGTH)) salt = os.urandom(16) code_hash = hash_code(otp, salt) REDIS_CONN.set(k_code, f"{code_hash}:{salt.hex()}", OTP_TTL_SECONDS) REDIS_CONN.set(k_attempts, 0, OTP_TTL_SECONDS) REDIS_CONN.set(k_last, now, OTP_TTL_SECONDS) REDIS_CONN.delete(k_lock) ttl_min = OTP_TTL_SECONDS // 60 if not smtp_mail_server: logging.warning("SMTP mail server not initialized; skip sending email.") else: try: send_email_html( subject="Your Password Reset Code", to_email=email, template_key="reset_code", code=otp, ttl_min=ttl_min, ) except Exception: return get_json_result(data=False, code=settings.RetCode.SERVER_ERROR, message="failed to send email") return get_json_result(data=True, code=settings.RetCode.SUCCESS, message="verification passed, email sent") @router.post("/forget") async def forget(request: ForgetPasswordRequest): """ POST: Verify email + OTP and reset password, then log the user in. Request JSON: { email, otp, new_password, confirm_new_password } """ email = request.email or "" otp = (request.otp or "").strip() new_pwd = request.new_password new_pwd2 = request.confirm_new_password if not all([email, otp, new_pwd, new_pwd2]): return get_json_result(data=False, code=settings.RetCode.ARGUMENT_ERROR, message="email, otp and passwords are required") # For reset, passwords are provided as-is (no decrypt needed) if new_pwd != new_pwd2: return get_json_result(data=False, code=settings.RetCode.ARGUMENT_ERROR, message="passwords do not match") users = UserService.query(email=email) if not users: return get_json_result(data=False, code=settings.RetCode.DATA_ERROR, message="invalid email") user = users[0] # Verify OTP from Redis k_code, k_attempts, k_last, k_lock = otp_keys(email) if REDIS_CONN.get(k_lock): return get_json_result(data=False, code=settings.RetCode.NOT_EFFECTIVE, message="too many attempts, try later") stored = REDIS_CONN.get(k_code) if not stored: return get_json_result(data=False, code=settings.RetCode.NOT_EFFECTIVE, message="expired otp") try: stored_hash, salt_hex = str(stored).split(":", 1) salt = bytes.fromhex(salt_hex) except Exception: return get_json_result(data=False, code=settings.RetCode.EXCEPTION_ERROR, message="otp storage corrupted") # Case-insensitive verification: OTP generated uppercase calc = hash_code(otp.upper(), salt) if calc != stored_hash: # bump attempts try: attempts = int(REDIS_CONN.get(k_attempts) or 0) + 1 except Exception: attempts = 1 REDIS_CONN.set(k_attempts, attempts, OTP_TTL_SECONDS) if attempts >= ATTEMPT_LIMIT: REDIS_CONN.set(k_lock, int(time.time()), ATTEMPT_LOCK_SECONDS) return get_json_result(data=False, code=settings.RetCode.AUTHENTICATION_ERROR, message="expired otp") # Success: consume OTP and reset password REDIS_CONN.delete(k_code) REDIS_CONN.delete(k_attempts) REDIS_CONN.delete(k_last) REDIS_CONN.delete(k_lock) try: UserService.update_user_password(user.id, new_pwd) except Exception as e: logging.exception(e) return get_json_result(data=False, code=settings.RetCode.EXCEPTION_ERROR, message="failed to reset password") # Auto login (reuse login flow) user.access_token = get_uuid() user.update_time = (current_timestamp(),) user.update_date = (datetime_format(datetime.now()),) user.save() msg = "Password reset successful. Logged in." return construct_response(data=user.to_json(), auth=user.get_id(), message=msg)