444 lines
15 KiB
TypeScript
444 lines
15 KiB
TypeScript
import React, { useEffect, useState } from 'react';
|
|
import { useParams, useNavigate } from 'react-router-dom';
|
|
import { useForm, Controller } from 'react-hook-form';
|
|
import { useTranslation } from 'react-i18next';
|
|
import {
|
|
Box,
|
|
Container,
|
|
Typography,
|
|
Paper,
|
|
TextField,
|
|
Button,
|
|
Slider,
|
|
FormControl,
|
|
InputLabel,
|
|
Select,
|
|
MenuItem,
|
|
FormControlLabel,
|
|
Switch,
|
|
Grid,
|
|
Pagination,
|
|
Checkbox,
|
|
ListItemText,
|
|
OutlinedInput,
|
|
ListSubheader,
|
|
Chip,
|
|
} from '@mui/material';
|
|
import { useKnowledgeDetail } from '@/hooks/knowledge-hooks';
|
|
import { useRerankModelOptions } from '@/hooks/llm-hooks';
|
|
import knowledgeService from '@/services/knowledge_service';
|
|
import type { ITestRetrievalRequestBody } from '@/interfaces/request/knowledge';
|
|
import type { INextTestingResult } from '@/interfaces/database/knowledge';
|
|
import KnowledgeBreadcrumbs from './components/KnowledgeBreadcrumbs';
|
|
import TestChunkResult from './components/TestChunkResult';
|
|
import { useSnackbar } from '@/components/Provider/SnackbarProvider';
|
|
import { toLower } from 'lodash';
|
|
import { t } from 'i18next';
|
|
|
|
// 语言选项常量
|
|
const options = [
|
|
{ value: 'en', label: t('knowledgeTesting.languages.english') },
|
|
{ value: 'zh', label: t('knowledgeTesting.languages.chinese') },
|
|
{ value: 'ja', label: t('knowledgeTesting.languages.japanese') },
|
|
{ value: 'ko', label: t('knowledgeTesting.languages.korean') },
|
|
{ value: 'fr', label: t('knowledgeTesting.languages.french') },
|
|
{ value: 'de', label: t('knowledgeTesting.languages.german') },
|
|
{ value: 'es', label: t('knowledgeTesting.languages.spanish') },
|
|
{ value: 'vi', label: t('knowledgeTesting.languages.vietnamese') },
|
|
];
|
|
|
|
// 表单数据接口
|
|
interface TestFormData {
|
|
question: string;
|
|
similarity_threshold: number;
|
|
vector_similarity_weight: number;
|
|
rerank_id?: string;
|
|
top_k?: number;
|
|
use_kg?: boolean;
|
|
cross_languages?: string[];
|
|
doc_ids?: string[];
|
|
}
|
|
|
|
function KnowledgeBaseTesting() {
|
|
const { t } = useTranslation();
|
|
const { id } = useParams<{ id: string }>();
|
|
const navigate = useNavigate();
|
|
|
|
// 状态管理
|
|
const [testResult, setTestResult] = useState<INextTestingResult | null>(null);
|
|
const [testing, setTesting] = useState(false);
|
|
const [page, setPage] = useState(1);
|
|
const [pageSize] = useState(10);
|
|
const [selectedDocIds, setSelectedDocIds] = useState<string[]>([]);
|
|
|
|
const { showMessage } = useSnackbar();
|
|
|
|
// 获取知识库详情
|
|
const { knowledge: knowledgeDetail, loading: detailLoading } = useKnowledgeDetail(id || '');
|
|
|
|
// 获取重排序模型选项
|
|
const { options: rerankOptions, loading: rerankLoading } = useRerankModelOptions();
|
|
|
|
// 表单配置
|
|
const { control, handleSubmit, watch, register, setValue, getValues, formState: { errors } } = useForm<TestFormData>({
|
|
defaultValues: {
|
|
question: '',
|
|
similarity_threshold: 0.2,
|
|
vector_similarity_weight: 0.3,
|
|
rerank_id: '',
|
|
top_k: 1024,
|
|
use_kg: false,
|
|
cross_languages: [],
|
|
doc_ids: [],
|
|
},
|
|
});
|
|
|
|
// 处理测试提交
|
|
|
|
const handleTestSubmit = async (data: TestFormData) => {
|
|
return handleTestSubmitFunc(data, false);
|
|
}
|
|
|
|
const handleTestSubmitFunc = async (data: TestFormData, withSelectedDocs: boolean = false) => {
|
|
if (!id) return;
|
|
|
|
setTesting(true);
|
|
try {
|
|
const requestBody: ITestRetrievalRequestBody = {
|
|
question: data.question,
|
|
similarity_threshold: data.similarity_threshold,
|
|
vector_similarity_weight: data.vector_similarity_weight,
|
|
kb_id: id,
|
|
page: page,
|
|
size: pageSize,
|
|
};
|
|
|
|
// 只有当字段有值时才添加到请求体中
|
|
if (data.rerank_id) {
|
|
requestBody.rerank_id = data.rerank_id;
|
|
}
|
|
|
|
if (data.top_k) {
|
|
requestBody.top_k = data.top_k;
|
|
}
|
|
|
|
if (data.use_kg !== undefined) {
|
|
requestBody.use_kg = data.use_kg;
|
|
}
|
|
|
|
// 如果有选择的文档,添加到请求中
|
|
if (withSelectedDocs) {
|
|
const doc_ids = data.doc_ids || [];
|
|
if (doc_ids.length > 0) {
|
|
requestBody.doc_ids = doc_ids;
|
|
}
|
|
} else {
|
|
if (selectedDocIds.length > 0) {
|
|
requestBody.doc_ids = selectedDocIds;
|
|
}
|
|
}
|
|
if (data.cross_languages && data.cross_languages.length > 0) {
|
|
requestBody.cross_languages = data.cross_languages;
|
|
}
|
|
|
|
const response = await knowledgeService.retrievalTest(requestBody);
|
|
|
|
if (response.data.code === 0) {
|
|
setTestResult(response.data.data);
|
|
setPage(1); // 重置到第一页
|
|
showMessage.success(t('knowledgeTesting.retrievalTestComplete'));
|
|
} else {
|
|
throw new Error(response.data.message || t('knowledgeTesting.retrievalTestFailed'));
|
|
}
|
|
} catch (error: any) {
|
|
showMessage.error(error.message || t('knowledgeTesting.retrievalTestFailed'));
|
|
} finally {
|
|
setTesting(false);
|
|
}
|
|
};
|
|
|
|
// 处理分页变化
|
|
const handlePageChange = async (event: React.ChangeEvent<unknown>, value: number) => {
|
|
if (!id) return;
|
|
|
|
setPage(value);
|
|
setTesting(true);
|
|
|
|
try {
|
|
const formData = getValues();
|
|
const requestBody: ITestRetrievalRequestBody = {
|
|
question: formData.question,
|
|
similarity_threshold: formData.similarity_threshold,
|
|
vector_similarity_weight: formData.vector_similarity_weight,
|
|
kb_id: id,
|
|
page: value,
|
|
size: pageSize,
|
|
highlight: true,
|
|
};
|
|
|
|
// 只有当字段有值时才添加到请求体中
|
|
if (formData.rerank_id) {
|
|
requestBody.rerank_id = formData.rerank_id;
|
|
}
|
|
if (formData.top_k) {
|
|
requestBody.top_k = formData.top_k;
|
|
}
|
|
if (formData.use_kg !== undefined) {
|
|
requestBody.use_kg = formData.use_kg;
|
|
}
|
|
if (selectedDocIds.length > 0) {
|
|
requestBody.doc_ids = selectedDocIds;
|
|
}
|
|
|
|
if (formData.cross_languages && formData.cross_languages.length > 0) {
|
|
requestBody.cross_languages = formData.cross_languages;
|
|
}
|
|
|
|
const response = await knowledgeService.retrievalTest(requestBody);
|
|
if (response.data.code === 0) {
|
|
setTestResult(response.data.data);
|
|
} else {
|
|
throw new Error(response.data.message || t('knowledgeTesting.paginationRequestFailed'));
|
|
}
|
|
} catch (error: any) {
|
|
showMessage.error(error.message || t('knowledgeTesting.paginationRequestFailed'));
|
|
} finally {
|
|
setTesting(false);
|
|
}
|
|
};
|
|
|
|
// 处理文档过滤
|
|
const handleDocumentFilter = (docIds: string[]) => {
|
|
setSelectedDocIds(docIds);
|
|
setValue('doc_ids', docIds);
|
|
handleTestSubmitFunc(getValues(), true);
|
|
};
|
|
|
|
// 返回详情页
|
|
const handleBackToDetail = () => {
|
|
navigate(`/knowledge/${id}`);
|
|
};
|
|
|
|
if (detailLoading) {
|
|
return (
|
|
<Container maxWidth="lg" sx={{ py: 4 }}>
|
|
<Typography>{t('common.loading')}</Typography>
|
|
</Container>
|
|
);
|
|
}
|
|
|
|
return (
|
|
<Container maxWidth="lg" sx={{ py: 4 }}>
|
|
|
|
{/* 面包屑导航 */}
|
|
<KnowledgeBreadcrumbs
|
|
kbItems={[
|
|
{
|
|
label: t('knowledgeTesting.knowledgeBase'),
|
|
path: '/knowledge'
|
|
},
|
|
{
|
|
label: knowledgeDetail?.name || t('knowledgeTesting.knowledgeBaseDetail'),
|
|
path: `/knowledge/${id}`
|
|
},
|
|
{
|
|
label: t('knowledgeTesting.testing')
|
|
}
|
|
]}
|
|
/>
|
|
|
|
<Box sx={{ mb: 3 }}>
|
|
<Typography variant="h4" gutterBottom>
|
|
{t('knowledgeTesting.knowledgeBaseTesting')}
|
|
</Typography>
|
|
<Typography variant="subtitle1" color="text.secondary">
|
|
{knowledgeDetail?.name}
|
|
</Typography>
|
|
</Box>
|
|
|
|
<Grid container spacing={3} direction="row">
|
|
{/* 测试表单 */}
|
|
<Grid size={4}>
|
|
<Paper sx={{ p: 3, position: 'sticky', top: 20 }}>
|
|
<Typography variant="h6" gutterBottom>
|
|
{t('knowledgeTesting.testConfiguration')}
|
|
</Typography>
|
|
|
|
<Box component="form" onSubmit={handleSubmit(handleTestSubmit)} sx={{ mt: 2 }}>
|
|
<TextField
|
|
{...register('question', { required: t('knowledgeTesting.pleaseEnterTestQuestion') })}
|
|
label={t('knowledgeTesting.testQuestion')}
|
|
multiline
|
|
rows={3}
|
|
fullWidth
|
|
margin="normal"
|
|
error={!!errors.question}
|
|
helperText={errors.question?.message}
|
|
placeholder={t('knowledgeTesting.testQuestionPlaceholder')}
|
|
/>
|
|
|
|
<Box sx={{ mt: 3 }}>
|
|
<Typography gutterBottom>
|
|
{t('knowledgeTesting.similarityThreshold')}: {watch('similarity_threshold')}
|
|
</Typography>
|
|
<Slider
|
|
{...register('similarity_threshold')}
|
|
value={watch('similarity_threshold')}
|
|
onChange={(_, value) => setValue('similarity_threshold', value as number)}
|
|
min={0}
|
|
max={1}
|
|
step={0.1}
|
|
marks
|
|
valueLabelDisplay="auto"
|
|
/>
|
|
</Box>
|
|
|
|
<Box sx={{ mt: 3 }}>
|
|
<Typography gutterBottom>
|
|
{t('knowledgeTesting.vectorSimilarityWeight')}: {watch('vector_similarity_weight')}
|
|
</Typography>
|
|
<Slider
|
|
{...register('vector_similarity_weight')}
|
|
value={watch('vector_similarity_weight')}
|
|
onChange={(_, value) => setValue('vector_similarity_weight', value as number)}
|
|
min={0}
|
|
max={1}
|
|
step={0.1}
|
|
marks
|
|
valueLabelDisplay="auto"
|
|
/>
|
|
</Box>
|
|
|
|
<FormControl fullWidth margin="normal">
|
|
<InputLabel>{t('knowledgeTesting.rerankModel')}</InputLabel>
|
|
<Controller
|
|
name="rerank_id"
|
|
control={control}
|
|
render={({ field }) => (
|
|
<Select
|
|
{...field}
|
|
label={t('knowledgeTesting.rerankModel')}
|
|
disabled={rerankLoading}
|
|
>
|
|
<MenuItem value="">
|
|
<em>{t('knowledgeTesting.noRerank')}</em>
|
|
</MenuItem>
|
|
{rerankOptions.map((group) => [
|
|
<ListSubheader key={group.label}>{group.label}</ListSubheader>,
|
|
...group.options.map((option) => (
|
|
<MenuItem key={option.value} value={option.value} disabled={option.disabled}>
|
|
{option.label}
|
|
</MenuItem>
|
|
))
|
|
])}
|
|
</Select>
|
|
)}
|
|
/>
|
|
</FormControl>
|
|
|
|
{/* Top-K 字段 - 只有选择了rerank_id时才显示 */}
|
|
{watch('rerank_id') && (
|
|
<TextField
|
|
{...register('top_k', {
|
|
required: t('knowledgeTesting.pleaseEnterResultCount'),
|
|
min: { value: 1, message: t('knowledgeTesting.minValue1') },
|
|
max: { value: 2048, message: t('knowledgeTesting.maxValue2048') }
|
|
})}
|
|
label="Top-K"
|
|
type="number"
|
|
fullWidth
|
|
margin="normal"
|
|
inputProps={{ min: 1, max: 2048 }}
|
|
error={!!errors.top_k}
|
|
helperText={errors.top_k?.message || t('knowledgeTesting.useWithRerankModel')}
|
|
/>
|
|
)}
|
|
|
|
<FormControl fullWidth margin="normal">
|
|
<InputLabel>{t('knowledgeTesting.crossLanguageSearch')}</InputLabel>
|
|
<Controller
|
|
name="cross_languages"
|
|
control={control}
|
|
render={({ field }) => (
|
|
<Select
|
|
{...field}
|
|
multiple
|
|
label={t('knowledgeTesting.crossLanguageSearch')}
|
|
input={<OutlinedInput label={t('knowledgeTesting.crossLanguageSearch')} />}
|
|
renderValue={(selected) => (
|
|
<Box sx={{ display: 'flex', flexWrap: 'wrap', gap: 0.5 }}>
|
|
{selected.map((value) => {
|
|
const option = options.find(opt => opt.value === value);
|
|
return (
|
|
<Chip key={value} label={option?.label || value} size="small" />
|
|
);
|
|
})}
|
|
</Box>
|
|
)}
|
|
>
|
|
{options.map((option) => (
|
|
<MenuItem key={option.value} value={option.value}>
|
|
<Checkbox checked={(watch('cross_languages') ?? []).indexOf(option.value) > -1} />
|
|
<ListItemText primary={option.label} />
|
|
</MenuItem>
|
|
))}
|
|
</Select>
|
|
)}
|
|
/>
|
|
</FormControl>
|
|
|
|
<FormControlLabel
|
|
control={
|
|
<Switch
|
|
{...register('use_kg')}
|
|
checked={watch('use_kg')}
|
|
/>
|
|
}
|
|
label={t('knowledgeTesting.useKnowledgeGraph')}
|
|
/>
|
|
|
|
<Button
|
|
type="submit"
|
|
variant="contained"
|
|
fullWidth
|
|
disabled={testing}
|
|
sx={{ mt: 2 }}
|
|
>
|
|
{testing ? t('knowledgeTesting.testing') : t('knowledgeTesting.startTest')}
|
|
</Button>
|
|
</Box>
|
|
</Paper>
|
|
</Grid>
|
|
<Grid size={8}>
|
|
{testResult && (
|
|
<TestChunkResult
|
|
result={testResult}
|
|
loading={testing}
|
|
page={page}
|
|
pageSize={pageSize}
|
|
onDocumentFilter={handleDocumentFilter}
|
|
selectedDocIds={selectedDocIds}
|
|
/>
|
|
)}
|
|
|
|
{/* 分页组件 */}
|
|
{testResult && testResult.total > 10 && (
|
|
<Box sx={{ display: 'flex', justifyContent: 'center', mt: 3 }}>
|
|
<Pagination
|
|
count={Math.ceil(testResult.total / 10)}
|
|
page={page}
|
|
onChange={handlePageChange}
|
|
color="primary"
|
|
size="large"
|
|
showFirstButton
|
|
showLastButton
|
|
/>
|
|
</Box>
|
|
)}
|
|
</Grid>
|
|
</Grid>
|
|
</Container>
|
|
);
|
|
};
|
|
|
|
export default KnowledgeBaseTesting; |