Skip to content

Commit

Permalink
Merge input cubes only once when computing lazy multimodel statistics (
Browse files Browse the repository at this point in the history
…#2518)

Co-authored-by: Valeriu Predoi <[email protected]>
  • Loading branch information
bouweandela and valeriupredoi authored Oct 14, 2024
1 parent e4e6b9b commit bd36519
Showing 1 changed file with 23 additions and 23 deletions.
46 changes: 23 additions & 23 deletions esmvalcore/preprocessor/_multimodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,8 @@ def _compute_eager(
input_slices = cubes # scalar cubes
else:
input_slices = [cube[chunk] for cube in cubes]
result_slice = _compute(input_slices, operator=operator, **kwargs)
combined_cube = _combine(input_slices)
result_slice = _compute(combined_cube, operator=operator, **kwargs)
result_slices.append(result_slice)

try:
Expand All @@ -503,10 +504,13 @@ def _compute_eager(
return result_cube


def _compute(cubes: list, *, operator: iris.analysis.Aggregator, **kwargs):
def _compute(
cube: iris.cube.Cube,
*,
operator: iris.analysis.Aggregator,
**kwargs,
):
"""Compute statistic."""
cube = _combine(cubes)

with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
Expand All @@ -531,8 +535,6 @@ def _compute(cubes: list, *, operator: iris.analysis.Aggregator, **kwargs):

# Remove concatenation dimension added by _combine
result_cube.remove_coord(CONCAT_DIM)
for cube in cubes:
cube.remove_coord(CONCAT_DIM)

# some iris aggregators modify dtype, see e.g.
# https://numpy.org/doc/stable/reference/generated/numpy.ma.average.html
Expand All @@ -545,7 +547,6 @@ def _compute(cubes: list, *, operator: iris.analysis.Aggregator, **kwargs):
method=cell_method.method,
coords=cell_method.coord_names,
intervals=cell_method.intervals,
comments=f"input_cubes: {len(cubes)}",
)
result_cube.add_cell_method(updated_method)
return result_cube
Expand Down Expand Up @@ -602,27 +603,26 @@ def _multicube_statistics(
# Calculate statistics
statistics_cubes = {}
lazy_input = any(cube.has_lazy_data() for cube in cubes)
for stat in statistics:
(stat_id, result_cube) = _compute_statistic(cubes, lazy_input, stat)
combined_cube = None
for statistic in statistics:
stat_id = _get_stat_identifier(statistic)
logger.debug("Multicube statistics: computing: %s", stat_id)

(operator, kwargs) = _get_operator_and_kwargs(statistic)
(agg, agg_kwargs) = get_iris_aggregator(operator, **kwargs)
if lazy_input and agg.lazy_func is not None:
if combined_cube is None:
# Merge input cubes only once as this is can be computationally
# expensive.
combined_cube = _combine(cubes)
result_cube = _compute(combined_cube, operator=agg, **agg_kwargs)
else:
result_cube = _compute_eager(cubes, operator=agg, **agg_kwargs)
statistics_cubes[stat_id] = result_cube

return statistics_cubes


def _compute_statistic(cubes, lazy_input, statistic):
"""Compute a single statistic."""
stat_id = _get_stat_identifier(statistic)
logger.debug("Multicube statistics: computing: %s", stat_id)

(operator, kwargs) = _get_operator_and_kwargs(statistic)
(agg, agg_kwargs) = get_iris_aggregator(operator, **kwargs)
if lazy_input and agg.lazy_func is not None:
result_cube = _compute(cubes, operator=agg, **agg_kwargs)
else:
result_cube = _compute_eager(cubes, operator=agg, **agg_kwargs)
return (stat_id, result_cube)


def _multiproduct_statistics(
products,
statistics,
Expand Down

0 comments on commit bd36519

Please sign in to comment.