Files
siemens_ragas/webapp/services/task_manager.py

162 lines
6.0 KiB
Python
Raw Permalink Normal View History

"""In-process background task manager for evaluation runs.
Evaluations run in a thread pool so the FastAPI event loop is never blocked.
The heavy rag_eval / ragas import is performed lazily inside the worker thread,
which keeps the web server bootable even when the evaluation dependencies are
broken failures then surface as task errors in the UI instead of crashing
startup. This matches the "coarse status + logs" progress decision.
"""
from __future__ import annotations
import io
import threading
import uuid
from concurrent.futures import ThreadPoolExecutor
from contextlib import redirect_stderr, redirect_stdout
from datetime import datetime, timezone
from pathlib import Path
from webapp.models import TaskStatus
def _now_iso() -> str:
"""Return the current UTC time as an ISO 8601 string."""
return datetime.now(timezone.utc).isoformat()
class _LineCapture(io.TextIOBase):
"""A writable stream that appends captured lines to a task's log buffer."""
def __init__(self, sink: "EvaluationTask") -> None:
"""Bind the capture stream to the owning task."""
self._sink = sink
self._buffer = ""
def write(self, text: str) -> int:
"""Buffer text and flush complete lines into the task log."""
self._buffer += text
while "\n" in self._buffer:
line, self._buffer = self._buffer.split("\n", 1)
self._sink.append_log(line)
return len(text)
def flush(self) -> None:
"""Flush any trailing partial line into the task log."""
if self._buffer:
self._sink.append_log(self._buffer)
self._buffer = ""
class EvaluationTask:
"""Mutable state for a single background evaluation run."""
def __init__(self, task_id: str, scenario_path: str) -> None:
"""Initialize a queued task for the given scenario path."""
self.task_id = task_id
self.scenario_path = scenario_path
self.status = "queued"
self.logs: list[str] = []
self.run_id: str | None = None
self.error: str | None = None
self.created_at = _now_iso()
self.finished_at = ""
self._lock = threading.Lock()
def append_log(self, line: str) -> None:
"""Append one log line in a thread-safe manner."""
with self._lock:
self.logs.append(line)
def snapshot(self) -> TaskStatus:
"""Return an immutable copy of the current task state for the API."""
with self._lock:
return TaskStatus(
task_id=self.task_id,
scenario_path=self.scenario_path,
status=self.status,
logs=list(self.logs),
run_id=self.run_id,
error=self.error,
created_at=self.created_at,
finished_at=self.finished_at,
)
class TaskManager:
"""Owns the thread pool and registry of evaluation tasks."""
def __init__(self, max_workers: int = 2) -> None:
"""Create a task manager backed by a small thread pool."""
self._executor = ThreadPoolExecutor(max_workers=max_workers)
self._tasks: dict[str, EvaluationTask] = {}
self._lock = threading.Lock()
def submit(self, scenario_path: str) -> str:
"""Register and schedule a new evaluation task, returning its id."""
task_id = uuid.uuid4().hex[:12]
task = EvaluationTask(task_id=task_id, scenario_path=scenario_path)
with self._lock:
self._tasks[task_id] = task
self._executor.submit(self._run, task)
return task_id
def get(self, task_id: str) -> TaskStatus | None:
"""Return a snapshot of one task, or None if the id is unknown."""
with self._lock:
task = self._tasks.get(task_id)
return task.snapshot() if task is not None else None
def list_tasks(self) -> list[TaskStatus]:
"""Return snapshots of all known tasks, newest first."""
with self._lock:
tasks = list(self._tasks.values())
snapshots = [task.snapshot() for task in tasks]
snapshots.sort(key=lambda item: item.created_at, reverse=True)
return snapshots
def _run(self, task: EvaluationTask) -> None:
"""Execute one evaluation end to end inside a worker thread."""
task.status = "running"
task.append_log(f"[{_now_iso()}] 开始评估: {task.scenario_path}")
capture = _LineCapture(task)
try:
# Lazy import keeps the web server bootable if ragas is unavailable.
task.append_log("加载评估引擎 (rag_eval / ragas)...")
from rag_eval.execution.runner import run_scenario
absolute_path = self._to_absolute(task.scenario_path)
task.append_log(f"运行场景文件: {absolute_path}")
with redirect_stdout(capture), redirect_stderr(capture):
result = run_scenario(str(absolute_path))
capture.flush()
task.run_id = getattr(result, "run_id", None)
output_dir = getattr(getattr(result, "scenario", None), "output_dir", "")
task.append_log(f"[{_now_iso()}] 评估完成。run_id={task.run_id}")
if output_dir:
task.append_log(f"结果目录: {output_dir}")
task.status = "completed"
except Exception as exc: # noqa: BLE001 - surface any failure to the UI
capture.flush()
error_type = type(exc).__name__
task.error = f"{error_type}: {exc}"
task.append_log(f"[{_now_iso()}] 评估失败 [{error_type}]: {exc}")
task.status = "failed"
finally:
task.finished_at = _now_iso()
def _to_absolute(self, scenario_path: str) -> Path:
"""Resolve a scenario path against the repository root if relative."""
candidate = Path(scenario_path)
if candidate.is_absolute():
return candidate
repo_root = Path(__file__).resolve().parents[2]
return (repo_root / candidate).resolve()
# Module-level singleton shared by the FastAPI routes.
task_manager = TaskManager()