Files
TERES_web_frontend/src/pages/setting/hooks/useModelDialogs.ts

358 lines
9.7 KiB
TypeScript

import { useState, useCallback, useMemo, useEffect } from 'react';
import { useMessage } from '@/hooks/useSnackbar';
import userService from '@/services/user_service';
import logger from '@/utils/logger';
import type {
ApiKeyFormData,
AzureOpenAIFormData,
BedrockFormData,
OllamaFormData,
} from '../components/ModelDialogs';
import type { ITenantInfo } from '@/interfaces/database/knowledge';
import { useLlmList } from '@/hooks/llm-hooks';
import type { LlmModelType } from '@/constants/knowledge';
import { useUserData } from '@/hooks/useUserData';
import type { ISetApiKeyRequestBody } from '@/interfaces/request/llm';
// 对话框状态管理 hook
export const useDialogState = () => {
const [open, setOpen] = useState(false);
const [loading, setLoading] = useState(false);
const [editMode, setEditMode] = useState(false);
const [initialData, setInitialData] = useState<any>(null);
const openDialog = useCallback((data?: any, isEdit = false) => {
if (data != null) {
setInitialData(data);
}
setEditMode(isEdit);
setOpen(true);
}, []);
const closeDialog = useCallback(() => {
setOpen(false);
setInitialData(null);
setEditMode(false);
}, []);
return {
open,
loading,
editMode,
initialData,
setLoading,
openDialog,
closeDialog,
};
};
// API Key 对话框管理
export const useApiKeyDialog = () => {
const dialogState = useDialogState();
const showMessage = useMessage();
const [factoryName, setFactoryName] = useState('');
const openApiKeyDialog = useCallback((factory: string, data?: Partial<ApiKeyFormData>, isEdit = false) => {
setFactoryName(factory);
dialogState.openDialog(data, isEdit);
}, [dialogState]);
const submitApiKey = useCallback(async (data: ApiKeyFormData) => {
dialogState.setLoading(true);
logger.info('提交 API Key:', data);
try {
const params: ISetApiKeyRequestBody = {
llm_factory: factoryName,
api_key: data.api_key,
};
if (data.base_url && data.base_url.trim() !== '') {
params.base_url = data.base_url;
}
if (data.group_id && data.group_id.trim() !== '') {
// params.group_id = data.group_id;
}
await userService.set_api_key(params);
showMessage.success('API Key 配置成功');
dialogState.closeDialog();
} catch (error) {
logger.error('API Key 配置失败:', error);
} finally {
dialogState.setLoading(false);
}
}, [factoryName, dialogState]);
return {
...dialogState,
factoryName,
openApiKeyDialog,
submitApiKey,
};
};
// Azure OpenAI 对话框管理
export const useAzureOpenAIDialog = () => {
const dialogState = useDialogState();
const showMessage = useMessage();
const submitAzureOpenAI = useCallback(async (data: AzureOpenAIFormData) => {
dialogState.setLoading(true);
try {
// 调用 Azure OpenAI 特定的 API
await userService.set_api_key({
llm_factory: 'AzureOpenAI',
llm_name: data.deployment_name,
api_key: data.api_key,
// azure_endpoint: data.azure_endpoint,
// api_version: data.api_version,
});
showMessage.success('Azure OpenAI 配置成功');
dialogState.closeDialog();
} catch (error) {
logger.error('Azure OpenAI 配置失败:', error);
showMessage.error('Azure OpenAI 配置失败');
throw error;
} finally {
dialogState.setLoading(false);
}
}, [dialogState]);
return {
...dialogState,
submitAzureOpenAI,
};
};
// AWS Bedrock 对话框管理
export const useBedrockDialog = () => {
const dialogState = useDialogState();
const showMessage = useMessage();
const submitBedrock = useCallback(async (data: BedrockFormData) => {
dialogState.setLoading(true);
try {
// 调用 Bedrock 特定的 API
await userService.set_api_key({
llm_factory: 'Bedrock',
llm_name: '',
api_key: '', // Bedrock 使用 access key
// access_key_id: data.access_key_id,
// secret_access_key: data.secret_access_key,
// region: data.region,
});
showMessage.success('AWS Bedrock 配置成功');
dialogState.closeDialog();
} catch (error) {
logger.error('AWS Bedrock 配置失败:', error);
showMessage.error('AWS Bedrock 配置失败');
throw error;
} finally {
dialogState.setLoading(false);
}
}, [dialogState]);
return {
...dialogState,
submitBedrock,
};
};
// Ollama 对话框管理
export const useOllamaDialog = () => {
const dialogState = useDialogState();
const showMessage = useMessage();
const submitOllama = useCallback(async (data: OllamaFormData) => {
dialogState.setLoading(true);
try {
// 调用添加 LLM 的 API
await userService.add_llm({
llm_factory: 'Ollama',
llm_name: data.model_name,
// base_url: data.base_url,
});
showMessage.success('Ollama 模型添加成功');
dialogState.closeDialog();
} catch (error) {
logger.error('Ollama 模型添加失败:', error);
showMessage.error('Ollama 模型添加失败');
throw error;
} finally {
dialogState.setLoading(false);
}
}, [dialogState]);
return {
...dialogState,
submitOllama,
};
};
// 删除操作管理
export const useDeleteOperations = () => {
const showMessage = useMessage();
const [loading, setLoading] = useState(false);
const deleteLlm = useCallback(async (factoryName: string, modelName: string) => {
setLoading(true);
try {
await userService.delete_llm({
llm_factory: factoryName,
llm_name: modelName,
});
showMessage.success('模型删除成功');
} catch (error) {
logger.error('模型删除失败:', error);
} finally {
setLoading(false);
}
}, []);
const deleteFactory = useCallback(async (factoryName: string) => {
setLoading(true);
try {
await userService.deleteFactory({
llm_factory: factoryName,
});
showMessage.success('模型工厂删除成功');
} catch (error) {
logger.error('模型工厂删除失败:', error);
} finally {
setLoading(false);
}
}, []);
return {
loading,
deleteLlm,
deleteFactory,
};
};
// 系统默认模型设置
export const useSystemModelSetting = () => {
const dialogState = useDialogState();
const showMessage = useMessage();
const { data: llmList } = useLlmList();
const { tenantInfo, fetchTenantInfo } = useUserData();
useEffect(() => {
fetchTenantInfo();
}, []);
const getOptionsByModelType = useCallback((modelType: LlmModelType) => {
return Object.entries(llmList)
.filter(([, value]) =>
modelType
? value.some((x) => x.model_type.includes(modelType))
: true,
)
.map(([key, value]) => {
return {
label: key,
options: value
.filter(
(x) =>
(modelType ? x.model_type.includes(modelType) : true) &&
x.available,
)
.map((x) => ({
label: x.llm_name,
value: `${x.llm_name}@${x.fid}`,
disabled: !x.available,
model: x,
})),
};
})
.filter((x) => x.options.length > 0);
}, [llmList]);
const allModelOptions = useMemo(() => {
const llmOptions = getOptionsByModelType('chat');
const image2textOptions = getOptionsByModelType('image2text');
const embeddingOptions = getOptionsByModelType('embedding');
const speech2textOptions = getOptionsByModelType('speech2text');
const rerankOptions = getOptionsByModelType('rerank');
const ttsOptions = getOptionsByModelType('tts');
return {
llmOptions,
image2textOptions,
embeddingOptions,
speech2textOptions,
rerankOptions,
ttsOptions,
}
}, [llmList, getOptionsByModelType]);
const submitSystemModelSetting = useCallback(async (data: Partial<ITenantInfo>) => {
dialogState.setLoading(true);
logger.debug('submitSystemModelSetting data:', data);
try {
delete data.role;
// 这里需要根据实际的 API 接口调整
await userService.setTenantInfo({
...data,
});
showMessage.success('系统默认模型设置成功');
dialogState.closeDialog();
fetchTenantInfo();
} catch (error) {
logger.error('系统默认模型设置失败:', error);
showMessage.error('系统默认模型设置失败');
throw error;
} finally {
dialogState.setLoading(false);
}
}, [dialogState]);
return {
...dialogState,
submitSystemModelSetting,
allModelOptions,
initialData: tenantInfo,
};
};
// 统一的模型对话框管理器
export const useModelDialogs = () => {
const apiKeyDialog = useApiKeyDialog();
const azureDialog = useAzureOpenAIDialog();
const bedrockDialog = useBedrockDialog();
const ollamaDialog = useOllamaDialog();
const systemDialog = useSystemModelSetting();
const deleteOps = useDeleteOperations();
// 根据工厂类型打开对应的对话框
const openFactoryDialog = useCallback((factoryName: string, data?: any, isEdit = false) => {
switch (factoryName.toLowerCase()) {
case 'azureopenai':
azureDialog.openDialog(data, isEdit);
break;
case 'bedrock':
bedrockDialog.openDialog(data, isEdit);
break;
case 'ollama':
ollamaDialog.openDialog(data, isEdit);
break;
default:
// 默认使用 API Key 对话框
apiKeyDialog.openApiKeyDialog(factoryName, data, isEdit);
break;
}
}, [apiKeyDialog, azureDialog, bedrockDialog, ollamaDialog]);
return {
apiKeyDialog,
azureDialog,
bedrockDialog,
ollamaDialog,
systemDialog,
deleteOps,
openFactoryDialog,
};
};