feat(models): implement model configuration dialogs for Azure, Bedrock and Ollama

This commit is contained in:
2025-10-24 15:40:34 +08:00
parent a9b47f776b
commit edba1f049e
6 changed files with 603 additions and 144 deletions

View File

@@ -9,9 +9,9 @@ export interface ISetApiKeyRequestBody {
export interface IAddLlmRequestBody { export interface IAddLlmRequestBody {
llm_factory: string; // Ollama llm_factory: string; // Ollama
llm_name: string; llm_name: string;
model_type: string; model_type: string;// chat|embedding|speech2text|image2text
api_base?: string; // chat|embedding|speech2text|image2text api_base: string;
api_key: string; api_key?: string;
max_tokens: number; max_tokens: number;
} }

View File

@@ -11,14 +11,26 @@ import {
IconButton, IconButton,
InputAdornment, InputAdornment,
CircularProgress, CircularProgress,
MenuItem,
Select,
FormControl,
InputLabel,
FormHelperText,
Link,
} from '@mui/material'; } from '@mui/material';
import { Visibility, VisibilityOff } from '@mui/icons-material'; import { Visibility, VisibilityOff } from '@mui/icons-material';
import { Controller, useForm } from 'react-hook-form'; import { Controller, useForm } from 'react-hook-form';
import type { IAddLlmRequestBody } from '@/interfaces/request/llm';
// 模型类型选项
const MODEL_TYPE_OPTIONS = [
{ value: 'chat', label: 'Chat' },
{ value: 'embedding', label: 'Embedding' },
{ value: 'image2text', label: 'Image2Text' },
];
// 表单数据接口 // 表单数据接口
export interface AzureOpenAIFormData { export interface AzureOpenAIFormData extends IAddLlmRequestBody {
api_key: string;
endpoint: string;
api_version: string; api_version: string;
} }
@@ -52,16 +64,28 @@ function AzureOpenAIDialog ({
formState: { errors }, formState: { errors },
} = useForm<AzureOpenAIFormData>({ } = useForm<AzureOpenAIFormData>({
defaultValues: { defaultValues: {
model_type: 'embedding',
llm_name: 'gpt-3.5-turbo',
api_base: '',
api_key: '', api_key: '',
endpoint: '',
api_version: '2024-02-01', api_version: '2024-02-01',
max_tokens: 4096,
llm_factory: 'Azure-OpenAI',
}, },
}); });
// 当对话框打开或初始数据变化时重置表单 // 当对话框打开或初始数据变化时重置表单
useEffect(() => { useEffect(() => {
if (open) { if (open) {
reset(initialData || { api_key: '', endpoint: '', api_version: '2024-02-01' }); reset({
model_type: 'embedding',
llm_name: 'gpt-3.5-turbo',
api_base: '',
api_key: '',
api_version: '2024-02-01',
max_tokens: 4096,
llm_factory: initialData?.llm_factory || 'Azure-OpenAI',
});
} }
}, [open, initialData, reset]); }, [open, initialData, reset]);
@@ -80,19 +104,86 @@ function AzureOpenAIDialog ({
</DialogTitle> </DialogTitle>
<DialogContent> <DialogContent>
<Box component="form" sx={{ mt: 2 }}> <Box component="form" sx={{ mt: 2 }}>
{/* 模型类型选择 */}
<Controller <Controller
name="api_key" name="model_type"
control={control} control={control}
rules={{ required: 'API Key 是必填项' }} rules={{ required: '模型类型是必填项' }}
render={({ field }) => (
<FormControl fullWidth margin="normal" error={!!errors.model_type}>
<InputLabel></InputLabel>
<Select
{...field}
label="模型类型"
>
{MODEL_TYPE_OPTIONS.map((option) => (
<MenuItem key={option.value} value={option.value}>
{option.label}
</MenuItem>
))}
</Select>
{errors.model_type && (
<FormHelperText>{errors.model_type.message}</FormHelperText>
)}
</FormControl>
)}
/>
{/* 模型名称 */}
<Controller
name="llm_name"
control={control}
rules={{ required: '模型名称是必填项' }}
render={({ field }) => ( render={({ field }) => (
<TextField <TextField
{...field} {...field}
fullWidth fullWidth
label="API Key" label="模型名称"
margin="normal"
error={!!errors.llm_name}
helperText={errors.llm_name?.message || '请输入模型名称'}
placeholder="gpt-3.5-turbo"
/>
)}
/>
{/* 基础 URL */}
<Controller
name="api_base"
control={control}
rules={{
required: '基础 URL 是必填项',
pattern: {
value: /^https?:\/\/.+/,
message: '基础 URL 必须是有效的 URL'
}
}}
render={({ field }) => (
<TextField
{...field}
fullWidth
label="基础 Url"
margin="normal"
error={!!errors.api_base}
helperText={errors.api_base?.message || 'Azure OpenAI 服务的端点 URL'}
placeholder="https://your-resource.openai.azure.com/"
/>
)}
/>
{/* API Key */}
<Controller
name="api_key"
control={control}
render={({ field }) => (
<TextField
{...field}
fullWidth
label="API-Key"
type={showApiKey ? 'text' : 'password'} type={showApiKey ? 'text' : 'password'}
margin="normal" margin="normal"
error={!!errors.api_key} error={!!errors.api_key}
helperText={errors.api_key?.message} helperText={errors.api_key?.message || '输入api key如果是本地部署的模型请忽略'}
InputProps={{ InputProps={{
endAdornment: ( endAdornment: (
<InputAdornment position="end"> <InputAdornment position="end">
@@ -110,29 +201,7 @@ function AzureOpenAIDialog ({
)} )}
/> />
<Controller {/* API Version */}
name="endpoint"
control={control}
rules={{
required: 'Endpoint 是必填项',
pattern: {
value: /^https?:\/\/.+/,
message: 'Endpoint 必须是有效的 URL'
}
}}
render={({ field }) => (
<TextField
{...field}
fullWidth
label="Endpoint"
margin="normal"
error={!!errors.endpoint}
helperText={errors.endpoint?.message || 'Azure OpenAI 服务的端点 URL'}
placeholder="https://your-resource.openai.azure.com/"
/>
)}
/>
<Controller <Controller
name="api_version" name="api_version"
control={control} control={control}
@@ -149,9 +218,34 @@ function AzureOpenAIDialog ({
/> />
)} )}
/> />
{/* 最大token数 */}
<Controller
name="max_tokens"
control={control}
rules={{
required: '最大token数是必填项',
min: { value: 1, message: '最大token数必须大于0' },
max: { value: 100000, message: '最大token数不能超过100000' }
}}
render={({ field }) => (
<TextField
{...field}
fullWidth
label="最大token数"
type="number"
margin="normal"
error={!!errors.max_tokens}
helperText={errors.max_tokens?.message || '设置了模型输出的最大长度以token单词片段的数量表示'}
onChange={(e) => field.onChange(parseInt(e.target.value) || 0)}
/>
)}
/>
</Box> </Box>
</DialogContent> </DialogContent>
<DialogActions> <DialogActions>
{/* 右侧按钮组 */}
<Box sx={{ display: 'flex', gap: 1 }}>
<Button onClick={onClose} disabled={loading}> <Button onClick={onClose} disabled={loading}>
</Button> </Button>
@@ -163,6 +257,7 @@ function AzureOpenAIDialog ({
> >
</Button> </Button>
</Box>
</DialogActions> </DialogActions>
</Dialog> </Dialog>
); );

View File

@@ -15,25 +15,63 @@ import {
Select, Select,
MenuItem, MenuItem,
CircularProgress, CircularProgress,
FormHelperText,
Link,
} from '@mui/material'; } from '@mui/material';
import { Visibility, VisibilityOff } from '@mui/icons-material'; import { Visibility, VisibilityOff } from '@mui/icons-material';
import { Controller, useForm } from 'react-hook-form'; import { Controller, useForm } from 'react-hook-form';
import type { IAddLlmRequestBody } from '@/interfaces/request/llm';
// AWS Bedrock 支持的区域列表 // AWS Bedrock 支持的区域列表
export const BEDROCK_REGIONS = [ export const BEDROCK_REGIONS = [
{ value: 'us-east-1', label: 'US East (N. Virginia)' }, 'us-east-2',
{ value: 'us-west-2', label: 'US West (Oregon)' }, 'us-east-1',
{ value: 'ap-southeast-2', label: 'Asia Pacific (Sydney)' }, 'us-west-1',
{ value: 'ap-northeast-1', label: 'Asia Pacific (Tokyo)' }, 'us-west-2',
{ value: 'eu-central-1', label: 'Europe (Frankfurt)' }, 'af-south-1',
{ value: 'eu-west-3', label: 'Europe (Paris)' }, 'ap-east-1',
'ap-south-2',
'ap-southeast-3',
'ap-southeast-5',
'ap-southeast-4',
'ap-south-1',
'ap-northeast-3',
'ap-northeast-2',
'ap-southeast-1',
'ap-southeast-2',
'ap-east-2',
'ap-southeast-7',
'ap-northeast-1',
'ca-central-1',
'ca-west-1',
'eu-central-1',
'eu-west-1',
'eu-west-2',
'eu-south-1',
'eu-west-3',
'eu-south-2',
'eu-north-1',
'eu-central-2',
'il-central-1',
'mx-central-1',
'me-south-1',
'me-central-1',
'sa-east-1',
'us-gov-east-1',
'us-gov-west-1',
];
// 模型类型选项
const MODEL_TYPE_OPTIONS = [
{ value: 'chat', label: 'Chat' },
{ value: 'embedding', label: 'Embedding' },
]; ];
// 表单数据接口 // 表单数据接口
export interface BedrockFormData { export interface BedrockFormData extends IAddLlmRequestBody {
access_key_id: string; bedrock_ak: string;
secret_access_key: string; bedrock_sk: string;
region: string; bedrock_region: string;
} }
// 对话框 Props 接口 // 对话框 Props 接口
@@ -67,16 +105,28 @@ function BedrockDialog ({
formState: { errors }, formState: { errors },
} = useForm<BedrockFormData>({ } = useForm<BedrockFormData>({
defaultValues: { defaultValues: {
access_key_id: '', model_type: 'chat',
secret_access_key: '', llm_name: '',
region: 'us-east-1', bedrock_ak: '',
bedrock_sk: '',
bedrock_region: 'us-east-1',
max_tokens: 4096,
llm_factory: 'Bedrock',
}, },
}); });
// 当对话框打开或初始数据变化时重置表单 // 当对话框打开或初始数据变化时重置表单
useEffect(() => { useEffect(() => {
if (open) { if (open) {
reset(initialData || { access_key_id: '', secret_access_key: '', region: 'us-east-1' }); reset({
model_type: 'chat',
llm_name: '',
bedrock_ak: '',
bedrock_sk: '',
bedrock_region: 'us-east-1',
max_tokens: 4096,
llm_factory: initialData?.llm_factory || 'Bedrock',
});
} }
}, [open, initialData, reset]); }, [open, initialData, reset]);
@@ -92,26 +142,73 @@ function BedrockDialog ({
setShowSecretKey(!showSecretKey); setShowSecretKey(!showSecretKey);
}; };
const docInfo = {
url: 'https://console.aws.amazon.com/',
text: '如何集成 Bedrock',
};
return ( return (
<Dialog open={open} onClose={onClose} maxWidth="sm" fullWidth> <Dialog open={open} onClose={onClose} maxWidth="sm" fullWidth>
<DialogTitle> <DialogTitle>
{editMode ? '编辑' : '配置'} AWS Bedrock {editMode ? '编辑' : '添加'} LLM
</DialogTitle> </DialogTitle>
<DialogContent> <DialogContent>
<Box component="form" sx={{ mt: 2 }}> <Box component="form" sx={{ mt: 2 }}>
{/* 模型类型 */}
<Controller <Controller
name="access_key_id" name="model_type"
control={control} control={control}
rules={{ required: 'Access Key ID 是必填项' }} rules={{ required: '模型类型是必填项' }}
render={({ field }) => (
<FormControl fullWidth margin="normal" error={!!errors.model_type}>
<InputLabel>* </InputLabel>
<Select {...field} label="* 模型类型">
{MODEL_TYPE_OPTIONS.map((option) => (
<MenuItem key={option.value} value={option.value}>
{option.label}
</MenuItem>
))}
</Select>
{errors.model_type && (
<FormHelperText>{errors.model_type.message}</FormHelperText>
)}
</FormControl>
)}
/>
{/* 模型名称 */}
<Controller
name="llm_name"
control={control}
rules={{ required: '模型名称是必填项' }}
render={({ field }) => ( render={({ field }) => (
<TextField <TextField
{...field} {...field}
fullWidth fullWidth
label="Access Key ID" label="* 模型名称"
margin="normal"
placeholder="请输入模型名称"
error={!!errors.llm_name}
helperText={errors.llm_name?.message}
/>
)}
/>
{/* ACCESS KEY */}
<Controller
name="bedrock_ak"
control={control}
rules={{ required: 'ACCESS KEY 是必填项' }}
render={({ field }) => (
<TextField
{...field}
fullWidth
label="* ACCESS KEY"
type={showAccessKey ? 'text' : 'password'} type={showAccessKey ? 'text' : 'password'}
margin="normal" margin="normal"
error={!!errors.access_key_id} placeholder="请输入 ACCESS KEY"
helperText={errors.access_key_id?.message} error={!!errors.bedrock_ak}
helperText={errors.bedrock_ak?.message}
InputProps={{ InputProps={{
endAdornment: ( endAdornment: (
<InputAdornment position="end"> <InputAdornment position="end">
@@ -129,19 +226,21 @@ function BedrockDialog ({
)} )}
/> />
{/* SECRET KEY */}
<Controller <Controller
name="secret_access_key" name="bedrock_sk"
control={control} control={control}
rules={{ required: 'Secret Access Key 是必填项' }} rules={{ required: 'SECRET KEY 是必填项' }}
render={({ field }) => ( render={({ field }) => (
<TextField <TextField
{...field} {...field}
fullWidth fullWidth
label="Secret Access Key" label="* SECRET KEY"
type={showSecretKey ? 'text' : 'password'} type={showSecretKey ? 'text' : 'password'}
margin="normal" margin="normal"
error={!!errors.secret_access_key} placeholder="请输入 SECRET KEY"
helperText={errors.secret_access_key?.message} error={!!errors.bedrock_sk}
helperText={errors.bedrock_sk?.message}
InputProps={{ InputProps={{
endAdornment: ( endAdornment: (
<InputAdornment position="end"> <InputAdornment position="end">
@@ -159,32 +258,64 @@ function BedrockDialog ({
)} )}
/> />
{/* AWS Region */}
<Controller <Controller
name="region" name="bedrock_region"
control={control} control={control}
rules={{ required: 'Region 是必填项' }} rules={{ required: 'AWS Region 是必填项' }}
render={({ field }) => ( render={({ field }) => (
<FormControl fullWidth margin="normal" error={!!errors.region}> <FormControl fullWidth margin="normal" error={!!errors.bedrock_region}>
<InputLabel>Region</InputLabel> <InputLabel>* AWS Region</InputLabel>
<Select {...field} label="Region"> <Select {...field} label="* AWS Region">
{BEDROCK_REGIONS.map((region) => ( {BEDROCK_REGIONS.map((region) => (
<MenuItem key={region.value} value={region.value}> <MenuItem key={region} value={region}>
{region.label} {region}
</MenuItem> </MenuItem>
))} ))}
</Select> </Select>
{errors.region && ( {errors.bedrock_region && (
<Typography variant="caption" color="error" sx={{ mt: 1, ml: 2 }}> <FormHelperText>{errors.bedrock_region.message}</FormHelperText>
{errors.region.message}
</Typography>
)} )}
</FormControl> </FormControl>
)} )}
/> />
{/* 最大token数 */}
<Controller
name="max_tokens"
control={control}
rules={{
required: '最大token数是必填项',
min: { value: 1, message: '最大token数必须大于0' },
}}
render={({ field }) => (
<TextField
{...field}
fullWidth
label="* 最大token数"
type="number"
margin="normal"
placeholder="这设置了模型输出的最大长度以token单词或词片段的数量来衡量"
error={!!errors.max_tokens}
helperText={errors.max_tokens?.message}
onChange={(e) => field.onChange(Number(e.target.value))}
/>
)}
/>
</Box> </Box>
</DialogContent> </DialogContent>
<DialogActions> <DialogActions>
<Button onClick={onClose} disabled={loading}> <Box sx={{ display: 'flex', justifyContent: 'space-between', width: '100%' }}>
<Link
href={docInfo.url}
target="_blank"
rel="noopener noreferrer"
sx={{ alignSelf: 'center', textDecoration: 'none', ml:2 }}
>
{docInfo.text}
</Link>
<Box>
<Button onClick={onClose} disabled={loading} sx={{ mr: 1 }}>
</Button> </Button>
<Button <Button
@@ -195,6 +326,8 @@ function BedrockDialog ({
> >
</Button> </Button>
</Box>
</Box>
</DialogActions> </DialogActions>
</Dialog> </Dialog>
); );

View File

@@ -1,4 +1,4 @@
import React, { useEffect } from 'react'; import React, { useEffect, useMemo } from 'react';
import { import {
Dialog, Dialog,
DialogTitle, DialogTitle,
@@ -9,14 +9,28 @@ import {
Box, Box,
Typography, Typography,
CircularProgress, CircularProgress,
MenuItem,
Select,
FormControl,
InputLabel,
FormHelperText,
Link,
} from '@mui/material'; } from '@mui/material';
import { Controller, useForm } from 'react-hook-form'; import { Controller, useForm } from 'react-hook-form';
import logger from '@/utils/logger';
import { LLM_FACTORY_LIST, type LLMFactory } from '@/constants/llm';
// 表单数据接口 // 表单数据接口
export interface OllamaFormData { export interface OllamaFormData {
base_url: string; model_type: string;
llm_name: string;
api_base: string;
api_key?: string;
max_tokens: number;
llm_factory: string;
} }
// 对话框 Props 接口 // 对话框 Props 接口
export interface OllamaDialogProps { export interface OllamaDialogProps {
open: boolean; open: boolean;
@@ -27,8 +41,47 @@ export interface OllamaDialogProps {
editMode?: boolean; editMode?: boolean;
} }
const llmFactoryToUrlMap: { [x: string]: string } = {
[LLM_FACTORY_LIST.Ollama]:
'https://github.com/infiniflow/ragflow/blob/main/docs/guides/models/deploy_local_llm.mdx',
[LLM_FACTORY_LIST.Xinference]:
'https://inference.readthedocs.io/en/latest/user_guide',
[LLM_FACTORY_LIST.ModelScope]:
'https://www.modelscope.cn/docs/model-service/API-Inference/intro',
[LLM_FACTORY_LIST.LocalAI]: 'https://localai.io/docs/getting-started/models/',
[LLM_FACTORY_LIST.LMStudio]: 'https://lmstudio.ai/docs/basics',
[LLM_FACTORY_LIST.OpenAiAPICompatible]:
'https://platform.openai.com/docs/models/gpt-4',
[LLM_FACTORY_LIST.TogetherAI]: 'https://docs.together.ai/docs/deployment-options',
[LLM_FACTORY_LIST.Replicate]: 'https://replicate.com/docs/topics/deployments',
[LLM_FACTORY_LIST.OpenRouter]: 'https://openrouter.ai/docs',
[LLM_FACTORY_LIST.HuggingFace]:
'https://huggingface.co/docs/text-embeddings-inference/quick_tour',
[LLM_FACTORY_LIST.GPUStack]: 'https://docs.gpustack.ai/latest/quickstart',
[LLM_FACTORY_LIST.VLLM]: 'https://docs.vllm.ai/en/latest/',
} as const;
function getURLByFactory(factory: LLMFactory) {
const url = llmFactoryToUrlMap[factory];
return {
textTip: `如何集成 ${factory}`,
url,
}
}
// 模型类型选项
const MODEL_TYPE_OPTIONS = [
{ value: 'chat', label: 'Chat' },
{ value: 'embedding', label: 'Embedding' },
{ value: 'rerank', label: 'Rerank' },
{ value: 'image2text', label: 'Image2Text' },
{ value: 'speech2text', label: 'Speech2Text' },
];
/** /**
* Ollama 配置对话框 * Ollama / local llm 配置对话框
*/ */
function OllamaDialog({ function OllamaDialog({
open, open,
@@ -45,14 +98,60 @@ function OllamaDialog ({
formState: { errors }, formState: { errors },
} = useForm<OllamaFormData>({ } = useForm<OllamaFormData>({
defaultValues: { defaultValues: {
base_url: 'http://localhost:11434', model_type: 'chat',
llm_name: '',
api_base: 'http://localhost:11434',
api_key: '',
max_tokens: 4096,
llm_factory: 'Ollama',
}, },
}); });
const modelTypeOptions = useMemo(() => {
const factory = initialData?.llm_factory || LLM_FACTORY_LIST.Ollama;
if (factory == LLM_FACTORY_LIST.HuggingFace) {
return [
{ value: 'embedding', label: 'Embedding' },
{ value: 'chat', label: 'Chat' },
{ value: 'rerank', label: 'Rerank' },
]
} else if (factory == LLM_FACTORY_LIST.Xinference) {
return [
{ value: 'chat', label: 'Chat' },
{ value: 'embedding', label: 'Embedding' },
{ value: 'rerank', label: 'Rerank' },
{ value: 'image2text', label: 'Image2Text' },
{ value: 'speech2text', label: 'Speech2Text' },
{ value: 'tts', label: 'TTS' },
]
} else if (factory == LLM_FACTORY_LIST.ModelScope) {
return [
{ value: 'chat', label: 'Chat' },
]
} else if (factory == LLM_FACTORY_LIST.GPUStack) {
return [
{ value: 'chat', label: 'Chat' },
{ value: 'embedding', label: 'Embedding' },
{ value: 'rerank', label: 'Rerank' },
{ value: 'image2text', label: 'Image2Text' },
]
}
return MODEL_TYPE_OPTIONS;
}, [initialData])
logger.debug('OllamaDialog', { open, initialData, editMode });
// 当对话框打开或初始数据变化时重置表单 // 当对话框打开或初始数据变化时重置表单
useEffect(() => { useEffect(() => {
if (open) { if (open) {
reset(initialData || { base_url: 'http://localhost:11434' }); reset({
model_type: 'chat',
llm_name: '',
api_base: initialData?.api_base,
api_key: initialData?.api_key,
max_tokens: initialData?.max_tokens,
llm_factory: initialData?.llm_factory || 'Ollama',
});
} }
}, [open, initialData, reset]); }, [open, initialData, reset]);
@@ -60,38 +159,155 @@ function OllamaDialog ({
onSubmit(data); onSubmit(data);
}; };
// 获取文档链接信息
const docInfo = getURLByFactory((initialData?.llm_factory || LLM_FACTORY_LIST.Ollama) as LLMFactory);
return ( return (
<Dialog open={open} onClose={onClose} maxWidth="sm" fullWidth> <Dialog open={open} onClose={onClose} maxWidth="sm" fullWidth>
<DialogTitle> <DialogTitle>
{editMode ? '编辑' : '配置'} Ollama {editMode ? `编辑 ${initialData?.llm_factory || LLM_FACTORY_LIST.Ollama}` : `配置 ${initialData?.llm_factory || LLM_FACTORY_LIST.Ollama}`}
</DialogTitle> </DialogTitle>
<DialogContent> <DialogContent>
<Box component="form" sx={{ mt: 2 }}> <Box component="form" sx={{ mt: 2 }}>
{/* 模型类型选择 */}
<Controller <Controller
name="base_url" name="model_type"
control={control}
rules={{ required: '模型类型是必填项' }}
render={({ field }) => (
<FormControl fullWidth margin="normal" error={!!errors.model_type}>
<InputLabel> *</InputLabel>
<Select
{...field}
label="模型类型 *"
>
{modelTypeOptions.map((option) => (
<MenuItem key={option.value} value={option.value}>
{option.label}
</MenuItem>
))}
</Select>
{errors.model_type && (
<FormHelperText>{errors.model_type.message}</FormHelperText>
)}
</FormControl>
)}
/>
{/* 模型名称 */}
<Controller
name="llm_name"
control={control}
rules={{ required: '模型名称是必填项' }}
render={({ field }) => (
<TextField
{...field}
fullWidth
label="模型名称"
margin="normal"
required
error={!!errors.llm_name}
helperText={errors.llm_name?.message || '请输入模型名称'}
placeholder="例如: llama2, mistral"
/>
)}
/>
{/* 基础 URL */}
<Controller
name="api_base"
control={control} control={control}
rules={{ rules={{
required: 'Base URL 是必填项', required: '基础 URL 是必填项',
pattern: { pattern: {
value: /^https?:\/\/.+/, value: /^https?:\/\/.+/,
message: 'Base URL 必须是有效的 URL' message: '基础 URL 必须是有效的 URL'
} }
}} }}
render={({ field }) => ( render={({ field }) => (
<TextField <TextField
{...field} {...field}
fullWidth fullWidth
label="Base URL" label="基础 URL"
margin="normal" margin="normal"
error={!!errors.base_url} required
helperText={errors.base_url?.message || 'Ollama 服务的基础 URL'} error={!!errors.api_base}
placeholder="http://localhost:11434" helperText={errors.api_base?.message || '基础 URL'}
placeholder="http://localhost:8888"
/>
)}
/>
{/* API Key (可选) */}
<Controller
name="api_key"
control={control}
render={({ field }) => (
<TextField
{...field}
fullWidth
label="API Key"
margin="normal"
error={!!errors.api_key}
helperText={errors.api_key?.message || 'API Key (可选)'}
placeholder="如果需要认证,请输入 API Key"
/>
)}
/>
{/* 最大 Token 数 */}
<Controller
name="max_tokens"
control={control}
rules={{
required: '最大 Token 数是必填项',
min: {
value: 1,
message: '最大 Token 数必须大于 0'
},
max: {
value: 100000,
message: '最大 Token 数不能超过 100000'
}
}}
render={({ field }) => (
<TextField
{...field}
fullWidth
label="最大 Token 数"
margin="normal"
type="number"
required
error={!!errors.max_tokens}
helperText={errors.max_tokens?.message || '模型支持的最大 Token 数'}
placeholder="4096"
onChange={(e) => field.onChange(parseInt(e.target.value) || 0)}
/> />
)} )}
/> />
</Box> </Box>
</DialogContent> </DialogContent>
<DialogActions> <DialogActions>
<Box sx={{ display: 'flex', justifyContent: 'space-between', width: '100%', alignItems: 'center' }}>
{/* 左侧文档链接 */}
<Link
href={docInfo.url}
target="_blank"
rel="noopener noreferrer"
sx={{
ml: 2,
fontSize: '16',
textDecoration: 'none',
'&:hover': {
textDecoration: 'underline'
}
}}
>
{docInfo.textTip}
</Link>
{/* 右侧按钮组 */}
<Box sx={{ display: 'flex', gap: 1 }}>
<Button onClick={onClose} disabled={loading}> <Button onClick={onClose} disabled={loading}>
</Button> </Button>
@@ -103,6 +319,8 @@ function OllamaDialog ({
> >
</Button> </Button>
</Box>
</Box>
</DialogActions> </DialogActions>
</Dialog> </Dialog>
); );

View File

@@ -107,23 +107,25 @@ export const useAzureOpenAIDialog = () => {
dialogState.setLoading(true); dialogState.setLoading(true);
try { try {
// 调用 Azure OpenAI 特定的 API // 调用 Azure OpenAI 特定的 API
await userService.set_api_key({ await userService.add_llm({
llm_factory: 'AzureOpenAI', llm_factory: data.llm_factory,
// llm_name: data.deployment_name, llm_name: data.llm_name,
model_type: data.model_type,
api_base: data.api_base,
api_key: data.api_key, api_key: data.api_key,
// azure_endpoint: data.azure_endpoint, // @ts-ignore
// api_version: data.api_version, api_version: data.api_version,
max_tokens: data.max_tokens,
}); });
showMessage.success('Azure OpenAI 配置成功'); showMessage.success('Azure OpenAI 配置成功');
dialogState.closeDialog(); dialogState.closeDialog();
} catch (error) { } catch (error) {
logger.error('Azure OpenAI 配置失败:', error); logger.error('Azure OpenAI 配置失败:', error);
showMessage.error('Azure OpenAI 配置失败');
throw error; throw error;
} finally { } finally {
dialogState.setLoading(false); dialogState.setLoading(false);
} }
}, [dialogState]); }, [dialogState, showMessage]);
return { return {
...dialogState, ...dialogState,
@@ -140,13 +142,15 @@ export const useBedrockDialog = () => {
dialogState.setLoading(true); dialogState.setLoading(true);
try { try {
// 调用 Bedrock 特定的 API // 调用 Bedrock 特定的 API
await userService.set_api_key({ await userService.add_llm({
llm_factory: 'Bedrock', llm_factory: data.llm_factory,
llm_name: '', llm_name: data.llm_name,
api_key: '', // Bedrock 使用 access key model_type: data.model_type,
// access_key_id: data.access_key_id, // @ts-ignore
// secret_access_key: data.secret_access_key, bedrock_ak: data.bedrock_ak,
// region: data.region, bedrock_sk: data.bedrock_sk,
bedrock_region: data.bedrock_region,
max_tokens: data.max_tokens,
}); });
showMessage.success('AWS Bedrock 配置成功'); showMessage.success('AWS Bedrock 配置成功');
dialogState.closeDialog(); dialogState.closeDialog();
@@ -175,9 +179,12 @@ export const useOllamaDialog = () => {
try { try {
// 调用添加 LLM 的 API // 调用添加 LLM 的 API
await userService.add_llm({ await userService.add_llm({
llm_factory: 'Ollama', llm_factory: data.llm_factory,
// llm_name: data.model_name, llm_name: data.llm_name,
// base_url: data.base_url, model_type: data.model_type,
api_base: data.api_base,
api_key: data.api_key || '',
max_tokens: data.max_tokens,
}); });
showMessage.success('Ollama 模型添加成功'); showMessage.success('Ollama 模型添加成功');
dialogState.closeDialog(); dialogState.closeDialog();
@@ -188,7 +195,7 @@ export const useOllamaDialog = () => {
} finally { } finally {
dialogState.setLoading(false); dialogState.setLoading(false);
} }
}, [dialogState]); }, [dialogState, showMessage]);
return { return {
...dialogState, ...dialogState,

View File

@@ -69,7 +69,7 @@ function ModelsPage() {
const { llmFactory, myLlm, refreshLlmModel } = useLlmModelSetting(); const { llmFactory, myLlm, refreshLlmModel } = useLlmModelSetting();
const modelDialogs = useModelDialogs(refreshLlmModel); const modelDialogs = useModelDialogs(refreshLlmModel);
// 折叠状态管理 - 使用 Map 来管理每个工厂的折叠状态 // 折叠状态管理 - 使用 Map 来管理每个工厂的折叠状态,默认所有工厂都是折叠的
const [collapsedFactories, setCollapsedFactories] = useState<Record<string, boolean>>({}); const [collapsedFactories, setCollapsedFactories] = useState<Record<string, boolean>>({});
// 切换工厂折叠状态 // 切换工厂折叠状态
@@ -120,11 +120,17 @@ function ModelsPage() {
// 然后有很多自定义的配置项,需要单独用 dialog 来配置 // 然后有很多自定义的配置项,需要单独用 dialog 来配置
const factoryName = factory.name as LLMFactory; const factoryName = factory.name as LLMFactory;
if (LocalLlmFactories.includes(factoryName)) { if (LocalLlmFactories.includes(factoryName)) {
// modelDialogs.localLlmDialog.openLocalLlmDialog(factoryName); modelDialogs.ollamaDialog.openDialog({
llm_factory: factory.name,
});
} else if (factoryName == LLM_FACTORY_LIST.AzureOpenAI) { } else if (factoryName == LLM_FACTORY_LIST.AzureOpenAI) {
modelDialogs.azureDialog.openDialog({
llm_factory: factory.name,
});
} else if (factoryName == LLM_FACTORY_LIST.Bedrock) { } else if (factoryName == LLM_FACTORY_LIST.Bedrock) {
modelDialogs.bedrockDialog.openDialog({
llm_factory: factory.name,
});
} else if (factoryName == LLM_FACTORY_LIST.BaiduYiYan) { } else if (factoryName == LLM_FACTORY_LIST.BaiduYiYan) {
} else if (factoryName == LLM_FACTORY_LIST.GoogleCloud) { } else if (factoryName == LLM_FACTORY_LIST.GoogleCloud) {
@@ -241,7 +247,7 @@ function ModelsPage() {
<Box sx={{ display: 'flex', alignItems: 'center', gap: 1 }}> <Box sx={{ display: 'flex', alignItems: 'center', gap: 1 }}>
{/* 折叠/展开图标 */} {/* 折叠/展开图标 */}
<IconButton size="small"> <IconButton size="small">
{collapsedFactories[factoryName] ? <ExpandMoreIcon /> : <ExpandLessIcon />} {collapsedFactories[factoryName] ? <ExpandLessIcon />: <ExpandMoreIcon /> }
</IconButton> </IconButton>
<Box> <Box>
{/* 模型工厂名称 */} {/* 模型工厂名称 */}
@@ -281,7 +287,7 @@ function ModelsPage() {
</Box> </Box>
</Box> </Box>
{/* 模型列表 - 使用 Collapse 组件包装 */} {/* 模型列表 - 使用 Collapse 组件包装 */}
<Collapse in={!collapsedFactories[factoryName]} timeout="auto" unmountOnExit> <Collapse in={collapsedFactories[factoryName]} timeout="auto" unmountOnExit>
<Box sx={{ mt: 2 }}> <Box sx={{ mt: 2 }}>
<Grid container spacing={2}> <Grid container spacing={2}>
{group.llm.map((model) => ( {group.llm.map((model) => (