Files
TERES_web_frontend/src/pages/knowledge/testing.tsx

444 lines
15 KiB
TypeScript

import React, { useEffect, useState } from 'react';
import { useParams, useNavigate } from 'react-router-dom';
import { useForm, Controller } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import {
Box,
Container,
Typography,
Paper,
TextField,
Button,
Slider,
FormControl,
InputLabel,
Select,
MenuItem,
FormControlLabel,
Switch,
Grid,
Pagination,
Checkbox,
ListItemText,
OutlinedInput,
ListSubheader,
Chip,
} from '@mui/material';
import { useKnowledgeDetail } from '@/hooks/knowledge-hooks';
import { useRerankModelOptions } from '@/hooks/llm-hooks';
import knowledgeService from '@/services/knowledge_service';
import type { ITestRetrievalRequestBody } from '@/interfaces/request/knowledge';
import type { INextTestingResult } from '@/interfaces/database/knowledge';
import KnowledgeBreadcrumbs from './components/KnowledgeBreadcrumbs';
import TestChunkResult from './components/TestChunkResult';
import { useSnackbar } from '@/components/Provider/SnackbarProvider';
import { toLower } from 'lodash';
import { t } from 'i18next';
// 语言选项常量
const options = [
{ value: 'en', label: t('knowledgeTesting.languages.english') },
{ value: 'zh', label: t('knowledgeTesting.languages.chinese') },
{ value: 'ja', label: t('knowledgeTesting.languages.japanese') },
{ value: 'ko', label: t('knowledgeTesting.languages.korean') },
{ value: 'fr', label: t('knowledgeTesting.languages.french') },
{ value: 'de', label: t('knowledgeTesting.languages.german') },
{ value: 'es', label: t('knowledgeTesting.languages.spanish') },
{ value: 'vi', label: t('knowledgeTesting.languages.vietnamese') },
];
// 表单数据接口
interface TestFormData {
question: string;
similarity_threshold: number;
vector_similarity_weight: number;
rerank_id?: string;
top_k?: number;
use_kg?: boolean;
cross_languages?: string[];
doc_ids?: string[];
}
function KnowledgeBaseTesting() {
const { t } = useTranslation();
const { id } = useParams<{ id: string }>();
const navigate = useNavigate();
// 状态管理
const [testResult, setTestResult] = useState<INextTestingResult | null>(null);
const [testing, setTesting] = useState(false);
const [page, setPage] = useState(1);
const [pageSize] = useState(10);
const [selectedDocIds, setSelectedDocIds] = useState<string[]>([]);
const { showMessage } = useSnackbar();
// 获取知识库详情
const { knowledge: knowledgeDetail, loading: detailLoading } = useKnowledgeDetail(id || '');
// 获取重排序模型选项
const { options: rerankOptions, loading: rerankLoading } = useRerankModelOptions();
// 表单配置
const { control, handleSubmit, watch, register, setValue, getValues, formState: { errors } } = useForm<TestFormData>({
defaultValues: {
question: '',
similarity_threshold: 0.2,
vector_similarity_weight: 0.3,
rerank_id: '',
top_k: 1024,
use_kg: false,
cross_languages: [],
doc_ids: [],
},
});
// 处理测试提交
const handleTestSubmit = async (data: TestFormData) => {
return handleTestSubmitFunc(data, false);
}
const handleTestSubmitFunc = async (data: TestFormData, withSelectedDocs: boolean = false) => {
if (!id) return;
setTesting(true);
try {
const requestBody: ITestRetrievalRequestBody = {
question: data.question,
similarity_threshold: data.similarity_threshold,
vector_similarity_weight: data.vector_similarity_weight,
kb_id: id,
page: page,
size: pageSize,
};
// 只有当字段有值时才添加到请求体中
if (data.rerank_id) {
requestBody.rerank_id = data.rerank_id;
}
if (data.top_k) {
requestBody.top_k = data.top_k;
}
if (data.use_kg !== undefined) {
requestBody.use_kg = data.use_kg;
}
// 如果有选择的文档,添加到请求中
if (withSelectedDocs) {
const doc_ids = data.doc_ids || [];
if (doc_ids.length > 0) {
requestBody.doc_ids = doc_ids;
}
} else {
if (selectedDocIds.length > 0) {
requestBody.doc_ids = selectedDocIds;
}
}
if (data.cross_languages && data.cross_languages.length > 0) {
requestBody.cross_languages = data.cross_languages;
}
const response = await knowledgeService.retrievalTest(requestBody);
if (response.data.code === 0) {
setTestResult(response.data.data);
setPage(1); // 重置到第一页
showMessage.success(t('knowledgeTesting.retrievalTestComplete'));
} else {
throw new Error(response.data.message || t('knowledgeTesting.retrievalTestFailed'));
}
} catch (error: any) {
showMessage.error(error.message || t('knowledgeTesting.retrievalTestFailed'));
} finally {
setTesting(false);
}
};
// 处理分页变化
const handlePageChange = async (event: React.ChangeEvent<unknown>, value: number) => {
if (!id) return;
setPage(value);
setTesting(true);
try {
const formData = getValues();
const requestBody: ITestRetrievalRequestBody = {
question: formData.question,
similarity_threshold: formData.similarity_threshold,
vector_similarity_weight: formData.vector_similarity_weight,
kb_id: id,
page: value,
size: pageSize,
highlight: true,
};
// 只有当字段有值时才添加到请求体中
if (formData.rerank_id) {
requestBody.rerank_id = formData.rerank_id;
}
if (formData.top_k) {
requestBody.top_k = formData.top_k;
}
if (formData.use_kg !== undefined) {
requestBody.use_kg = formData.use_kg;
}
if (selectedDocIds.length > 0) {
requestBody.doc_ids = selectedDocIds;
}
if (formData.cross_languages && formData.cross_languages.length > 0) {
requestBody.cross_languages = formData.cross_languages;
}
const response = await knowledgeService.retrievalTest(requestBody);
if (response.data.code === 0) {
setTestResult(response.data.data);
} else {
throw new Error(response.data.message || t('knowledgeTesting.paginationRequestFailed'));
}
} catch (error: any) {
showMessage.error(error.message || t('knowledgeTesting.paginationRequestFailed'));
} finally {
setTesting(false);
}
};
// 处理文档过滤
const handleDocumentFilter = (docIds: string[]) => {
setSelectedDocIds(docIds);
setValue('doc_ids', docIds);
handleTestSubmitFunc(getValues(), true);
};
// 返回详情页
const handleBackToDetail = () => {
navigate(`/knowledge/${id}`);
};
if (detailLoading) {
return (
<Container maxWidth="lg" sx={{ py: 4 }}>
<Typography>{t('common.loading')}</Typography>
</Container>
);
}
return (
<Container maxWidth="lg" sx={{ py: 4 }}>
{/* 面包屑导航 */}
<KnowledgeBreadcrumbs
kbItems={[
{
label: t('knowledgeTesting.knowledgeBase'),
path: '/knowledge'
},
{
label: knowledgeDetail?.name || t('knowledgeTesting.knowledgeBaseDetail'),
path: `/knowledge/${id}`
},
{
label: t('knowledgeTesting.testing')
}
]}
/>
<Box sx={{ mb: 3 }}>
<Typography variant="h4" gutterBottom>
{t('knowledgeTesting.knowledgeBaseTesting')}
</Typography>
<Typography variant="subtitle1" color="text.secondary">
{knowledgeDetail?.name}
</Typography>
</Box>
<Grid container spacing={3} direction="row">
{/* 测试表单 */}
<Grid size={4}>
<Paper sx={{ p: 3, position: 'sticky', top: 20 }}>
<Typography variant="h6" gutterBottom>
{t('knowledgeTesting.testConfiguration')}
</Typography>
<Box component="form" onSubmit={handleSubmit(handleTestSubmit)} sx={{ mt: 2 }}>
<TextField
{...register('question', { required: t('knowledgeTesting.pleaseEnterTestQuestion') })}
label={t('knowledgeTesting.testQuestion')}
multiline
rows={3}
fullWidth
margin="normal"
error={!!errors.question}
helperText={errors.question?.message}
placeholder={t('knowledgeTesting.testQuestionPlaceholder')}
/>
<Box sx={{ mt: 3 }}>
<Typography gutterBottom>
{t('knowledgeTesting.similarityThreshold')}: {watch('similarity_threshold')}
</Typography>
<Slider
{...register('similarity_threshold')}
value={watch('similarity_threshold')}
onChange={(_, value) => setValue('similarity_threshold', value as number)}
min={0}
max={1}
step={0.1}
marks
valueLabelDisplay="auto"
/>
</Box>
<Box sx={{ mt: 3 }}>
<Typography gutterBottom>
{t('knowledgeTesting.vectorSimilarityWeight')}: {watch('vector_similarity_weight')}
</Typography>
<Slider
{...register('vector_similarity_weight')}
value={watch('vector_similarity_weight')}
onChange={(_, value) => setValue('vector_similarity_weight', value as number)}
min={0}
max={1}
step={0.1}
marks
valueLabelDisplay="auto"
/>
</Box>
<FormControl fullWidth margin="normal">
<InputLabel>{t('knowledgeTesting.rerankModel')}</InputLabel>
<Controller
name="rerank_id"
control={control}
render={({ field }) => (
<Select
{...field}
label={t('knowledgeTesting.rerankModel')}
disabled={rerankLoading}
>
<MenuItem value="">
<em>{t('knowledgeTesting.noRerank')}</em>
</MenuItem>
{rerankOptions.map((group) => [
<ListSubheader key={group.label}>{group.label}</ListSubheader>,
...group.options.map((option) => (
<MenuItem key={option.value} value={option.value} disabled={option.disabled}>
{option.label}
</MenuItem>
))
])}
</Select>
)}
/>
</FormControl>
{/* Top-K 字段 - 只有选择了rerank_id时才显示 */}
{watch('rerank_id') && (
<TextField
{...register('top_k', {
required: t('knowledgeTesting.pleaseEnterResultCount'),
min: { value: 1, message: t('knowledgeTesting.minValue1') },
max: { value: 2048, message: t('knowledgeTesting.maxValue2048') }
})}
label="Top-K"
type="number"
fullWidth
margin="normal"
inputProps={{ min: 1, max: 2048 }}
error={!!errors.top_k}
helperText={errors.top_k?.message || t('knowledgeTesting.useWithRerankModel')}
/>
)}
<FormControl fullWidth margin="normal">
<InputLabel>{t('knowledgeTesting.crossLanguageSearch')}</InputLabel>
<Controller
name="cross_languages"
control={control}
render={({ field }) => (
<Select
{...field}
multiple
label={t('knowledgeTesting.crossLanguageSearch')}
input={<OutlinedInput label={t('knowledgeTesting.crossLanguageSearch')} />}
renderValue={(selected) => (
<Box sx={{ display: 'flex', flexWrap: 'wrap', gap: 0.5 }}>
{selected.map((value) => {
const option = options.find(opt => opt.value === value);
return (
<Chip key={value} label={option?.label || value} size="small" />
);
})}
</Box>
)}
>
{options.map((option) => (
<MenuItem key={option.value} value={option.value}>
<Checkbox checked={(watch('cross_languages') ?? []).indexOf(option.value) > -1} />
<ListItemText primary={option.label} />
</MenuItem>
))}
</Select>
)}
/>
</FormControl>
<FormControlLabel
control={
<Switch
{...register('use_kg')}
checked={watch('use_kg')}
/>
}
label={t('knowledgeTesting.useKnowledgeGraph')}
/>
<Button
type="submit"
variant="contained"
fullWidth
disabled={testing}
sx={{ mt: 2 }}
>
{testing ? t('knowledgeTesting.testing') : t('knowledgeTesting.startTest')}
</Button>
</Box>
</Paper>
</Grid>
<Grid size={8}>
{testResult && (
<TestChunkResult
result={testResult}
loading={testing}
page={page}
pageSize={pageSize}
onDocumentFilter={handleDocumentFilter}
selectedDocIds={selectedDocIds}
/>
)}
{/* 分页组件 */}
{testResult && testResult.total > 10 && (
<Box sx={{ display: 'flex', justifyContent: 'center', mt: 3 }}>
<Pagination
count={Math.ceil(testResult.total / 10)}
page={page}
onChange={handlePageChange}
color="primary"
size="large"
showFirstButton
showLastButton
/>
</Box>
)}
</Grid>
</Grid>
</Container>
);
};
export default KnowledgeBaseTesting;