Skip to content

Commit

Permalink
Add reshuffle to optimize WeatherBench2 evaluation
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 697039829
  • Loading branch information
shoyer authored and Weatherbench2 authors committed Nov 16, 2024
1 parent c638c1f commit 04e3359
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 5 deletions.
12 changes: 11 additions & 1 deletion scripts/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,13 +258,22 @@
FANOUT = flags.DEFINE_integer(
'fanout',
None,
help='Beam CombineFn fanout. Might be required for large dataset.',
help='Beam CombineFn fanout. Recommended when evaluating large datasets.',
)
NUM_THREADS = flags.DEFINE_integer(
'num_threads',
None,
help='Number of chunks to read/write Zarr in parallel per worker.',
)
SHUFFLE_BEFORE_TEMPORAL_MEAN = flags.DEFINE_bool(
'shuffle_before_temporal_mean',
False,
help=(
'Shuffle before computing the temporal mean. This is a good idea when'
' evaluation metric outputs are small compared to the size of the'
' input data, such as when aggregating over space or a large ensemble.'
),
)


def _wind_vector_error(err_type: str):
Expand Down Expand Up @@ -661,6 +670,7 @@ def main(argv: list[str]) -> None:
skipna=SKIPNA.value,
fanout=FANOUT.value,
num_threads=NUM_THREADS.value,
shuffle_before_temporal_mean=SHUFFLE_BEFORE_TEMPORAL_MEAN.value,
argv=argv,
)
else:
Expand Down
2 changes: 1 addition & 1 deletion weatherbench2/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ class Eval:
by-valid convention. For by-init, specify analysis dataset as obs.
derived_variables: dict of DerivedVariable instances to compute on the fly.
temporal_mean: Compute temporal mean (over time/init_time) for metrics.
output_format: Wether to save to 'netcdf' or 'zarr'.
output_format: whether to save to 'netcdf' or 'zarr'.
"""

metrics: t.Dict[str, Metric]
Expand Down
25 changes: 22 additions & 3 deletions weatherbench2/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,10 @@ def _metric_and_region_loop(
"""Compute metric results looping over metrics and regions in eval config."""
# Compute derived variables
logging.info('Starting _metric_and_region_loop')
logging.info(
f'{len(forecast)} variables, {forecast.sizes=}, {truth.sizes=}, '
f'({forecast.nbytes + truth.nbytes} bytes)'
)
for name, dv in eval_config.derived_variables.items():
logging.info(f'Logging: derived_variable {name!r}: {dv}')
forecast[name] = dv.compute(forecast)
Expand Down Expand Up @@ -559,7 +563,11 @@ class _EvaluateAllMetrics(beam.PTransform):
input_chunks: Chunks to use for input files.
skipna: Whether to skip NaN values in both forecasts and observations during
evaluation.
fanout: Fanout parameter for Beam combiners.
fanout: Fanout parameter for Beam combiners in the temporal mean.
shuffle_before_temporal_mean: If True, shuffle before computing the temporal
mean. This is a good idea when evaluation metric outputs are small
compared to the size of the input data, such as when aggregating over
space or a large ensemble.
num_threads: Number of threads for reading/writing files.
"""

Expand All @@ -569,6 +577,7 @@ class _EvaluateAllMetrics(beam.PTransform):
input_chunks: abc.Mapping[str, int]
skipna: bool
fanout: Optional[int] = None
shuffle_before_temporal_mean: bool = False
num_threads: Optional[int] = None

def _evaluate_chunk(
Expand Down Expand Up @@ -724,6 +733,10 @@ def _evaluate(
forecast_pipeline |= 'EvaluateChunk' >> beam.MapTuple(self._evaluate_chunk)

if self.eval_config.temporal_mean:
if self.shuffle_before_temporal_mean:
# Reshuffle to avoid fusing evaluation of chunks with the temporal mean.
forecast_pipeline |= beam.Reshuffle()

forecast_pipeline |= 'TemporalMean' >> xbeam.Mean(
dim='init_time' if self.data_config.by_init else 'time',
fanout=self.fanout,
Expand All @@ -749,6 +762,7 @@ def evaluate_with_beam(
input_chunks: abc.Mapping[str, int],
runner: str,
fanout: Optional[int] = None,
shuffle_before_temporal_mean: bool = False,
num_threads: Optional[int] = None,
argv: Optional[list[str]] = None,
skipna: bool = False,
Expand Down Expand Up @@ -777,7 +791,11 @@ def evaluate_with_beam(
eval_configs: Dictionary of config.Eval instances.
input_chunks: Chunking of input datasets.
runner: Beam runner.
fanout: Beam CombineFn fanout.
fanout: Fanout parameter for Beam combiners in the temporal mean.
shuffle_before_temporal_mean: If True, shuffle before computing the temporal
mean. This is a good idea when evaluation metric outputs are small
compared to the size of the input data, such as when aggregating over
space or a large ensemble.
num_threads: Number of threads to use for reading/writing data.
argv: Other arguments to pass into the Beam pipeline.
skipna: Whether to skip NaN values in both forecasts and observations during
Expand All @@ -795,9 +813,10 @@ def evaluate_with_beam(
eval_config,
data_config,
input_chunks,
skipna=skipna,
fanout=fanout,
shuffle_before_temporal_mean=shuffle_before_temporal_mean,
num_threads=num_threads,
skipna=skipna,
)
| f'save_{eval_name}'
>> _SaveOutputs(
Expand Down

0 comments on commit 04e3359

Please sign in to comment.