Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
253 changes: 165 additions & 88 deletions openjudge/runner/grading_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import asyncio
import copy
from contextlib import nullcontext
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Tuple, Union

Expand All @@ -29,6 +30,7 @@
SemaphoreResourceExecutor,
)
from openjudge.utils.mapping import parse_data_with_mapper
from openjudge.utils.timer import TimingCollector


@dataclass
Expand Down Expand Up @@ -155,6 +157,8 @@ def __init__(
aggregators: Union[BaseAggregator, Callable, List[Union[BaseAggregator, Callable]], None] = None,
show_progress: bool = True,
executor: BaseResourceExecutor | None = None,
enable_timing: bool = False,
timing_collector: TimingCollector | None = None,
) -> None:
"""Initialize the grading runner.

Expand All @@ -169,6 +173,10 @@ def __init__(
show_progress: Whether to display a progress bar during execution. Defaults to True.
executor: Optional execution resource to manage task execution.
Defaults to LocalController if not provided.
enable_timing: Whether to collect latency metrics for the grading workflow.
Defaults to False.
timing_collector: Optional collector for storing timing records. When not
provided and ``enable_timing=True``, a collector is created automatically.

Example:
>>> # Initialize with multiple graders
Expand All @@ -182,6 +190,10 @@ def __init__(
self.max_concurrency = max_concurrency
self.show_progress = show_progress
self.executor = executor or SemaphoreResourceExecutor(max_concurrency)
self.enable_timing = enable_timing or timing_collector is not None
self.timing_collector = timing_collector or (
TimingCollector() if self.enable_timing else None
)

# Handle aggregators
if not aggregators:
Expand All @@ -198,6 +210,8 @@ async def _arun(
grader: BaseGrader,
mapper: Dict[str, str] | Callable | None,
executor: BaseResourceExecutor,
timing_collector: TimingCollector | None = None,
timing_metadata: dict[str, Any] | None = None,
) -> GraderResult:
"""Run a single evaluation asynchronously.

Expand Down Expand Up @@ -236,21 +250,48 @@ async def _arun(
... }
>>> result = await GradingRunner._arun(data, ContextGrader(), custom_mapper)
"""
try:
data = parse_data_with_mapper(data, mapper)
# Create an isolated grader instance for this evaluation to prevent state sharing
isolated_grader = copy.deepcopy(grader)

# The grader itself handles the mapping internally
return await isolated_grader.aevaluate(executor=executor, **data)
except Exception as e:
error_msg = f"Error in {grader.name} during evaluation: {str(e)}"
logger.error(error_msg)
return GraderError(
name=grader.name,
reason=f"Error in {grader.name} during evaluation",
error=error_msg,
timing_context = (
timing_collector.measure(
"grading_runner.single_evaluation",
metadata={"grader_name": grader.name, **(timing_metadata or {})},
)
if timing_collector
else nullcontext()
)

with timing_context:
try:
data = parse_data_with_mapper(data, mapper)
# Create an isolated grader instance for this evaluation to prevent state sharing
isolated_grader = copy.deepcopy(grader)

# The grader itself handles the mapping internally
return await isolated_grader.aevaluate(executor=executor, **data)
except Exception as e:
error_msg = f"Error in {grader.name} during evaluation: {str(e)}"
logger.error(error_msg)
return GraderError(
name=grader.name,
reason=f"Error in {grader.name} during evaluation",
error=error_msg,
)

def get_timing_records(self, name: str | None = None) -> list:
"""Return collected timing records for the grading workflow."""
if self.timing_collector is None:
return []
return self.timing_collector.get_records(name=name)

def get_timing_summary(self) -> dict[str, dict[str, float | int]]:
"""Return aggregate timing metrics collected by the runner."""
if self.timing_collector is None:
return {}
return self.timing_collector.get_summary()

def clear_timing_records(self) -> None:
"""Clear previously collected timing records."""
if self.timing_collector is not None:
self.timing_collector.clear()

async def arun(
self,
Expand Down Expand Up @@ -324,59 +365,85 @@ async def arun(
... else:
... print(f" Sample {i}: Error - {result.error}")
"""
# Create a dictionary to store result lists for each grader
grader_results: RunnerResult = {name: [] for name in self.grader_configs.keys()}

# Create coroutines for all evaluators and all samples
all_coroutines = []
coroutine_info = [] # Track (grader_name, sample_index) for each coroutine

# Use the executor from self
executor = self.executor

# Execute executor lifecycle
for name, config in self.grader_configs.items():
grader = config.grader
mapper = config.mapper
assert grader is not None

# Create coroutines for the current evaluator on all samples
for i, case in enumerate(dataset):
all_coroutines.append(
self._arun(data=case, grader=grader, mapper=mapper, executor=executor),
)
coroutine_info.append(
(name, i),
) # Record grader name and sample index

# Execute all evaluator-sample coroutines concurrently
if self.show_progress:
all_results = await tqdm_asyncio.gather(
*all_coroutines,
desc="Evaluating a dataset",
total=len(all_coroutines),
timing_context = (
self.timing_collector.measure(
"grading_runner.dataset",
metadata={
"dataset_size": len(dataset),
"grader_count": len(self.grader_configs),
},
)
else:
all_results = await asyncio.gather(*all_coroutines)

# Initialize lists for all graders
for name in self.grader_configs.keys():
grader_results[name] = [None] * len(dataset)

# Organize results by grader
for (grader_name, sample_index), result in zip(coroutine_info, all_results):
grader_results[grader_name][sample_index] = result

# Aggregate results
if self.aggregators:
for aggregator in self.aggregators:
aggregator_name = aggregator.__name__
grader_results[aggregator_name] = [None] * len(dataset)
for i in range(len(dataset)):
grader_results[aggregator_name][i] = aggregator(
{grader_name: grader_results[grader_name][i] for grader_name in self.grader_configs.keys()},
if self.timing_collector
else nullcontext()
)

with timing_context:
# Create a dictionary to store result lists for each grader
grader_results: RunnerResult = {name: [] for name in self.grader_configs.keys()}

# Create coroutines for all evaluators and all samples
all_coroutines = []
coroutine_info = [] # Track (grader_name, sample_index) for each coroutine

# Use the executor from self
executor = self.executor

# Execute executor lifecycle
for name, config in self.grader_configs.items():
grader = config.grader
mapper = config.mapper
assert grader is not None

# Create coroutines for the current evaluator on all samples
for i, case in enumerate(dataset):
all_coroutines.append(
self._arun(
data=case,
grader=grader,
mapper=mapper,
executor=executor,
timing_collector=self.timing_collector,
timing_metadata={
"grader_config_name": name,
"sample_index": i,
},
),
)
return grader_results
coroutine_info.append(
(name, i),
) # Record grader name and sample index

# Execute all evaluator-sample coroutines concurrently
if self.show_progress:
all_results = await tqdm_asyncio.gather(
*all_coroutines,
desc="Evaluating a dataset",
total=len(all_coroutines),
)
else:
all_results = await asyncio.gather(*all_coroutines)

# Initialize lists for all graders
for name in self.grader_configs.keys():
grader_results[name] = [None] * len(dataset)

# Organize results by grader
for (grader_name, sample_index), result in zip(coroutine_info, all_results):
grader_results[grader_name][sample_index] = result

# Aggregate results
if self.aggregators:
for aggregator in self.aggregators:
aggregator_name = aggregator.__name__
grader_results[aggregator_name] = [None] * len(dataset)
for i in range(len(dataset)):
grader_results[aggregator_name][i] = aggregator(
{
grader_name: grader_results[grader_name][i]
for grader_name in self.grader_configs.keys()
},
)
return grader_results

async def arun_multiple_datasets(
self,
Expand Down Expand Up @@ -468,26 +535,36 @@ async def arun_multiple_datasets(
- When batch processing, individual arun() progress bars are disabled to avoid
display conflicts with the batch-level progress bar.
"""
# Temporarily disable show_progress for individual arun calls to avoid progress bar conflicts
original_show_progress = self.show_progress
self.show_progress = False

try:
# Create tasks for each dataset
tasks = [self.arun(dataset, *args, **kwargs) for dataset in datasets]

# Execute all dataset tasks concurrently with progress bar
if original_show_progress:
all_results = await tqdm_asyncio.gather(
*tasks,
desc=f"Evaluating {len(tasks)} datasets",
total=len(tasks),
)
else:
all_results = await asyncio.gather(*tasks)

# Return results as a list
return list(all_results)
finally:
# Restore original show_progress setting
self.show_progress = original_show_progress
timing_context = (
self.timing_collector.measure(
"grading_runner.multi_dataset",
metadata={"dataset_count": len(datasets)},
)
if self.timing_collector
else nullcontext()
)

with timing_context:
# Temporarily disable show_progress for individual arun calls to avoid progress bar conflicts
original_show_progress = self.show_progress
self.show_progress = False

try:
# Create tasks for each dataset
tasks = [self.arun(dataset, *args, **kwargs) for dataset in datasets]

# Execute all dataset tasks concurrently with progress bar
if original_show_progress:
all_results = await tqdm_asyncio.gather(
*tasks,
desc=f"Evaluating {len(tasks)} datasets",
total=len(tasks),
)
else:
all_results = await asyncio.gather(*tasks)

# Return results as a list
return list(all_results)
finally:
# Restore original show_progress setting
self.show_progress = original_show_progress
Loading