Files
TERES_web_frontend/src/pages/setting/components/Dialog/SystemModelDialog.tsx

308 lines
11 KiB
TypeScript

import React, { useEffect, useMemo } from 'react';
import {
Dialog,
DialogTitle,
DialogContent,
DialogActions,
Button,
FormControl,
InputLabel,
Select,
MenuItem,
Box,
Typography,
CircularProgress,
ListSubheader,
} from '@mui/material';
import { useForm, Controller } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import { LlmSvgIcon } from '@/components/AppSvgIcon';
import { IconMap, type LLMFactory } from '@/constants/llm';
import type { ITenantInfo } from '@/interfaces/database/knowledge';
import type { LlmModelType } from '@/constants/knowledge';
import type { IMyLlmModel, IThirdOAIModel } from '@/interfaces/database/llm';
interface AllModelOptionItem {
label: string;
options: {
value: string;
label: string;
disabled: boolean;
model: IThirdOAIModel
}[];
}
// 导出接口供其他文件使用
export interface SystemModelFormData extends Partial<ITenantInfo> { }
// 系统默认模型设置对话框
export interface SystemModelDialogProps {
open: boolean;
onClose: () => void;
loading: boolean;
editMode?: boolean;
onSubmit: (data: SystemModelFormData) => Promise<void>;
initialData?: Partial<ITenantInfo>;
allModelOptions: Record<string, AllModelOptionItem[]>;
}
export interface ModelOption {
value: string;
label: string;
disabled: boolean;
model: IThirdOAIModel;
}
export interface ModelGroup {
label: string;
options: ModelOption[];
}
/**
* 系统默认模型设置对话框
*/
function SystemModelDialog({
open,
onClose,
onSubmit,
loading,
initialData,
editMode = false,
allModelOptions
}: SystemModelDialogProps) {
const { t } = useTranslation();
const { control, handleSubmit, reset, formState: { errors } } = useForm<ITenantInfo>({
defaultValues: {}
});
// 获取工厂图标名称
const getFactoryIconName = (factoryName: LLMFactory) => {
return IconMap[factoryName] || 'default';
};
// all model options 包含了全部的 options
const llmOptions = useMemo(() => allModelOptions?.llmOptions || [], [allModelOptions]);
const embdOptions = useMemo(() => allModelOptions?.embeddingOptions || [], [allModelOptions]);
const img2txtOptions = useMemo(() => allModelOptions?.image2textOptions || [], [allModelOptions]);
const asrOptions = useMemo(() => allModelOptions?.speech2textOptions || [], [allModelOptions]);
const ttsOptions = useMemo(() => allModelOptions?.ttsOptions || [], [allModelOptions]);
const rerankOptions = useMemo(() => allModelOptions?.rerankOptions || [], [allModelOptions]);
useEffect(() => {
if (open && initialData) {
reset(initialData);
} else if (open) {
reset({
llm_id: '',
embd_id: '',
img2txt_id: '',
asr_id: '',
tts_id: '',
rerank_id: '',
});
}
}, [open, initialData, reset]);
const handleFormSubmit = async (data: ITenantInfo) => {
try {
await onSubmit(data);
onClose();
} catch (error) {
console.error(t('setting.submitFailed'), error);
}
};
return (
<Dialog open={open} onClose={onClose} maxWidth="sm" fullWidth>
<DialogTitle>
{t('setting.setDefaultModel')}
</DialogTitle>
<DialogContent>
<Box component="form" sx={{ mt: 2 }}>
<Controller
name="llm_id"
control={control}
rules={{ required: t('setting.chatModelRequired') }}
render={({ field }) => (
<FormControl fullWidth margin="normal" error={!!errors.llm_id}>
<InputLabel>{t('setting.chatModel')}</InputLabel>
<Select {...field} label={t('setting.chatModel')}>
{llmOptions.map((group) => [
<ListSubheader key={group.label}>{group.label}</ListSubheader>,
...group.options.map((option) => (
<MenuItem key={option.value} value={option.value} disabled={option.disabled}>
<Box sx={{ display: 'flex', alignItems: 'center', gap: 1 }}>
<LlmSvgIcon
name={getFactoryIconName(group.label as LLMFactory)}
sx={{ width: 20, height: 20, color: 'primary.main' }}
/>
{option.label}
</Box>
</MenuItem>
))
])}
</Select>
{errors.llm_id && (
<Typography variant="caption" color="error" sx={{ mt: 1, ml: 2 }}>
{errors.llm_id.message}
</Typography>
)}
</FormControl>
)}
/>
<Controller
name="embd_id"
control={control}
rules={{ required: t('setting.embeddingModelRequired') }}
render={({ field }) => (
<FormControl fullWidth margin="normal" error={!!errors.embd_id}>
<InputLabel>{t('setting.embeddingModel')}</InputLabel>
<Select {...field} label={t('setting.embeddingModel')}>
{embdOptions.map((group) => [
<ListSubheader key={group.label}>{group.label}</ListSubheader>,
...group.options.map((option) => (
<MenuItem key={option.value} value={option.value} disabled={option.disabled}>
<Box sx={{ display: 'flex', alignItems: 'center', gap: 1 }}>
<LlmSvgIcon
name={getFactoryIconName(group.label as LLMFactory)}
sx={{ width: 20, height: 20, color: 'primary.main' }}
/>
{option.label}
</Box>
</MenuItem>
))
])}
</Select>
{errors.embd_id && (
<Typography variant="caption" color="error" sx={{ mt: 1, ml: 2 }}>
{errors.embd_id.message}
</Typography>
)}
</FormControl>
)}
/>
<Controller
name="img2txt_id"
control={control}
render={({ field }) => (
<FormControl fullWidth margin="normal">
<InputLabel>{t('setting.img2txtModel')}</InputLabel>
<Select {...field} label={t('setting.img2txtModel')}>
{img2txtOptions.map((group) => [
<ListSubheader key={group.label}>{group.label}</ListSubheader>,
...group.options.map((option) => (
<MenuItem key={option.value} value={option.value} disabled={option.disabled}>
<Box sx={{ display: 'flex', alignItems: 'center', gap: 1 }}>
<LlmSvgIcon
name={getFactoryIconName(group.label as LLMFactory)}
sx={{ width: 20, height: 20, color: 'primary.main' }}
/>
{option.label}
</Box>
</MenuItem>
))
])}
</Select>
</FormControl>
)}
/>
<Controller
name="asr_id"
control={control}
render={({ field }) => (
<FormControl fullWidth margin="normal">
<InputLabel>{t('setting.speech2txtModel')}</InputLabel>
<Select {...field} label={t('setting.speech2txtModel')}>
{asrOptions.map((group) => [
<ListSubheader key={group.label}>{group.label}</ListSubheader>,
...group.options.map((option) => (
<MenuItem key={option.value} value={option.value} disabled={option.disabled}>
<Box sx={{ display: 'flex', alignItems: 'center', gap: 1 }}>
<LlmSvgIcon
name={getFactoryIconName(group.label as LLMFactory)}
sx={{ width: 20, height: 20, color: 'primary.main' }}
/>
{option.label}
</Box>
</MenuItem>
))
])}
</Select>
</FormControl>
)}
/>
<Controller
name="rerank_id"
control={control}
render={({ field }) => (
<FormControl fullWidth margin="normal">
<InputLabel>{t('setting.rerankModel')}</InputLabel>
<Select {...field} label={t('setting.rerankModel')}>
{rerankOptions.map((group) => [
<ListSubheader key={group.label}>{group.label}</ListSubheader>,
...group.options.map((option) => (
<MenuItem key={option.value} value={option.value} disabled={option.disabled}>
<Box sx={{ display: 'flex', alignItems: 'center', gap: 1 }}>
<LlmSvgIcon
name={getFactoryIconName(group.label as LLMFactory)}
sx={{ width: 20, height: 20, color: 'primary.main' }}
/>
{option.label}
</Box>
</MenuItem>
))
])}
</Select>
</FormControl>
)}
/>
<Controller
name="tts_id"
control={control}
render={({ field }) => (
<FormControl fullWidth margin="normal">
<InputLabel>{t('setting.ttsModel')}</InputLabel>
<Select {...field} label={t('setting.ttsModel')}>
{ttsOptions.map((group) => [
<ListSubheader key={group.label}>{group.label}</ListSubheader>,
...group.options.map((option) => (
<MenuItem key={option.value} value={option.value} disabled={option.disabled}>
<Box sx={{ display: 'flex', alignItems: 'center', gap: 1 }}>
<LlmSvgIcon
name={getFactoryIconName(group.label as LLMFactory)}
sx={{ width: 20, height: 20, color: 'primary.main' }}
/>
{option.label}
</Box>
</MenuItem>
))
])}
</Select>
</FormControl>
)}
/>
</Box>
</DialogContent>
<DialogActions>
<Button onClick={onClose} disabled={loading}>
{t('setting.cancel')}
</Button>
<Button
onClick={handleSubmit(handleFormSubmit)}
variant="contained"
disabled={loading}
startIcon={loading ? <CircularProgress size={20} /> : null}
>
{t('setting.confirm')}
</Button>
</DialogActions>
</Dialog>
);
};
export default SystemModelDialog;