Skip to content

Commit

Permalink
Lazy mask_fillvalues preprocessor function (#2340)
Browse files Browse the repository at this point in the history
  • Loading branch information
bouweandela authored Jun 4, 2024
1 parent 8e2ef15 commit af5490f
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 69 deletions.
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ dependencies:
- geopy
- humanfriendly
- importlib_metadata # required for Python < 3.10
- iris >3.8.0
- iris >=3.9.0
- iris-esmf-regrid >=0.10.0 # github.com/SciTools-incubator/iris-esmf-regrid/pull/342
- isodate
- jinja2
Expand Down
149 changes: 87 additions & 62 deletions esmvalcore/preprocessor/_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
masking with ancillary variables, masking with Natural Earth shapefiles
(land or ocean), masking on thresholds, missing values masking.
"""
from __future__ import annotations

import logging
import os
Expand Down Expand Up @@ -323,7 +324,14 @@ def _mask_with_shp(cube, shapefilename, region_indices=None):
return cube


def count_spells(data, threshold, axis, spell_length):
def count_spells(
data: np.ndarray | da.Array,
threshold: float | None,
axis: int,
spell_length,
) -> np.ndarray | da.Array:
# Copied from:
# https://scitools-iris.readthedocs.io/en/stable/generated/gallery/general/plot_custom_aggregation.html
"""Count data occurrences.
Define a function to perform the custom statistical operation.
Expand All @@ -338,10 +346,10 @@ def count_spells(data, threshold, axis, spell_length):
Parameters
----------
data: ndarray
data:
raw data to be compared with value threshold.
threshold: float
threshold:
threshold point for 'significant' datapoints.
axis: int
Expand All @@ -353,15 +361,17 @@ def count_spells(data, threshold, axis, spell_length):
Returns
-------
int
:obj:`numpy.ndarray` or :obj:`dask.array.Array`
Number of counts.
"""
if axis < 0:
# just cope with negative axis numbers
axis += data.ndim
# Threshold the data to find the 'significant' points.
array_module = da if isinstance(data, da.Array) else np
if not threshold:
data_hits = np.ones_like(data, dtype=bool)
# Keeps the mask of the input data.
data_hits = array_module.ma.ones_like(data, dtype=bool)
else:
data_hits = data > float(threshold)

Expand All @@ -371,17 +381,16 @@ def count_spells(data, threshold, axis, spell_length):
# if you want overlapping windows set the step to be m*spell_length
# where m is a float
###############################################################
hit_windows = rolling_window(data_hits,
window=spell_length,
step=spell_length,
axis=axis)

hit_windows = rolling_window(
data_hits,
window=spell_length,
step=spell_length,
axis=axis,
)
# Find the windows "full of True-s" (along the added 'window axis').
full_windows = np.all(hit_windows, axis=axis + 1)

full_windows = array_module.all(hit_windows, axis=axis + 1)
# Count points fulfilling the condition (along the time axis).
spell_point_counts = np.sum(full_windows, axis=axis, dtype=int)

spell_point_counts = array_module.sum(full_windows, axis=axis, dtype=int)
return spell_point_counts


Expand Down Expand Up @@ -572,10 +581,12 @@ def mask_multimodel(products):
f"got {product_types}")


def mask_fillvalues(products,
threshold_fraction,
min_value=None,
time_window=1):
def mask_fillvalues(
products,
threshold_fraction: float,
min_value: float | None = None,
time_window: int = 1,
):
"""Compute and apply a multi-dataset fillvalues mask.
Construct the mask that fills a certain time window with missing values
Expand All @@ -590,15 +601,15 @@ def mask_fillvalues(products,
products: iris.cube.Cube
data products to be masked.
threshold_fraction: float
threshold_fraction:
fractional threshold to be used as argument for Aggregator.
Must be between 0 and 1.
min_value: float
min_value:
minimum value threshold; default None
If default, no thresholding applied so the full mask will be selected.
time_window: float
time_window:
time window to compute missing data counts; default set to 1.
Returns
Expand All @@ -611,48 +622,58 @@ def mask_fillvalues(products,
NotImplementedError
Implementation missing for data with higher dimensionality than 4.
"""
combined_mask = None
array_module = da if any(c.has_lazy_data() for p in products
for c in p.cubes) else np

logger.debug("Creating fillvalues mask")
used = set()
combined_mask = None
for product in products:
for cube in product.cubes:
cube.data = np.ma.fix_invalid(cube.data, copy=False)
mask = _get_fillvalues_mask(cube, threshold_fraction, min_value,
time_window)
for i, cube in enumerate(product.cubes):
cube = cube.copy()
product.cubes[i] = cube
cube.data = array_module.ma.fix_invalid(cube.core_data())
mask = _get_fillvalues_mask(
cube,
threshold_fraction,
min_value,
time_window,
)
if combined_mask is None:
combined_mask = np.zeros_like(mask)
combined_mask = array_module.zeros_like(mask)
# Select only valid (not all masked) pressure levels
n_dims = len(mask.shape)
if n_dims == 2:
valid = ~np.all(mask)
if valid:
combined_mask |= mask
used.add(product)
elif n_dims == 3:
valid = ~np.all(mask, axis=(1, 2))
combined_mask[valid] |= mask[valid]
if np.any(valid):
used.add(product)
if mask.ndim in (2, 3):
valid = ~mask.all(axis=(-2, -1), keepdims=True)
else:
raise NotImplementedError(
f"Unable to handle {n_dims} dimensional data"
f"Unable to handle {mask.ndim} dimensional data"
)
combined_mask = array_module.where(
valid,
combined_mask | mask,
combined_mask,
)

if np.any(combined_mask):
logger.debug("Applying fillvalues mask")
used = {p.copy_provenance() for p in used}
for product in products:
for cube in product.cubes:
cube.data.mask |= combined_mask
for other in used:
if other.filename != product.filename:
product.wasderivedfrom(other)
for product in products:
for cube in product.cubes:
array = cube.core_data()
data = array_module.ma.getdata(array)
mask = array_module.ma.getmaskarray(array) | combined_mask
cube.data = array_module.ma.masked_array(data, mask)

# Record provenance
input_products = {p.copy_provenance() for p in products}
for other in input_products:
if other.filename != product.filename:
product.wasderivedfrom(other)

return products


def _get_fillvalues_mask(cube, threshold_fraction, min_value, time_window):
def _get_fillvalues_mask(
cube: iris.cube.Cube,
threshold_fraction: float,
min_value: float | None,
time_window: int,
) -> np.ndarray | da.Array:
"""Compute the per-model missing values mask.
Construct the mask that fills a certain time window with missing
Expand All @@ -662,7 +683,6 @@ def _get_fillvalues_mask(cube, threshold_fraction, min_value, time_window):
counts the number of valid (unmasked) data points within that
window; a simple value thresholding is also applied if needed.
"""
# basic checks
if threshold_fraction < 0 or threshold_fraction > 1.0:
raise ValueError(
f"Fraction of missing values {threshold_fraction} should be "
Expand All @@ -678,19 +698,24 @@ def _get_fillvalues_mask(cube, threshold_fraction, min_value, time_window):
counts_threshold = int(max_counts_per_time_window * threshold_fraction)

# Make an aggregator
spell_count = Aggregator('spell_count',
count_spells,
units_func=lambda units: 1)
spell_count = Aggregator(
'spell_count',
count_spells,
lazy_func=count_spells,
units_func=lambda units: 1,
)

# Calculate the statistic.
counts_windowed_cube = cube.collapsed('time',
spell_count,
threshold=min_value,
spell_length=time_window)
counts_windowed_cube = cube.collapsed(
'time',
spell_count,
threshold=min_value,
spell_length=time_window,
)

# Create mask
mask = counts_windowed_cube.data < counts_threshold
if np.ma.isMaskedArray(mask):
mask = mask.data | mask.mask
mask = counts_windowed_cube.core_data() < counts_threshold
array_module = da if isinstance(mask, da.Array) else np
mask = array_module.ma.filled(mask, True)

return mask
4 changes: 1 addition & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,7 @@
'pyyaml',
'requests',
'scipy>=1.6',
# See the following issue for info on the iris pin below:
# https://github.com/ESMValGroup/ESMValCore/issues/2407
'scitools-iris>3.8.0',
'scitools-iris>=3.9.0',
'shapely>=2.0.0',
'stratify>=0.3',
'yamale',
Expand Down
35 changes: 32 additions & 3 deletions tests/integration/preprocessor/_mask/test_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,13 +182,17 @@ def test_mask_landseaice(self):
np.ma.set_fill_value(expected, 1e+20)
assert_array_equal(result_ice.data, expected)

def test_mask_fillvalues(self, mocker):
@pytest.mark.parametrize('lazy', [True, False])
def test_mask_fillvalues(self, mocker, lazy):
"""Test the fillvalues mask: func mask_fillvalues."""
data_1 = data_2 = self.mock_data
data_2.mask = np.ones((4, 3, 3), bool)
coords_spec = [(self.times, 0), (self.lats, 1), (self.lons, 2)]
cube_1 = iris.cube.Cube(data_1, dim_coords_and_dims=coords_spec)
cube_2 = iris.cube.Cube(data_2, dim_coords_and_dims=coords_spec)
if lazy:
cube_1.data = cube_1.lazy_data().rechunk((2, None, None))
cube_2.data = cube_2.lazy_data()
filename_1 = 'file1.nc'
filename_2 = 'file2.nc'
product_1 = mocker.create_autospec(
Expand All @@ -215,10 +219,17 @@ def test_mask_fillvalues(self, mocker):
result_1 = product.cubes[0]
if product.filename == filename_2:
result_2 = product.cubes[0]

assert cube_1.has_lazy_data() == lazy
assert cube_2.has_lazy_data() == lazy
assert result_1.has_lazy_data() == lazy
assert result_2.has_lazy_data() == lazy

assert_array_equal(result_2.data.mask, data_2.mask)
assert_array_equal(result_1.data, data_1)

def test_mask_fillvalues_zero_threshold(self, mocker):
@pytest.mark.parametrize('lazy', [True, False])
def test_mask_fillvalues_zero_threshold(self, mocker, lazy):
"""Test the fillvalues mask: func mask_fillvalues for 0-threshold."""
data_1 = self.mock_data
data_2 = self.mock_data[0:3]
Expand All @@ -232,6 +243,10 @@ def test_mask_fillvalues_zero_threshold(self, mocker):
coords_spec2 = [(self.time2, 0), (self.lats, 1), (self.lons, 2)]
cube_1 = iris.cube.Cube(data_1, dim_coords_and_dims=coords_spec)
cube_2 = iris.cube.Cube(data_2, dim_coords_and_dims=coords_spec2)
if lazy:
cube_1.data = cube_1.lazy_data().rechunk((2, None, None))
cube_2.data = cube_2.lazy_data()

filename_1 = Path('file1.nc')
filename_2 = Path('file2.nc')
product_1 = mocker.create_autospec(
Expand All @@ -255,6 +270,12 @@ def test_mask_fillvalues_zero_threshold(self, mocker):
result_1 = product.cubes[0]
if product.filename == filename_2:
result_2 = product.cubes[0]

assert cube_1.has_lazy_data() == lazy
assert cube_2.has_lazy_data() == lazy
assert result_1.has_lazy_data() == lazy
assert result_2.has_lazy_data() == lazy

# identical masks
assert_array_equal(
result_2.data[0, ...].mask,
Expand All @@ -265,7 +286,8 @@ def test_mask_fillvalues_zero_threshold(self, mocker):
assert_array_equal(result_1[1:2].data.mask, cumulative_mask)
assert_array_equal(result_2[2:3].data.mask, cumulative_mask)

def test_mask_fillvalues_min_value_none(self, mocker):
@pytest.mark.parametrize('lazy', [True, False])
def test_mask_fillvalues_min_value_none(self, mocker, lazy):
"""Test ``mask_fillvalues`` for min_value=None."""
# We use non-masked data here and explicitly set some values to 0 here
# since this caused problems in the past, see
Expand All @@ -278,6 +300,10 @@ def test_mask_fillvalues_min_value_none(self, mocker):
coords_spec2 = [(self.time2, 0), (self.lats, 1), (self.lons, 2)]
cube_1 = iris.cube.Cube(data_1, dim_coords_and_dims=coords_spec)
cube_2 = iris.cube.Cube(data_2, dim_coords_and_dims=coords_spec2)
if lazy:
cube_1.data = cube_1.lazy_data().rechunk((2, None, None))
cube_2.data = cube_2.lazy_data()

filename_1 = Path('file1.nc')
filename_2 = Path('file2.nc')

Expand All @@ -303,10 +329,13 @@ def test_mask_fillvalues_min_value_none(self, mocker):
min_value=None,
)

assert cube_1.has_lazy_data() == lazy
assert cube_2.has_lazy_data() == lazy
assert len(results) == 2
for product in results:
if product.filename in (filename_1, filename_2):
assert len(product.cubes) == 1
assert product.cubes[0].has_lazy_data() == lazy
assert not np.ma.is_masked(product.cubes[0].data)
else:
assert False, f"Invalid filename: {product.filename}"
1 change: 1 addition & 0 deletions tests/unit/preprocessor/_mask/test_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np

import iris
import iris.fileformats
import tests
from cf_units import Unit
from esmvalcore.preprocessor._mask import (_apply_fx_mask,
Expand Down

0 comments on commit af5490f

Please sign in to comment.