308 lines
11 KiB
TypeScript
308 lines
11 KiB
TypeScript
import React, { useEffect, useMemo } from 'react';
|
|
import {
|
|
Dialog,
|
|
DialogTitle,
|
|
DialogContent,
|
|
DialogActions,
|
|
Button,
|
|
FormControl,
|
|
InputLabel,
|
|
Select,
|
|
MenuItem,
|
|
Box,
|
|
Typography,
|
|
CircularProgress,
|
|
ListSubheader,
|
|
} from '@mui/material';
|
|
import { useForm, Controller } from 'react-hook-form';
|
|
import { useTranslation } from 'react-i18next';
|
|
import { LlmSvgIcon } from '@/components/AppSvgIcon';
|
|
import { IconMap, type LLMFactory } from '@/constants/llm';
|
|
import type { ITenantInfo } from '@/interfaces/database/knowledge';
|
|
import type { LlmModelType } from '@/constants/knowledge';
|
|
import type { IMyLlmModel, IThirdOAIModel } from '@/interfaces/database/llm';
|
|
|
|
interface AllModelOptionItem {
|
|
label: string;
|
|
options: {
|
|
value: string;
|
|
label: string;
|
|
disabled: boolean;
|
|
model: IThirdOAIModel
|
|
}[];
|
|
}
|
|
|
|
// 导出接口供其他文件使用
|
|
export interface SystemModelFormData extends Partial<ITenantInfo> { }
|
|
|
|
// 系统默认模型设置对话框
|
|
export interface SystemModelDialogProps {
|
|
open: boolean;
|
|
onClose: () => void;
|
|
loading: boolean;
|
|
editMode?: boolean;
|
|
onSubmit: (data: SystemModelFormData) => Promise<void>;
|
|
initialData?: Partial<ITenantInfo>;
|
|
allModelOptions: Record<string, AllModelOptionItem[]>;
|
|
}
|
|
|
|
|
|
export interface ModelOption {
|
|
value: string;
|
|
label: string;
|
|
disabled: boolean;
|
|
model: IThirdOAIModel;
|
|
}
|
|
export interface ModelGroup {
|
|
label: string;
|
|
options: ModelOption[];
|
|
}
|
|
|
|
/**
|
|
* 系统默认模型设置对话框
|
|
*/
|
|
function SystemModelDialog({
|
|
open,
|
|
onClose,
|
|
onSubmit,
|
|
loading,
|
|
initialData,
|
|
editMode = false,
|
|
allModelOptions
|
|
}: SystemModelDialogProps) {
|
|
const { t } = useTranslation();
|
|
const { control, handleSubmit, reset, formState: { errors } } = useForm<ITenantInfo>({
|
|
defaultValues: {}
|
|
});
|
|
|
|
// 获取工厂图标名称
|
|
const getFactoryIconName = (factoryName: LLMFactory) => {
|
|
return IconMap[factoryName] || 'default';
|
|
};
|
|
|
|
// all model options 包含了全部的 options
|
|
const llmOptions = useMemo(() => allModelOptions?.llmOptions || [], [allModelOptions]);
|
|
const embdOptions = useMemo(() => allModelOptions?.embeddingOptions || [], [allModelOptions]);
|
|
const img2txtOptions = useMemo(() => allModelOptions?.image2textOptions || [], [allModelOptions]);
|
|
const asrOptions = useMemo(() => allModelOptions?.speech2textOptions || [], [allModelOptions]);
|
|
const ttsOptions = useMemo(() => allModelOptions?.ttsOptions || [], [allModelOptions]);
|
|
const rerankOptions = useMemo(() => allModelOptions?.rerankOptions || [], [allModelOptions]);
|
|
|
|
useEffect(() => {
|
|
if (open && initialData) {
|
|
reset(initialData);
|
|
} else if (open) {
|
|
reset({
|
|
llm_id: '',
|
|
embd_id: '',
|
|
img2txt_id: '',
|
|
asr_id: '',
|
|
tts_id: '',
|
|
rerank_id: '',
|
|
});
|
|
}
|
|
}, [open, initialData, reset]);
|
|
|
|
const handleFormSubmit = async (data: ITenantInfo) => {
|
|
try {
|
|
await onSubmit(data);
|
|
onClose();
|
|
} catch (error) {
|
|
console.error(t('setting.submitFailed'), error);
|
|
}
|
|
};
|
|
|
|
return (
|
|
<Dialog open={open} onClose={onClose} maxWidth="sm" fullWidth>
|
|
<DialogTitle>
|
|
{t('setting.setDefaultModel')}
|
|
</DialogTitle>
|
|
<DialogContent>
|
|
<Box component="form" sx={{ mt: 2 }}>
|
|
<Controller
|
|
name="llm_id"
|
|
control={control}
|
|
rules={{ required: t('setting.chatModelRequired') }}
|
|
render={({ field }) => (
|
|
<FormControl fullWidth margin="normal" error={!!errors.llm_id}>
|
|
<InputLabel>{t('setting.chatModel')}</InputLabel>
|
|
<Select {...field} label={t('setting.chatModel')}>
|
|
{llmOptions.map((group) => [
|
|
<ListSubheader key={group.label}>{group.label}</ListSubheader>,
|
|
...group.options.map((option) => (
|
|
<MenuItem key={option.value} value={option.value} disabled={option.disabled}>
|
|
<Box sx={{ display: 'flex', alignItems: 'center', gap: 1 }}>
|
|
<LlmSvgIcon
|
|
name={getFactoryIconName(group.label as LLMFactory)}
|
|
sx={{ width: 20, height: 20, color: 'primary.main' }}
|
|
/>
|
|
{option.label}
|
|
</Box>
|
|
</MenuItem>
|
|
))
|
|
])}
|
|
</Select>
|
|
{errors.llm_id && (
|
|
<Typography variant="caption" color="error" sx={{ mt: 1, ml: 2 }}>
|
|
{errors.llm_id.message}
|
|
</Typography>
|
|
)}
|
|
</FormControl>
|
|
)}
|
|
/>
|
|
|
|
<Controller
|
|
name="embd_id"
|
|
control={control}
|
|
rules={{ required: t('setting.embeddingModelRequired') }}
|
|
render={({ field }) => (
|
|
<FormControl fullWidth margin="normal" error={!!errors.embd_id}>
|
|
<InputLabel>{t('setting.embeddingModel')}</InputLabel>
|
|
<Select {...field} label={t('setting.embeddingModel')}>
|
|
{embdOptions.map((group) => [
|
|
<ListSubheader key={group.label}>{group.label}</ListSubheader>,
|
|
...group.options.map((option) => (
|
|
<MenuItem key={option.value} value={option.value} disabled={option.disabled}>
|
|
<Box sx={{ display: 'flex', alignItems: 'center', gap: 1 }}>
|
|
<LlmSvgIcon
|
|
name={getFactoryIconName(group.label as LLMFactory)}
|
|
sx={{ width: 20, height: 20, color: 'primary.main' }}
|
|
/>
|
|
{option.label}
|
|
</Box>
|
|
</MenuItem>
|
|
))
|
|
])}
|
|
</Select>
|
|
{errors.embd_id && (
|
|
<Typography variant="caption" color="error" sx={{ mt: 1, ml: 2 }}>
|
|
{errors.embd_id.message}
|
|
</Typography>
|
|
)}
|
|
</FormControl>
|
|
)}
|
|
/>
|
|
|
|
<Controller
|
|
name="img2txt_id"
|
|
control={control}
|
|
render={({ field }) => (
|
|
<FormControl fullWidth margin="normal">
|
|
<InputLabel>{t('setting.img2txtModel')}</InputLabel>
|
|
<Select {...field} label={t('setting.img2txtModel')}>
|
|
{img2txtOptions.map((group) => [
|
|
<ListSubheader key={group.label}>{group.label}</ListSubheader>,
|
|
...group.options.map((option) => (
|
|
<MenuItem key={option.value} value={option.value} disabled={option.disabled}>
|
|
<Box sx={{ display: 'flex', alignItems: 'center', gap: 1 }}>
|
|
<LlmSvgIcon
|
|
name={getFactoryIconName(group.label as LLMFactory)}
|
|
sx={{ width: 20, height: 20, color: 'primary.main' }}
|
|
/>
|
|
{option.label}
|
|
</Box>
|
|
</MenuItem>
|
|
))
|
|
])}
|
|
</Select>
|
|
</FormControl>
|
|
)}
|
|
/>
|
|
|
|
<Controller
|
|
name="asr_id"
|
|
control={control}
|
|
render={({ field }) => (
|
|
<FormControl fullWidth margin="normal">
|
|
<InputLabel>{t('setting.speech2txtModel')}</InputLabel>
|
|
<Select {...field} label={t('setting.speech2txtModel')}>
|
|
{asrOptions.map((group) => [
|
|
<ListSubheader key={group.label}>{group.label}</ListSubheader>,
|
|
...group.options.map((option) => (
|
|
<MenuItem key={option.value} value={option.value} disabled={option.disabled}>
|
|
<Box sx={{ display: 'flex', alignItems: 'center', gap: 1 }}>
|
|
<LlmSvgIcon
|
|
name={getFactoryIconName(group.label as LLMFactory)}
|
|
sx={{ width: 20, height: 20, color: 'primary.main' }}
|
|
/>
|
|
{option.label}
|
|
</Box>
|
|
</MenuItem>
|
|
))
|
|
])}
|
|
</Select>
|
|
</FormControl>
|
|
)}
|
|
/>
|
|
|
|
<Controller
|
|
name="rerank_id"
|
|
control={control}
|
|
render={({ field }) => (
|
|
<FormControl fullWidth margin="normal">
|
|
<InputLabel>{t('setting.rerankModel')}</InputLabel>
|
|
<Select {...field} label={t('setting.rerankModel')}>
|
|
{rerankOptions.map((group) => [
|
|
<ListSubheader key={group.label}>{group.label}</ListSubheader>,
|
|
...group.options.map((option) => (
|
|
<MenuItem key={option.value} value={option.value} disabled={option.disabled}>
|
|
<Box sx={{ display: 'flex', alignItems: 'center', gap: 1 }}>
|
|
<LlmSvgIcon
|
|
name={getFactoryIconName(group.label as LLMFactory)}
|
|
sx={{ width: 20, height: 20, color: 'primary.main' }}
|
|
/>
|
|
{option.label}
|
|
</Box>
|
|
</MenuItem>
|
|
))
|
|
])}
|
|
</Select>
|
|
</FormControl>
|
|
)}
|
|
/>
|
|
|
|
<Controller
|
|
name="tts_id"
|
|
control={control}
|
|
render={({ field }) => (
|
|
<FormControl fullWidth margin="normal">
|
|
<InputLabel>{t('setting.ttsModel')}</InputLabel>
|
|
<Select {...field} label={t('setting.ttsModel')}>
|
|
{ttsOptions.map((group) => [
|
|
<ListSubheader key={group.label}>{group.label}</ListSubheader>,
|
|
...group.options.map((option) => (
|
|
<MenuItem key={option.value} value={option.value} disabled={option.disabled}>
|
|
<Box sx={{ display: 'flex', alignItems: 'center', gap: 1 }}>
|
|
<LlmSvgIcon
|
|
name={getFactoryIconName(group.label as LLMFactory)}
|
|
sx={{ width: 20, height: 20, color: 'primary.main' }}
|
|
/>
|
|
{option.label}
|
|
</Box>
|
|
</MenuItem>
|
|
))
|
|
])}
|
|
</Select>
|
|
</FormControl>
|
|
)}
|
|
/>
|
|
</Box>
|
|
</DialogContent>
|
|
<DialogActions>
|
|
<Button onClick={onClose} disabled={loading}>
|
|
{t('setting.cancel')}
|
|
</Button>
|
|
<Button
|
|
onClick={handleSubmit(handleFormSubmit)}
|
|
variant="contained"
|
|
disabled={loading}
|
|
startIcon={loading ? <CircularProgress size={20} /> : null}
|
|
>
|
|
{t('setting.confirm')}
|
|
</Button>
|
|
</DialogActions>
|
|
</Dialog>
|
|
);
|
|
};
|
|
|
|
export default SystemModelDialog; |