import { useState, useCallback, useMemo, useEffect } from 'react'; import { useMessage } from '@/hooks/useSnackbar'; import userService from '@/services/user_service'; import logger from '@/utils/logger'; import type { ApiKeyFormData, AzureOpenAIFormData, BedrockFormData, OllamaFormData, } from '../components/ModelDialogs'; import type { ITenantInfo } from '@/interfaces/database/knowledge'; import { useLlmList } from '@/hooks/llm-hooks'; import type { LlmModelType } from '@/constants/knowledge'; import { useUserData } from '@/hooks/useUserData'; import type { ISetApiKeyRequestBody } from '@/interfaces/request/llm'; // 对话框状态管理 hook export const useDialogState = () => { const [open, setOpen] = useState(false); const [loading, setLoading] = useState(false); const [editMode, setEditMode] = useState(false); const [initialData, setInitialData] = useState(null); const openDialog = useCallback((data?: any, isEdit = false) => { if (data != null) { setInitialData(data); } setEditMode(isEdit); setOpen(true); }, []); const closeDialog = useCallback(() => { setOpen(false); setInitialData(null); setEditMode(false); }, []); return { open, loading, editMode, initialData, setLoading, openDialog, closeDialog, }; }; // API Key 对话框管理 export const useApiKeyDialog = () => { const dialogState = useDialogState(); const showMessage = useMessage(); const [factoryName, setFactoryName] = useState(''); const openApiKeyDialog = useCallback((factory: string, data?: Partial, isEdit = false) => { setFactoryName(factory); dialogState.openDialog(data, isEdit); }, [dialogState]); const submitApiKey = useCallback(async (data: ApiKeyFormData) => { dialogState.setLoading(true); logger.info('提交 API Key:', data); try { const params: ISetApiKeyRequestBody = { llm_factory: factoryName, api_key: data.api_key, }; if (data.base_url && data.base_url.trim() !== '') { params.base_url = data.base_url; } if (data.group_id && data.group_id.trim() !== '') { // params.group_id = data.group_id; } await userService.set_api_key(params); showMessage.success('API Key 配置成功'); dialogState.closeDialog(); } catch (error) { logger.error('API Key 配置失败:', error); } finally { dialogState.setLoading(false); } }, [factoryName, dialogState]); return { ...dialogState, factoryName, openApiKeyDialog, submitApiKey, }; }; // Azure OpenAI 对话框管理 export const useAzureOpenAIDialog = () => { const dialogState = useDialogState(); const showMessage = useMessage(); const submitAzureOpenAI = useCallback(async (data: AzureOpenAIFormData) => { dialogState.setLoading(true); try { // 调用 Azure OpenAI 特定的 API await userService.set_api_key({ llm_factory: 'AzureOpenAI', llm_name: data.deployment_name, api_key: data.api_key, // azure_endpoint: data.azure_endpoint, // api_version: data.api_version, }); showMessage.success('Azure OpenAI 配置成功'); dialogState.closeDialog(); } catch (error) { logger.error('Azure OpenAI 配置失败:', error); showMessage.error('Azure OpenAI 配置失败'); throw error; } finally { dialogState.setLoading(false); } }, [dialogState]); return { ...dialogState, submitAzureOpenAI, }; }; // AWS Bedrock 对话框管理 export const useBedrockDialog = () => { const dialogState = useDialogState(); const showMessage = useMessage(); const submitBedrock = useCallback(async (data: BedrockFormData) => { dialogState.setLoading(true); try { // 调用 Bedrock 特定的 API await userService.set_api_key({ llm_factory: 'Bedrock', llm_name: '', api_key: '', // Bedrock 使用 access key // access_key_id: data.access_key_id, // secret_access_key: data.secret_access_key, // region: data.region, }); showMessage.success('AWS Bedrock 配置成功'); dialogState.closeDialog(); } catch (error) { logger.error('AWS Bedrock 配置失败:', error); showMessage.error('AWS Bedrock 配置失败'); throw error; } finally { dialogState.setLoading(false); } }, [dialogState]); return { ...dialogState, submitBedrock, }; }; // Ollama 对话框管理 export const useOllamaDialog = () => { const dialogState = useDialogState(); const showMessage = useMessage(); const submitOllama = useCallback(async (data: OllamaFormData) => { dialogState.setLoading(true); try { // 调用添加 LLM 的 API await userService.add_llm({ llm_factory: 'Ollama', llm_name: data.model_name, // base_url: data.base_url, }); showMessage.success('Ollama 模型添加成功'); dialogState.closeDialog(); } catch (error) { logger.error('Ollama 模型添加失败:', error); showMessage.error('Ollama 模型添加失败'); throw error; } finally { dialogState.setLoading(false); } }, [dialogState]); return { ...dialogState, submitOllama, }; }; // 删除操作管理 export const useDeleteOperations = () => { const showMessage = useMessage(); const [loading, setLoading] = useState(false); const deleteLlm = useCallback(async (factoryName: string, modelName: string) => { setLoading(true); try { await userService.delete_llm({ llm_factory: factoryName, llm_name: modelName, }); showMessage.success('模型删除成功'); } catch (error) { logger.error('模型删除失败:', error); } finally { setLoading(false); } }, []); const deleteFactory = useCallback(async (factoryName: string) => { setLoading(true); try { await userService.deleteFactory({ llm_factory: factoryName, }); showMessage.success('模型工厂删除成功'); } catch (error) { logger.error('模型工厂删除失败:', error); } finally { setLoading(false); } }, []); return { loading, deleteLlm, deleteFactory, }; }; // 系统默认模型设置 export const useSystemModelSetting = () => { const dialogState = useDialogState(); const showMessage = useMessage(); const { data: llmList } = useLlmList(); const { tenantInfo, fetchTenantInfo } = useUserData(); useEffect(() => { fetchTenantInfo(); }, []); const getOptionsByModelType = useCallback((modelType: LlmModelType) => { return Object.entries(llmList) .filter(([, value]) => modelType ? value.some((x) => x.model_type.includes(modelType)) : true, ) .map(([key, value]) => { return { label: key, options: value .filter( (x) => (modelType ? x.model_type.includes(modelType) : true) && x.available, ) .map((x) => ({ label: x.llm_name, value: `${x.llm_name}@${x.fid}`, disabled: !x.available, model: x, })), }; }) .filter((x) => x.options.length > 0); }, [llmList]); const allModelOptions = useMemo(() => { const llmOptions = getOptionsByModelType('chat'); const image2textOptions = getOptionsByModelType('image2text'); const embeddingOptions = getOptionsByModelType('embedding'); const speech2textOptions = getOptionsByModelType('speech2text'); const rerankOptions = getOptionsByModelType('rerank'); const ttsOptions = getOptionsByModelType('tts'); return { llmOptions, image2textOptions, embeddingOptions, speech2textOptions, rerankOptions, ttsOptions, } }, [llmList, getOptionsByModelType]); const submitSystemModelSetting = useCallback(async (data: Partial) => { dialogState.setLoading(true); logger.debug('submitSystemModelSetting data:', data); try { delete data.role; // 这里需要根据实际的 API 接口调整 await userService.setTenantInfo({ ...data, }); showMessage.success('系统默认模型设置成功'); dialogState.closeDialog(); fetchTenantInfo(); } catch (error) { logger.error('系统默认模型设置失败:', error); showMessage.error('系统默认模型设置失败'); throw error; } finally { dialogState.setLoading(false); } }, [dialogState]); return { ...dialogState, submitSystemModelSetting, allModelOptions, initialData: tenantInfo, }; }; // 统一的模型对话框管理器 export const useModelDialogs = () => { const apiKeyDialog = useApiKeyDialog(); const azureDialog = useAzureOpenAIDialog(); const bedrockDialog = useBedrockDialog(); const ollamaDialog = useOllamaDialog(); const systemDialog = useSystemModelSetting(); const deleteOps = useDeleteOperations(); // 根据工厂类型打开对应的对话框 const openFactoryDialog = useCallback((factoryName: string, data?: any, isEdit = false) => { switch (factoryName.toLowerCase()) { case 'azureopenai': azureDialog.openDialog(data, isEdit); break; case 'bedrock': bedrockDialog.openDialog(data, isEdit); break; case 'ollama': ollamaDialog.openDialog(data, isEdit); break; default: // 默认使用 API Key 对话框 apiKeyDialog.openApiKeyDialog(factoryName, data, isEdit); break; } }, [apiKeyDialog, azureDialog, bedrockDialog, ollamaDialog]); return { apiKeyDialog, azureDialog, bedrockDialog, ollamaDialog, systemDialog, deleteOps, openFactoryDialog, }; };