first commit
This commit is contained in:
5
rag_eval/dataset_builder/generator/__init__.py
Normal file
5
rag_eval/dataset_builder/generator/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Question generation components for draft online datasets."""
|
||||
|
||||
from .question_generator import OpenAIQuestionGenerator, QuestionGenerator
|
||||
|
||||
__all__ = ["OpenAIQuestionGenerator", "QuestionGenerator"]
|
||||
173
rag_eval/dataset_builder/generator/question_generator.py
Normal file
173
rag_eval/dataset_builder/generator/question_generator.py
Normal file
@@ -0,0 +1,173 @@
|
||||
"""LLM-backed question generator for dataset build jobs."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from rag_eval.dataset_builder.models import DraftQuestionSample, ParsedDocument, SourceChunk
|
||||
from rag_eval.settings import EvaluationSettings
|
||||
|
||||
|
||||
class QuestionGenerator(ABC):
|
||||
"""Abstract interface for generating draft questions from parsed documents."""
|
||||
|
||||
@abstractmethod
|
||||
def generate(
|
||||
self,
|
||||
document: ParsedDocument,
|
||||
*,
|
||||
max_questions: int,
|
||||
max_chunks_per_question: int,
|
||||
job_name: str,
|
||||
) -> list[DraftQuestionSample]:
|
||||
"""Generate draft question samples for one parsed document."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class OpenAIQuestionGenerator(QuestionGenerator):
|
||||
"""Generate draft questions with an OpenAI-compatible chat completion API."""
|
||||
|
||||
def __init__(self, settings: EvaluationSettings, model: str, client: OpenAI | None = None):
|
||||
"""Initialize the OpenAI-compatible client and target generation model."""
|
||||
if not settings.openai_api_key:
|
||||
raise EnvironmentError("OPENAI_API_KEY must be set before generating draft questions.")
|
||||
self.client = client or OpenAI(**settings.openai_client_kwargs)
|
||||
self.model = model
|
||||
|
||||
def _build_prompt(
|
||||
self,
|
||||
document: ParsedDocument,
|
||||
*,
|
||||
max_questions: int,
|
||||
max_chunks_per_question: int,
|
||||
) -> str:
|
||||
"""Build a constrained JSON-generation prompt for one document."""
|
||||
chunk_lines: list[str] = []
|
||||
for chunk in document.source_chunks:
|
||||
chunk_lines.append(
|
||||
json.dumps(
|
||||
{
|
||||
"chunk_id": chunk.chunk_id,
|
||||
"section_path": chunk.section_path,
|
||||
"page_start": chunk.page_start,
|
||||
"page_end": chunk.page_end,
|
||||
"text": chunk.text,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
)
|
||||
|
||||
instructions = {
|
||||
"task": "Generate reviewable online evaluation draft questions from one document only.",
|
||||
"rules": [
|
||||
"Return JSON only.",
|
||||
f"Generate at most {max_questions} samples.",
|
||||
f"Each sample may cite at most {max_chunks_per_question} chunk ids.",
|
||||
"Every sample must stay within this document and use existing chunk ids only.",
|
||||
"Allowed question_type values: fact, summary, procedure, comparison.",
|
||||
"Allowed difficulty values: easy, medium, hard.",
|
||||
],
|
||||
"output_schema": {
|
||||
"samples": [
|
||||
{
|
||||
"question": "string",
|
||||
"ground_truth": "string",
|
||||
"source_chunk_ids": ["chunk-id"],
|
||||
"question_type": "fact|summary|procedure|comparison",
|
||||
"difficulty": "easy|medium|hard",
|
||||
}
|
||||
]
|
||||
},
|
||||
"document": {
|
||||
"doc_id": document.doc_id,
|
||||
"doc_name": document.doc_name,
|
||||
"chunks": chunk_lines,
|
||||
},
|
||||
}
|
||||
return json.dumps(instructions, ensure_ascii=False, indent=2)
|
||||
|
||||
def _build_sample(
|
||||
self,
|
||||
*,
|
||||
document: ParsedDocument,
|
||||
payload: dict[str, Any],
|
||||
index: int,
|
||||
job_name: str,
|
||||
) -> DraftQuestionSample:
|
||||
"""Convert one model output object into the internal draft sample model."""
|
||||
chunk_lookup: dict[str, SourceChunk] = {item.chunk_id: item for item in document.source_chunks}
|
||||
source_chunk_ids = [str(item).strip() for item in payload.get("source_chunk_ids") or [] if str(item).strip()]
|
||||
chunks = [chunk_lookup[item] for item in source_chunk_ids if item in chunk_lookup]
|
||||
|
||||
section_path = chunks[0].section_path if chunks else ""
|
||||
page_start = min((chunk.page_start for chunk in chunks), default=0)
|
||||
page_end = max((chunk.page_end for chunk in chunks), default=0)
|
||||
language = "zh" if any("\u4e00" <= char <= "\u9fff" for char in payload.get("question", "")) else "en"
|
||||
return DraftQuestionSample(
|
||||
sample_id=f"{document.doc_id}-q{index}",
|
||||
question=str(payload.get("question", "")).strip(),
|
||||
ground_truth=str(payload.get("ground_truth", "")).strip(),
|
||||
scenario=job_name,
|
||||
language=language,
|
||||
doc_id=document.doc_id,
|
||||
doc_name=document.doc_name,
|
||||
section_path=section_path,
|
||||
page_start=page_start,
|
||||
page_end=page_end,
|
||||
source_chunk_ids=source_chunk_ids,
|
||||
question_type=str(payload.get("question_type", "fact")).strip() or "fact",
|
||||
difficulty=str(payload.get("difficulty", "medium")).strip() or "medium",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _parse_response_payload(content: str) -> list[dict[str, Any]]:
|
||||
"""Parse the model response into a list of sample payload dictionaries."""
|
||||
try:
|
||||
payload = json.loads(content or "{}")
|
||||
except json.JSONDecodeError as exc:
|
||||
raise ValueError("Question generator returned invalid JSON.") from exc
|
||||
|
||||
if not isinstance(payload, dict):
|
||||
raise ValueError("Question generator response must be a JSON object.")
|
||||
samples = payload.get("samples") or []
|
||||
if not isinstance(samples, list):
|
||||
raise ValueError("Question generator response field 'samples' must be a list.")
|
||||
|
||||
normalized_samples: list[dict[str, Any]] = []
|
||||
for item in samples:
|
||||
if isinstance(item, dict):
|
||||
normalized_samples.append(item)
|
||||
return normalized_samples
|
||||
|
||||
def generate(
|
||||
self,
|
||||
document: ParsedDocument,
|
||||
*,
|
||||
max_questions: int,
|
||||
max_chunks_per_question: int,
|
||||
job_name: str,
|
||||
) -> list[DraftQuestionSample]:
|
||||
"""Generate draft questions for one parsed document."""
|
||||
prompt = self._build_prompt(
|
||||
document,
|
||||
max_questions=max_questions,
|
||||
max_chunks_per_question=max_chunks_per_question,
|
||||
)
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{"role": "system", "content": "You generate structured draft question banks from source documents."},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
response_format={"type": "json_object"},
|
||||
)
|
||||
content = response.choices[0].message.content or "{}"
|
||||
payload = self._parse_response_payload(content)
|
||||
return [
|
||||
self._build_sample(document=document, payload=item, index=index, job_name=job_name)
|
||||
for index, item in enumerate(payload[:max_questions], start=1)
|
||||
]
|
||||
87
rag_eval/dataset_builder/generator/validators.py
Normal file
87
rag_eval/dataset_builder/generator/validators.py
Normal file
@@ -0,0 +1,87 @@
|
||||
"""Validation and deduplication helpers for generated draft question samples."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from difflib import SequenceMatcher
|
||||
|
||||
from rag_eval.dataset_builder.models import DraftQuestionSample, ParsedDocument
|
||||
|
||||
|
||||
ALLOWED_QUESTION_TYPES = {"fact", "summary", "procedure", "comparison"}
|
||||
ALLOWED_DIFFICULTIES = {"easy", "medium", "hard"}
|
||||
|
||||
|
||||
def validate_draft_sample(
|
||||
sample: DraftQuestionSample,
|
||||
*,
|
||||
document: ParsedDocument,
|
||||
max_source_chunks_per_question: int | None = None,
|
||||
) -> list[str]:
|
||||
"""Validate one generated sample against the document and enum constraints."""
|
||||
errors: list[str] = []
|
||||
if not sample.question.strip():
|
||||
errors.append("question is empty")
|
||||
if not sample.ground_truth.strip():
|
||||
errors.append("ground_truth is empty")
|
||||
if not sample.source_chunk_ids:
|
||||
errors.append("source_chunk_ids is empty")
|
||||
if (
|
||||
max_source_chunks_per_question is not None
|
||||
and len(sample.source_chunk_ids) > max_source_chunks_per_question
|
||||
):
|
||||
errors.append(
|
||||
f"source_chunk_ids exceeds limit: {len(sample.source_chunk_ids)} > {max_source_chunks_per_question}"
|
||||
)
|
||||
|
||||
existing_chunk_ids = {chunk.chunk_id for chunk in document.source_chunks}
|
||||
for chunk_id in sample.source_chunk_ids:
|
||||
if chunk_id not in existing_chunk_ids:
|
||||
errors.append(f"unknown source chunk: {chunk_id}")
|
||||
|
||||
if sample.doc_id != document.doc_id:
|
||||
errors.append("sample doc_id does not match source document")
|
||||
if sample.question_type not in ALLOWED_QUESTION_TYPES:
|
||||
errors.append(f"unsupported question_type: {sample.question_type}")
|
||||
if sample.difficulty not in ALLOWED_DIFFICULTIES:
|
||||
errors.append(f"unsupported difficulty: {sample.difficulty}")
|
||||
return errors
|
||||
|
||||
|
||||
def normalize_question_text(text: str) -> str:
|
||||
"""Normalize question text for exact-match deduplication."""
|
||||
return re.sub(r"\s+", " ", text).strip().lower()
|
||||
|
||||
|
||||
def dedupe_samples(samples: list[DraftQuestionSample]) -> list[DraftQuestionSample]:
|
||||
"""Drop duplicate questions and enforce one output per chunk group per document."""
|
||||
deduped: list[DraftQuestionSample] = []
|
||||
seen_questions: set[tuple[str, str]] = set()
|
||||
seen_chunk_groups: set[tuple[str, tuple[str, ...]]] = set()
|
||||
seen_chunk_answers: list[tuple[str, tuple[str, ...], str]] = []
|
||||
|
||||
for sample in samples:
|
||||
question_key = (sample.doc_id, normalize_question_text(sample.question))
|
||||
if question_key in seen_questions:
|
||||
continue
|
||||
|
||||
chunk_key = tuple(sample.source_chunk_ids)
|
||||
chunk_group_key = (sample.doc_id, chunk_key)
|
||||
if chunk_group_key in seen_chunk_groups:
|
||||
continue
|
||||
answer_key = normalize_question_text(sample.ground_truth)
|
||||
duplicate = False
|
||||
for existing_doc_id, existing_chunk_key, existing_answer in seen_chunk_answers:
|
||||
if existing_doc_id != sample.doc_id or existing_chunk_key != chunk_key:
|
||||
continue
|
||||
if SequenceMatcher(None, existing_answer, answer_key).ratio() >= 0.9:
|
||||
duplicate = True
|
||||
break
|
||||
if duplicate:
|
||||
continue
|
||||
|
||||
seen_questions.add(question_key)
|
||||
seen_chunk_groups.add(chunk_group_key)
|
||||
seen_chunk_answers.append((sample.doc_id, chunk_key, answer_key))
|
||||
deduped.append(sample)
|
||||
return deduped
|
||||
Reference in New Issue
Block a user