Files

171 lines
6.1 KiB
Python
Raw Permalink Normal View History

"""Online evaluation adapter for the Siemens medical-imaging PDF question bank.
Functionally identical to apps/pdf_question_bank/adapter.py but uses a
Siemens-specific system prompt that:
- Instructs the model to answer in the same language as the question
(important for Chinese CT documentation).
- Emphasises citation of source chunks and refusal when evidence is absent.
- Adds domain context (medical imaging / CT terminology).
The adapter contract is the same as all other adapters:
run(question, **kwargs) -> {"answer": str, "contexts": [str], "raw_response": {}}
"""
from __future__ import annotations
import json
from pathlib import Path
from typing import Any
from openai import OpenAI
from rag_eval.settings import EvaluationSettings
from rag_eval.shared.utils import parse_contexts
# ── chunk cache (module-level, lives for the process lifetime) ────────────────
_CHUNK_CACHE: dict[Path, dict[str, dict[str, Any]]] = {}
def _resolve_source_chunks_path(source_chunks_path: str) -> Path:
"""Resolve the chunk artifact path; fall back to the latest timestamped run."""
resolved = Path(source_chunks_path).resolve()
if resolved.exists():
return resolved
if resolved.parent.name != "latest":
raise FileNotFoundError(resolved)
artifact_root = resolved.parent.parent
if not artifact_root.exists():
raise FileNotFoundError(resolved)
candidates = sorted(
[d for d in artifact_root.iterdir() if d.is_dir() and d.name != "latest"],
key=lambda p: p.name,
reverse=True,
)
for run_dir in candidates:
candidate = run_dir / resolved.name
if candidate.exists():
return candidate
raise FileNotFoundError(resolved)
def _load_source_chunks(source_chunks_path: str) -> dict[str, dict[str, Any]]:
"""Load and cache source chunks by chunk_id."""
resolved = _resolve_source_chunks_path(source_chunks_path)
cached = _CHUNK_CACHE.get(resolved)
if cached is not None:
return cached
lookup: dict[str, dict[str, Any]] = {}
with resolved.open(encoding="utf-8") as fh:
for lineno, line in enumerate(fh, 1):
text = line.strip()
if not text:
continue
payload = json.loads(text)
chunk_id = str(payload.get("chunk_id", "")).strip()
if not chunk_id:
raise ValueError(f"source_chunks.jsonl row {lineno} missing chunk_id: {resolved}")
lookup[chunk_id] = payload
_CHUNK_CACHE[resolved] = lookup
return lookup
def _resolve_chunk_ids(raw: Any) -> list[str]:
"""Parse the source_chunk_ids column into a list of non-empty id strings."""
ids = parse_contexts(raw)
normalized = [i for i in ids if i]
if not normalized:
raise ValueError("source_chunk_ids is required for siemens_pdf_qa adapter.")
return normalized
def _build_messages(
question: str,
contexts: list[str],
metadata: dict[str, Any],
) -> list[dict[str, str]]:
"""Build a Siemens-domain grounded prompt for the answer model."""
evidence_lines = [f"[chunk {i}] {ctx}" for i, ctx in enumerate(contexts, 1)]
meta_lines = [
f"doc_name: {metadata.get('doc_name', '')}",
f"section_path: {metadata.get('section_path', '')}",
f"page_range: {metadata.get('page_start', '')}{metadata.get('page_end', '')}",
]
# Siemens-specific system prompt: bilingual awareness, medical domain, strict grounding
system_prompt = (
"你是西门子医疗影像知识库的问答助手Siemens Healthineers CT Knowledge Base QA"
"请严格根据下方【证据片段】回答问题,不得使用片段之外的任何知识。"
"若证据不足以回答,请明确说明「根据现有资料无法回答」。"
"请用与问题相同的语言(中文或英文)作答,简洁准确,必要时引用片段编号。"
)
user_prompt = "\n".join([
"【问题】",
question,
"",
"【文档元信息】",
*meta_lines,
"",
"【证据片段】",
*evidence_lines,
"",
"请基于以上证据片段作答。",
])
return [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
def run(
question: str,
*,
source_chunks_path: str,
model: str | None = None,
client: OpenAI | None = None,
**kwargs: Any,
) -> dict[str, Any]:
"""Answer one question by resolving cited chunks and calling an OpenAI-compatible model.
This is the adapter contract entry point used by the online evaluation runner.
"""
chunk_ids = _resolve_chunk_ids(kwargs.get("source_chunk_ids"))
chunk_lookup = _load_source_chunks(source_chunks_path)
missing = [cid for cid in chunk_ids if cid not in chunk_lookup]
if missing:
raise ValueError("source_chunk_ids not found in artifact: " + ", ".join(missing))
resolved_chunks = [chunk_lookup[cid] for cid in chunk_ids]
contexts = [
str(chunk.get("text", "")).strip()
for chunk in resolved_chunks
if str(chunk.get("text", "")).strip()
]
if not contexts:
raise ValueError("resolved source chunks contain no usable text.")
settings = EvaluationSettings()
target_model = (model or settings.ragas_judge_model).strip()
if not target_model:
raise ValueError("A model name is required for siemens_pdf_qa adapter.")
llm_client = client or OpenAI(**settings.openai_client_kwargs)
completion = llm_client.chat.completions.create(
model=target_model,
messages=_build_messages(question, contexts, kwargs),
temperature=0,
)
answer = str(completion.choices[0].message.content or "").strip()
return {
"answer": answer,
"contexts": contexts,
"raw_response": {
"resolved_chunk_ids": chunk_ids,
"doc_id": kwargs.get("doc_id", ""),
"doc_name": kwargs.get("doc_name", ""),
"model": target_model,
"response_text": answer,
},
}