From 43e47de0659c8449fd4d788aca02642b6568988d Mon Sep 17 00:00:00 2001 From: "alexgangxi@163.com" Date: Mon, 6 Apr 2026 00:46:01 +0800 Subject: [PATCH] feat(runner): add reusable timing collector for grading workflows --- openjudge/runner/grading_runner.py | 253 ++++++++++++++++++---------- openjudge/utils/timer.py | 126 ++++++++++++++ tests/runner/test_grading_runner.py | 93 ++++++++++ tests/utils/test_timer.py | 80 +++++++++ 4 files changed, 464 insertions(+), 88 deletions(-) create mode 100644 openjudge/utils/timer.py create mode 100644 tests/utils/test_timer.py diff --git a/openjudge/runner/grading_runner.py b/openjudge/runner/grading_runner.py index 5a6cb3baf..54742aedc 100644 --- a/openjudge/runner/grading_runner.py +++ b/openjudge/runner/grading_runner.py @@ -12,6 +12,7 @@ import asyncio import copy +from contextlib import nullcontext from dataclasses import dataclass from typing import Any, Callable, Dict, List, Tuple, Union @@ -29,6 +30,7 @@ SemaphoreResourceExecutor, ) from openjudge.utils.mapping import parse_data_with_mapper +from openjudge.utils.timer import TimingCollector @dataclass @@ -155,6 +157,8 @@ def __init__( aggregators: Union[BaseAggregator, Callable, List[Union[BaseAggregator, Callable]], None] = None, show_progress: bool = True, executor: BaseResourceExecutor | None = None, + enable_timing: bool = False, + timing_collector: TimingCollector | None = None, ) -> None: """Initialize the grading runner. @@ -169,6 +173,10 @@ def __init__( show_progress: Whether to display a progress bar during execution. Defaults to True. executor: Optional execution resource to manage task execution. Defaults to LocalController if not provided. + enable_timing: Whether to collect latency metrics for the grading workflow. + Defaults to False. + timing_collector: Optional collector for storing timing records. When not + provided and ``enable_timing=True``, a collector is created automatically. Example: >>> # Initialize with multiple graders @@ -182,6 +190,10 @@ def __init__( self.max_concurrency = max_concurrency self.show_progress = show_progress self.executor = executor or SemaphoreResourceExecutor(max_concurrency) + self.enable_timing = enable_timing or timing_collector is not None + self.timing_collector = timing_collector or ( + TimingCollector() if self.enable_timing else None + ) # Handle aggregators if not aggregators: @@ -198,6 +210,8 @@ async def _arun( grader: BaseGrader, mapper: Dict[str, str] | Callable | None, executor: BaseResourceExecutor, + timing_collector: TimingCollector | None = None, + timing_metadata: dict[str, Any] | None = None, ) -> GraderResult: """Run a single evaluation asynchronously. @@ -236,21 +250,48 @@ async def _arun( ... } >>> result = await GradingRunner._arun(data, ContextGrader(), custom_mapper) """ - try: - data = parse_data_with_mapper(data, mapper) - # Create an isolated grader instance for this evaluation to prevent state sharing - isolated_grader = copy.deepcopy(grader) - - # The grader itself handles the mapping internally - return await isolated_grader.aevaluate(executor=executor, **data) - except Exception as e: - error_msg = f"Error in {grader.name} during evaluation: {str(e)}" - logger.error(error_msg) - return GraderError( - name=grader.name, - reason=f"Error in {grader.name} during evaluation", - error=error_msg, + timing_context = ( + timing_collector.measure( + "grading_runner.single_evaluation", + metadata={"grader_name": grader.name, **(timing_metadata or {})}, ) + if timing_collector + else nullcontext() + ) + + with timing_context: + try: + data = parse_data_with_mapper(data, mapper) + # Create an isolated grader instance for this evaluation to prevent state sharing + isolated_grader = copy.deepcopy(grader) + + # The grader itself handles the mapping internally + return await isolated_grader.aevaluate(executor=executor, **data) + except Exception as e: + error_msg = f"Error in {grader.name} during evaluation: {str(e)}" + logger.error(error_msg) + return GraderError( + name=grader.name, + reason=f"Error in {grader.name} during evaluation", + error=error_msg, + ) + + def get_timing_records(self, name: str | None = None) -> list: + """Return collected timing records for the grading workflow.""" + if self.timing_collector is None: + return [] + return self.timing_collector.get_records(name=name) + + def get_timing_summary(self) -> dict[str, dict[str, float | int]]: + """Return aggregate timing metrics collected by the runner.""" + if self.timing_collector is None: + return {} + return self.timing_collector.get_summary() + + def clear_timing_records(self) -> None: + """Clear previously collected timing records.""" + if self.timing_collector is not None: + self.timing_collector.clear() async def arun( self, @@ -324,59 +365,85 @@ async def arun( ... else: ... print(f" Sample {i}: Error - {result.error}") """ - # Create a dictionary to store result lists for each grader - grader_results: RunnerResult = {name: [] for name in self.grader_configs.keys()} - - # Create coroutines for all evaluators and all samples - all_coroutines = [] - coroutine_info = [] # Track (grader_name, sample_index) for each coroutine - - # Use the executor from self - executor = self.executor - - # Execute executor lifecycle - for name, config in self.grader_configs.items(): - grader = config.grader - mapper = config.mapper - assert grader is not None - - # Create coroutines for the current evaluator on all samples - for i, case in enumerate(dataset): - all_coroutines.append( - self._arun(data=case, grader=grader, mapper=mapper, executor=executor), - ) - coroutine_info.append( - (name, i), - ) # Record grader name and sample index - - # Execute all evaluator-sample coroutines concurrently - if self.show_progress: - all_results = await tqdm_asyncio.gather( - *all_coroutines, - desc="Evaluating a dataset", - total=len(all_coroutines), + timing_context = ( + self.timing_collector.measure( + "grading_runner.dataset", + metadata={ + "dataset_size": len(dataset), + "grader_count": len(self.grader_configs), + }, ) - else: - all_results = await asyncio.gather(*all_coroutines) - - # Initialize lists for all graders - for name in self.grader_configs.keys(): - grader_results[name] = [None] * len(dataset) - - # Organize results by grader - for (grader_name, sample_index), result in zip(coroutine_info, all_results): - grader_results[grader_name][sample_index] = result - - # Aggregate results - if self.aggregators: - for aggregator in self.aggregators: - aggregator_name = aggregator.__name__ - grader_results[aggregator_name] = [None] * len(dataset) - for i in range(len(dataset)): - grader_results[aggregator_name][i] = aggregator( - {grader_name: grader_results[grader_name][i] for grader_name in self.grader_configs.keys()}, + if self.timing_collector + else nullcontext() + ) + + with timing_context: + # Create a dictionary to store result lists for each grader + grader_results: RunnerResult = {name: [] for name in self.grader_configs.keys()} + + # Create coroutines for all evaluators and all samples + all_coroutines = [] + coroutine_info = [] # Track (grader_name, sample_index) for each coroutine + + # Use the executor from self + executor = self.executor + + # Execute executor lifecycle + for name, config in self.grader_configs.items(): + grader = config.grader + mapper = config.mapper + assert grader is not None + + # Create coroutines for the current evaluator on all samples + for i, case in enumerate(dataset): + all_coroutines.append( + self._arun( + data=case, + grader=grader, + mapper=mapper, + executor=executor, + timing_collector=self.timing_collector, + timing_metadata={ + "grader_config_name": name, + "sample_index": i, + }, + ), ) - return grader_results + coroutine_info.append( + (name, i), + ) # Record grader name and sample index + + # Execute all evaluator-sample coroutines concurrently + if self.show_progress: + all_results = await tqdm_asyncio.gather( + *all_coroutines, + desc="Evaluating a dataset", + total=len(all_coroutines), + ) + else: + all_results = await asyncio.gather(*all_coroutines) + + # Initialize lists for all graders + for name in self.grader_configs.keys(): + grader_results[name] = [None] * len(dataset) + + # Organize results by grader + for (grader_name, sample_index), result in zip(coroutine_info, all_results): + grader_results[grader_name][sample_index] = result + + # Aggregate results + if self.aggregators: + for aggregator in self.aggregators: + aggregator_name = aggregator.__name__ + grader_results[aggregator_name] = [None] * len(dataset) + for i in range(len(dataset)): + grader_results[aggregator_name][i] = aggregator( + { + grader_name: grader_results[grader_name][i] + for grader_name in self.grader_configs.keys() + }, + ) + return grader_results async def arun_multiple_datasets( self, @@ -468,26 +535,36 @@ async def arun_multiple_datasets( - When batch processing, individual arun() progress bars are disabled to avoid display conflicts with the batch-level progress bar. """ - # Temporarily disable show_progress for individual arun calls to avoid progress bar conflicts - original_show_progress = self.show_progress - self.show_progress = False - - try: - # Create tasks for each dataset - tasks = [self.arun(dataset, *args, **kwargs) for dataset in datasets] - - # Execute all dataset tasks concurrently with progress bar - if original_show_progress: - all_results = await tqdm_asyncio.gather( - *tasks, - desc=f"Evaluating {len(tasks)} datasets", - total=len(tasks), - ) - else: - all_results = await asyncio.gather(*tasks) - - # Return results as a list - return list(all_results) - finally: - # Restore original show_progress setting - self.show_progress = original_show_progress + timing_context = ( + self.timing_collector.measure( + "grading_runner.multi_dataset", + metadata={"dataset_count": len(datasets)}, + ) + if self.timing_collector + else nullcontext() + ) + + with timing_context: + # Temporarily disable show_progress for individual arun calls to avoid progress bar conflicts + original_show_progress = self.show_progress + self.show_progress = False + + try: + # Create tasks for each dataset + tasks = [self.arun(dataset, *args, **kwargs) for dataset in datasets] + + # Execute all dataset tasks concurrently with progress bar + if original_show_progress: + all_results = await tqdm_asyncio.gather( + *tasks, + desc=f"Evaluating {len(tasks)} datasets", + total=len(tasks), + ) + else: + all_results = await asyncio.gather(*tasks) + + # Return results as a list + return list(all_results) + finally: + # Restore original show_progress setting + self.show_progress = original_show_progress diff --git a/openjudge/utils/timer.py b/openjudge/utils/timer.py new file mode 100644 index 000000000..36b89776e --- /dev/null +++ b/openjudge/utils/timer.py @@ -0,0 +1,126 @@ +# -*- coding: utf-8 -*- +"""Lightweight timing utilities for performance instrumentation. + +This module provides a small reusable timing collector that can be used as a +context manager or decorator. Timing records are stored in memory, summarized +by operation name, and logged at DEBUG level by default to avoid cluttering +normal output. +""" + +from __future__ import annotations + +import asyncio +from collections import defaultdict +from contextlib import contextmanager +from dataclasses import dataclass, field +from functools import wraps +from time import perf_counter +from typing import Any, Callable, Iterator + +from loguru import logger + + +@dataclass(frozen=True) +class TimingRecord: + """A single timing measurement.""" + + name: str + duration_ms: float + metadata: dict[str, Any] = field(default_factory=dict) + + +class TimingCollector: + """Collect timing records and expose aggregated summaries.""" + + def __init__(self, log_level: str = "DEBUG") -> None: + self.log_level = log_level.upper() + self._records: list[TimingRecord] = [] + + @contextmanager + def measure(self, name: str, metadata: dict[str, Any] | None = None) -> Iterator[None]: + """Measure execution time for a code block.""" + start = perf_counter() + try: + yield + finally: + duration_ms = (perf_counter() - start) * 1000 + self.record(name=name, duration_ms=duration_ms, metadata=metadata) + + def record( + self, + name: str, + duration_ms: float, + metadata: dict[str, Any] | None = None, + ) -> TimingRecord: + """Add a timing record and emit a debug log entry.""" + record = TimingRecord( + name=name, + duration_ms=duration_ms, + metadata=dict(metadata or {}), + ) + self._records.append(record) + logger.log( + self.log_level, + "Timing | {name} took {duration_ms:.3f} ms | metadata={metadata}", + name=record.name, + duration_ms=record.duration_ms, + metadata=record.metadata, + ) + return record + + def get_records(self, name: str | None = None) -> list[TimingRecord]: + """Return collected records, optionally filtered by operation name.""" + if name is None: + return list(self._records) + return [record for record in self._records if record.name == name] + + def get_summary(self) -> dict[str, dict[str, float | int]]: + """Return aggregate timing statistics grouped by operation name.""" + grouped_records: dict[str, list[float]] = defaultdict(list) + for record in self._records: + grouped_records[record.name].append(record.duration_ms) + + summary: dict[str, dict[str, float | int]] = {} + for name, durations in grouped_records.items(): + summary[name] = { + "count": len(durations), + "total_ms": sum(durations), + "avg_ms": sum(durations) / len(durations), + "min_ms": min(durations), + "max_ms": max(durations), + } + return summary + + def clear(self) -> None: + """Clear all collected timing records.""" + self._records.clear() + + +def timed( + name: str, + collector: TimingCollector, + metadata: dict[str, Any] | None = None, +) -> Callable: + """Decorator for timing sync or async functions with a collector.""" + + def decorator(func: Callable) -> Callable: + if asyncio.iscoroutinefunction(func): + + @wraps(func) + async def async_wrapper(*args: Any, **kwargs: Any) -> Any: + with collector.measure(name, metadata=metadata): + return await func(*args, **kwargs) + + return async_wrapper + + @wraps(func) + def sync_wrapper(*args: Any, **kwargs: Any) -> Any: + with collector.measure(name, metadata=metadata): + return func(*args, **kwargs) + + return sync_wrapper + + return decorator + + +__all__ = ["TimingCollector", "TimingRecord", "timed"] diff --git a/tests/runner/test_grading_runner.py b/tests/runner/test_grading_runner.py index 66ef0924d..e27fb998b 100644 --- a/tests/runner/test_grading_runner.py +++ b/tests/runner/test_grading_runner.py @@ -14,6 +14,7 @@ from openjudge.graders.schema import GraderError, GraderScore from openjudge.runner.aggregator.weighted_sum_aggregator import WeightedSumAggregator from openjudge.runner.grading_runner import GradingRunner +from openjudge.utils.timer import TimingCollector class MockGrader(BaseGrader): @@ -281,6 +282,98 @@ async def test_grading_runner_with_real_components(self): assert isinstance(result, (GraderScore, GraderError)) assert result.metadata["call_count"] == 1 + @pytest.mark.asyncio + async def test_grading_runner_collects_timing_summary(self): + """Should expose timing summaries when timing is enabled.""" + runner = GradingRunner( + grader_configs={ + "accuracy": MockGrader(name="accuracy_grader", score_value=0.9), + "relevance": MockGrader(name="relevance_grader", score_value=0.8), + }, + show_progress=False, + enable_timing=True, + ) + + dataset = [ + {"query": "What is the capital of France?", "answer": "Paris"}, + {"query": "What is the capital of Germany?", "answer": "Berlin"}, + ] + + await runner.arun(dataset) + + summary = runner.get_timing_summary() + assert "grading_runner.dataset" in summary + assert "grading_runner.single_evaluation" in summary + assert summary["grading_runner.dataset"]["count"] == 1 + assert summary["grading_runner.single_evaluation"]["count"] == 4 + + records = runner.get_timing_records("grading_runner.single_evaluation") + assert len(records) == 4 + assert all(record.metadata["grader_name"] in {"accuracy_grader", "relevance_grader"} for record in records) + assert {record.metadata["sample_index"] for record in records} == {0, 1} + + @pytest.mark.asyncio + async def test_grading_runner_uses_injected_timing_collector(self): + """Should support externally managed timing collectors.""" + collector = TimingCollector() + runner = GradingRunner( + grader_configs={ + "accuracy": MockGrader(name="accuracy_grader", score_value=0.9), + }, + show_progress=False, + timing_collector=collector, + ) + + dataset = [{"query": "What is 2+2?", "answer": "4"}] + await runner.arun(dataset) + + assert runner.timing_collector is collector + assert len(collector.get_records("grading_runner.dataset")) == 1 + assert len(collector.get_records("grading_runner.single_evaluation")) == 1 + + @pytest.mark.asyncio + async def test_grading_runner_multiple_datasets_collects_batch_timing(self): + """Should record batch-level timing for multiple datasets.""" + runner = GradingRunner( + grader_configs={ + "accuracy": MockGrader(name="accuracy_grader", score_value=0.9), + }, + show_progress=False, + enable_timing=True, + ) + + datasets = [ + [{"query": "What is 2+2?", "answer": "4"}], + [{"query": "What is the sky color?", "answer": "blue"}], + ] + + await runner.arun_multiple_datasets(datasets) + + summary = runner.get_timing_summary() + assert summary["grading_runner.multi_dataset"]["count"] == 1 + assert summary["grading_runner.dataset"]["count"] == 2 + assert summary["grading_runner.single_evaluation"]["count"] == 2 + + @pytest.mark.asyncio + async def test_grading_runner_clear_timing_records(self): + """Should clear collected timing data on demand.""" + runner = GradingRunner( + grader_configs={ + "accuracy": MockGrader(name="accuracy_grader", score_value=0.9), + }, + show_progress=False, + enable_timing=True, + ) + + dataset = [{"query": "What is 2+2?", "answer": "4"}] + await runner.arun(dataset) + assert runner.get_timing_records() + + runner.clear_timing_records() + + assert runner.get_timing_records() == [] + assert runner.get_timing_summary() == {} + @pytest.mark.asyncio async def test_grading_runner_multiple_datasets(self): """Test the grading runner with multiple datasets using arun_multiple_datasets.""" diff --git a/tests/utils/test_timer.py b/tests/utils/test_timer.py new file mode 100644 index 000000000..fc4065859 --- /dev/null +++ b/tests/utils/test_timer.py @@ -0,0 +1,80 @@ +# -*- coding: utf-8 -*- +"""Unit tests for timing utilities.""" + +import asyncio + +import pytest + +from openjudge.utils.timer import TimingCollector, timed + + +@pytest.mark.unit +class TestTimingCollector: + """Test suite for TimingCollector.""" + + def test_measure_records_context_timing(self): + """Should record a timing entry for measured code blocks.""" + collector = TimingCollector() + + with collector.measure("test.context", metadata={"phase": "setup"}): + sum(range(10)) + + records = collector.get_records("test.context") + assert len(records) == 1 + assert records[0].name == "test.context" + assert records[0].duration_ms >= 0 + assert records[0].metadata == {"phase": "setup"} + + def test_summary_aggregates_multiple_records(self): + """Should summarize repeated measurements by operation name.""" + collector = TimingCollector() + collector.record("test.summary", 5.0) + collector.record("test.summary", 15.0) + + summary = collector.get_summary() + assert "test.summary" in summary + assert summary["test.summary"]["count"] == 2 + assert summary["test.summary"]["total_ms"] == pytest.approx(20.0) + assert summary["test.summary"]["avg_ms"] == pytest.approx(10.0) + assert summary["test.summary"]["min_ms"] == pytest.approx(5.0) + assert summary["test.summary"]["max_ms"] == pytest.approx(15.0) + + def test_clear_removes_all_records(self): + """Should clear collected timing records.""" + collector = TimingCollector() + collector.record("test.clear", 1.0) + + collector.clear() + + assert collector.get_records() == [] + assert collector.get_summary() == {} + + def test_timed_decorator_supports_sync_functions(self): + """Should measure sync functions decorated with timed.""" + collector = TimingCollector() + + @timed("test.sync", collector, metadata={"kind": "sync"}) + def add(a: int, b: int) -> int: + return a + b + + assert add(1, 2) == 3 + + records = collector.get_records("test.sync") + assert len(records) == 1 + assert records[0].metadata == {"kind": "sync"} + + @pytest.mark.asyncio + async def test_timed_decorator_supports_async_functions(self): + """Should measure async functions decorated with timed.""" + collector = TimingCollector() + + @timed("test.async", collector, metadata={"kind": "async"}) + async def async_add(a: int, b: int) -> int: + await asyncio.sleep(0) + return a + b + + assert await async_add(1, 2) == 3 + + records = collector.get_records("test.async") + assert len(records) == 1 + assert records[0].metadata == {"kind": "async"}