171 lines
6.1 KiB
Python
171 lines
6.1 KiB
Python
|
|
"""Online evaluation adapter for the Siemens medical-imaging PDF question bank.
|
|||
|
|
|
|||
|
|
Functionally identical to apps/pdf_question_bank/adapter.py but uses a
|
|||
|
|
Siemens-specific system prompt that:
|
|||
|
|
- Instructs the model to answer in the same language as the question
|
|||
|
|
(important for Chinese CT documentation).
|
|||
|
|
- Emphasises citation of source chunks and refusal when evidence is absent.
|
|||
|
|
- Adds domain context (medical imaging / CT terminology).
|
|||
|
|
|
|||
|
|
The adapter contract is the same as all other adapters:
|
|||
|
|
run(question, **kwargs) -> {"answer": str, "contexts": [str], "raw_response": {}}
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
from __future__ import annotations
|
|||
|
|
|
|||
|
|
import json
|
|||
|
|
from pathlib import Path
|
|||
|
|
from typing import Any
|
|||
|
|
|
|||
|
|
from openai import OpenAI
|
|||
|
|
|
|||
|
|
from rag_eval.settings import EvaluationSettings
|
|||
|
|
from rag_eval.shared.utils import parse_contexts
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ── chunk cache (module-level, lives for the process lifetime) ────────────────
|
|||
|
|
_CHUNK_CACHE: dict[Path, dict[str, dict[str, Any]]] = {}
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _resolve_source_chunks_path(source_chunks_path: str) -> Path:
|
|||
|
|
"""Resolve the chunk artifact path; fall back to the latest timestamped run."""
|
|||
|
|
resolved = Path(source_chunks_path).resolve()
|
|||
|
|
if resolved.exists():
|
|||
|
|
return resolved
|
|||
|
|
if resolved.parent.name != "latest":
|
|||
|
|
raise FileNotFoundError(resolved)
|
|||
|
|
artifact_root = resolved.parent.parent
|
|||
|
|
if not artifact_root.exists():
|
|||
|
|
raise FileNotFoundError(resolved)
|
|||
|
|
candidates = sorted(
|
|||
|
|
[d for d in artifact_root.iterdir() if d.is_dir() and d.name != "latest"],
|
|||
|
|
key=lambda p: p.name,
|
|||
|
|
reverse=True,
|
|||
|
|
)
|
|||
|
|
for run_dir in candidates:
|
|||
|
|
candidate = run_dir / resolved.name
|
|||
|
|
if candidate.exists():
|
|||
|
|
return candidate
|
|||
|
|
raise FileNotFoundError(resolved)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _load_source_chunks(source_chunks_path: str) -> dict[str, dict[str, Any]]:
|
|||
|
|
"""Load and cache source chunks by chunk_id."""
|
|||
|
|
resolved = _resolve_source_chunks_path(source_chunks_path)
|
|||
|
|
cached = _CHUNK_CACHE.get(resolved)
|
|||
|
|
if cached is not None:
|
|||
|
|
return cached
|
|||
|
|
lookup: dict[str, dict[str, Any]] = {}
|
|||
|
|
with resolved.open(encoding="utf-8") as fh:
|
|||
|
|
for lineno, line in enumerate(fh, 1):
|
|||
|
|
text = line.strip()
|
|||
|
|
if not text:
|
|||
|
|
continue
|
|||
|
|
payload = json.loads(text)
|
|||
|
|
chunk_id = str(payload.get("chunk_id", "")).strip()
|
|||
|
|
if not chunk_id:
|
|||
|
|
raise ValueError(f"source_chunks.jsonl row {lineno} missing chunk_id: {resolved}")
|
|||
|
|
lookup[chunk_id] = payload
|
|||
|
|
_CHUNK_CACHE[resolved] = lookup
|
|||
|
|
return lookup
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _resolve_chunk_ids(raw: Any) -> list[str]:
|
|||
|
|
"""Parse the source_chunk_ids column into a list of non-empty id strings."""
|
|||
|
|
ids = parse_contexts(raw)
|
|||
|
|
normalized = [i for i in ids if i]
|
|||
|
|
if not normalized:
|
|||
|
|
raise ValueError("source_chunk_ids is required for siemens_pdf_qa adapter.")
|
|||
|
|
return normalized
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _build_messages(
|
|||
|
|
question: str,
|
|||
|
|
contexts: list[str],
|
|||
|
|
metadata: dict[str, Any],
|
|||
|
|
) -> list[dict[str, str]]:
|
|||
|
|
"""Build a Siemens-domain grounded prompt for the answer model."""
|
|||
|
|
evidence_lines = [f"[chunk {i}] {ctx}" for i, ctx in enumerate(contexts, 1)]
|
|||
|
|
meta_lines = [
|
|||
|
|
f"doc_name: {metadata.get('doc_name', '')}",
|
|||
|
|
f"section_path: {metadata.get('section_path', '')}",
|
|||
|
|
f"page_range: {metadata.get('page_start', '')}–{metadata.get('page_end', '')}",
|
|||
|
|
]
|
|||
|
|
# Siemens-specific system prompt: bilingual awareness, medical domain, strict grounding
|
|||
|
|
system_prompt = (
|
|||
|
|
"你是西门子医疗影像知识库的问答助手(Siemens Healthineers CT Knowledge Base QA)。"
|
|||
|
|
"请严格根据下方【证据片段】回答问题,不得使用片段之外的任何知识。"
|
|||
|
|
"若证据不足以回答,请明确说明「根据现有资料无法回答」。"
|
|||
|
|
"请用与问题相同的语言(中文或英文)作答,简洁准确,必要时引用片段编号。"
|
|||
|
|
)
|
|||
|
|
user_prompt = "\n".join([
|
|||
|
|
"【问题】",
|
|||
|
|
question,
|
|||
|
|
"",
|
|||
|
|
"【文档元信息】",
|
|||
|
|
*meta_lines,
|
|||
|
|
"",
|
|||
|
|
"【证据片段】",
|
|||
|
|
*evidence_lines,
|
|||
|
|
"",
|
|||
|
|
"请基于以上证据片段作答。",
|
|||
|
|
])
|
|||
|
|
return [
|
|||
|
|
{"role": "system", "content": system_prompt},
|
|||
|
|
{"role": "user", "content": user_prompt},
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
|
|||
|
|
def run(
|
|||
|
|
question: str,
|
|||
|
|
*,
|
|||
|
|
source_chunks_path: str,
|
|||
|
|
model: str | None = None,
|
|||
|
|
client: OpenAI | None = None,
|
|||
|
|
**kwargs: Any,
|
|||
|
|
) -> dict[str, Any]:
|
|||
|
|
"""Answer one question by resolving cited chunks and calling an OpenAI-compatible model.
|
|||
|
|
|
|||
|
|
This is the adapter contract entry point used by the online evaluation runner.
|
|||
|
|
"""
|
|||
|
|
chunk_ids = _resolve_chunk_ids(kwargs.get("source_chunk_ids"))
|
|||
|
|
chunk_lookup = _load_source_chunks(source_chunks_path)
|
|||
|
|
|
|||
|
|
missing = [cid for cid in chunk_ids if cid not in chunk_lookup]
|
|||
|
|
if missing:
|
|||
|
|
raise ValueError("source_chunk_ids not found in artifact: " + ", ".join(missing))
|
|||
|
|
|
|||
|
|
resolved_chunks = [chunk_lookup[cid] for cid in chunk_ids]
|
|||
|
|
contexts = [
|
|||
|
|
str(chunk.get("text", "")).strip()
|
|||
|
|
for chunk in resolved_chunks
|
|||
|
|
if str(chunk.get("text", "")).strip()
|
|||
|
|
]
|
|||
|
|
if not contexts:
|
|||
|
|
raise ValueError("resolved source chunks contain no usable text.")
|
|||
|
|
|
|||
|
|
settings = EvaluationSettings()
|
|||
|
|
target_model = (model or settings.ragas_judge_model).strip()
|
|||
|
|
if not target_model:
|
|||
|
|
raise ValueError("A model name is required for siemens_pdf_qa adapter.")
|
|||
|
|
|
|||
|
|
llm_client = client or OpenAI(**settings.openai_client_kwargs)
|
|||
|
|
completion = llm_client.chat.completions.create(
|
|||
|
|
model=target_model,
|
|||
|
|
messages=_build_messages(question, contexts, kwargs),
|
|||
|
|
temperature=0,
|
|||
|
|
)
|
|||
|
|
answer = str(completion.choices[0].message.content or "").strip()
|
|||
|
|
|
|||
|
|
return {
|
|||
|
|
"answer": answer,
|
|||
|
|
"contexts": contexts,
|
|||
|
|
"raw_response": {
|
|||
|
|
"resolved_chunk_ids": chunk_ids,
|
|||
|
|
"doc_id": kwargs.get("doc_id", ""),
|
|||
|
|
"doc_name": kwargs.get("doc_name", ""),
|
|||
|
|
"model": target_model,
|
|||
|
|
"response_text": answer,
|
|||
|
|
},
|
|||
|
|
}
|