Skip to content

Commit 7f770d5

Browse files
authored
Merge pull request #46 from ecmwf/43-split-_reduce_dataarray-into-2-functions
43 split reduce dataarray into 2 functions
2 parents f68cb77 + 4264359 commit 7f770d5

3 files changed

Lines changed: 197 additions & 58 deletions

File tree

src/earthkit/transforms/aggregate/general.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,7 @@ def _rolling_reduce_dataarray(
230230
data_rolling = dataarray.rolling(**rolling_kwargs)
231231

232232
reduce_kwargs.setdefault("how", how_reduce)
233+
# TODO: remove type ignore when xarray puts types in stable location
233234
data_windowed = _reduce_dataarray(data_rolling, **reduce_kwargs) # type: ignore
234235

235236
data_windowed = _dropna(data_windowed, window_dims, how_dropna)

src/earthkit/transforms/aggregate/spatial.py

Lines changed: 98 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,7 @@ def mask(
365365
if union_geometries:
366366
out = masked_arrays[0]
367367
else:
368+
# TODO: remove ignore type if xarray concat typing is updated
368369
out = xr.concat(masked_arrays, dim=mask_dim_index.name) # type: ignore
369370
if chunk:
370371
out = out.chunk({mask_dim_index.name: 1})
@@ -418,6 +419,10 @@ def reduce(
418419
They must be on the same spatial grid as the dataarray.
419420
return_as :
420421
what format to return the data object, `pandas` or `xarray`. Work In Progress
422+
compact :
423+
If True, return a compact pandas.DataFrame with the reduced data as a new column.
424+
If False, return a fully expanded pandas.DataFrame.
425+
Only valid if return_as is `pandas`
421426
how_label :
422427
label to append to variable name in returned object, default is not to append
423428
kwargs :
@@ -443,37 +448,39 @@ def reduce(
443448
if return_as in ["xarray"]:
444449
out_ds = xr.Dataset().assign_attrs(dataarray.attrs)
445450
for var in dataarray.data_vars:
446-
out_da = _reduce_dataarray(
451+
out_da = _reduce_dataarray_as_xarray(
447452
dataarray[var], geodataframe=geodataframe, mask_arrays=_mask_arrays, **kwargs
448453
)
449454
out_ds[out_da.name] = out_da
450455
return out_ds
451456
elif "pandas" in return_as:
452457
logger.warning(
453-
"Returning reduced data in pandas format is considered experimental and may change in future"
458+
"Returning reduced data in pandas format is considered "
459+
"experimental and may change in future"
454460
"versions of earthkit"
455461
)
456462
if geodataframe is not None:
457463
out = geodataframe
458464
for var in dataarray.data_vars:
459-
out = _reduce_dataarray(dataarray[var], geodataframe=out, **kwargs)
465+
out = _reduce_dataarray_as_pandas(dataarray[var], geodataframe=out, **kwargs)
460466
else:
461467
out = None
462468
for var in dataarray.data_vars:
463-
_out = _reduce_dataarray(dataarray[var], mask_arrays=_mask_arrays, **kwargs)
469+
_out = _reduce_dataarray_as_pandas(dataarray[var], mask_arrays=_mask_arrays, **kwargs)
464470
if out is None:
465471
out = _out
466472
else:
467-
out = pd.merge(out, _out) # type: ignore
473+
out = pd.merge(out, _out)
468474
return out
469475
else:
470476
raise TypeError("Return as type not recognised or incompatible with inputs")
471477
else:
472-
return _reduce_dataarray(dataarray, geodataframe=geodataframe, mask_arrays=_mask_arrays, **kwargs) # type: ignore
478+
return _reduce_dataarray_as_xarray(
479+
dataarray, geodataframe=geodataframe, mask_arrays=_mask_arrays, **kwargs
480+
)
473481

474482

475-
# TODO: split into two functions, one for xarray and one for pandas
476-
def _reduce_dataarray(
483+
def _reduce_dataarray_as_xarray(
477484
dataarray: xr.DataArray,
478485
geodataframe: gpd.GeoDataFrame | None = None,
479486
mask_arrays: list[xr.DataArray] | None = None,
@@ -483,14 +490,13 @@ def _reduce_dataarray(
483490
lon_key: str | None = None,
484491
extra_reduce_dims: list | str = [],
485492
mask_dim: str | None = None,
486-
return_as: str = "xarray",
487493
how_label: str | None = None,
488494
squeeze: bool = True,
489495
all_touched: bool = False,
490496
mask_kwargs: dict[str, T.Any] = dict(),
491497
return_geometry_as_coord: bool = False,
492498
**reduce_kwargs,
493-
) -> xr.DataArray | pd.DataFrame:
499+
) -> xr.DataArray:
494500
"""Reduce an xarray.DataArray object over its geospatial dimensions using the specified 'how' method.
495501
496502
If a geodataframe is provided the DataArray is reduced over each feature in the geodataframe.
@@ -538,7 +544,7 @@ def _reduce_dataarray(
538544
539545
Returns
540546
-------
541-
xr.Dataset | xr.DataArray | pd.DataFrame
547+
xr.DataArray
542548
A data array with dimensions [features] + [data.dims not in ['lat','lon']].
543549
Each slice of layer corresponds to a feature in layer
544550
@@ -576,6 +582,7 @@ def _reduce_dataarray(
576582
comp for comp in [how_str, dataarray.attrs.get("long_name", dataarray.name)] if comp is not None
577583
]
578584
new_long_name = " ".join(new_long_name_components)
585+
extra_out_attrs.update({"long_name": new_long_name})
579586
new_short_name_components = [f"{comp}" for comp in [dataarray.name, how_label] if comp is not None]
580587
new_short_name = "_".join(new_short_name_components)
581588

@@ -596,7 +603,8 @@ def _reduce_dataarray(
596603
reduce_dims = spatial_dims + extra_reduce_dims
597604
extra_out_attrs.update({"reduce_dims": reduce_dims})
598605
reduce_kwargs.update({"dim": reduce_dims})
599-
# If using a pre-computed mask arrays, then iterator is just dataarray*mask_array
606+
# If using a pre-computed mask arrays,
607+
# then iterator is just dataarray*mask_array
600608
if mask_arrays is not None:
601609
masked_data_list = _array_mask_iterator(mask_arrays)
602610
else:
@@ -610,7 +618,8 @@ def _reduce_dataarray(
610618
for masked_data in masked_data_list:
611619
this = dataarray.where(masked_data, other=np.nan)
612620

613-
# If weighted, use xarray weighted arrays which correctly handle missing values etc.
621+
# If weighted, use xarray weighted arrays which
622+
# correctly handle missing values etc.
614623
if weights is not None:
615624
this_weighted = this.weighted(_weights)
616625
reduced_list.append(this_weighted.__getattribute__(weighted_how)(**reduce_kwargs))
@@ -626,59 +635,91 @@ def _reduce_dataarray(
626635
if geodataframe is not None:
627636
mask_dim_index = get_mask_dim_index(mask_dim, geodataframe)
628637
out_xr = xr.concat(reduced_list, dim=mask_dim_index)
629-
elif len(reduced_list) == 1:
638+
elif mask_dim is None and len(reduced_list) == 1:
630639
out_xr = reduced_list[0]
631640
else:
632641
_concat_dim_name = mask_dim or "index"
633642
out_xr = xr.concat(reduced_list, dim=_concat_dim_name)
634643

635644
out_xr = out_xr.rename(new_short_name)
645+
if geodataframe is not None:
646+
if return_geometry_as_coord:
647+
out_xr = out_xr.assign_coords(
648+
**{"geometry": (mask_dim_index.name, [_g for _g in geodataframe["geometry"]])}
649+
)
650+
out_xr = out_xr.assign_attrs({**geodataframe.attrs, **extra_out_attrs})
636651

637-
if "pandas" in return_as:
638-
reduce_attrs = {
639-
f"{dataarray.name}": dataarray.attrs,
640-
f"{new_short_name}": {
641-
"long_name": new_long_name,
642-
"units": dataarray.attrs.get("units", "No units found"),
643-
**extra_out_attrs,
644-
},
645-
}
652+
return out_xr
646653

647-
if geodataframe is None:
648-
# If no geodataframe, then just convert xarray to dataframe
649-
out = out_xr.to_dataframe()
650-
else:
651-
# Otherwise, splice the geodataframe and reduced xarray
652-
reduce_attrs = {
653-
**geodataframe.attrs.get("reduce_attrs", {}),
654-
**reduce_attrs,
655-
}
656-
out = geodataframe.set_index(mask_dim_index)
657-
if return_as in ["pandas"]: # Return as a fully expanded pandas.DataFrame
658-
# Convert to DataFrame
659-
out = out.join(out_xr.to_dataframe())
660-
elif return_as in ["pandas_compact"]:
661-
# add the reduced data into a new column as a numpy array,
662-
# store the dim information in the attributes
663-
664-
# TODO: fix typing
665-
out_dims = {
666-
dim: dataarray.coords.get(dim).values if dim in dataarray.coords else None # type: ignore
667-
for dim in reduced_list[0].dims
668-
}
669-
reduce_attrs[f"{new_short_name}"].update({"dims": out_dims})
670-
reduced_list = [red.values for red in reduced_list]
671-
out = out.assign(**{new_short_name: reduced_list})
654+
655+
def _reduce_dataarray_as_pandas(
656+
dataarray: xr.DataArray, geodataframe: gpd.GeoDataFrame | None = None, compact: bool = False, **kwargs
657+
) -> pd.DataFrame:
658+
"""Reduce an xarray.DataArray object over its geospatial dimensions using the specified 'how' method.
659+
660+
If a geodataframe is provided the DataArray is reduced over each feature in the geodataframe.
661+
Geospatial coordinates are reduced to a dimension representing the list of features in the shape object.
662+
663+
Parameters
664+
----------
665+
dataarray :
666+
Xarray data object (must have geospatial coordinates).
667+
geodataframe :
668+
Geopandas Dataframe containing the polygons for aggregations
669+
compact :
670+
If True, return a compact pandas.DataFrame with the reduced data as a new column.
671+
If False, return a fully expanded pandas.DataFrame.
672+
kwargs :
673+
kwargs accepted by the :function:_reduce_dataarray_as_xarray function
674+
675+
Returns
676+
-------
677+
pd.DataFrame
678+
A pandas.DataFrame similar to the geopandas dataframe, with the reduced data
679+
added as a new column.
680+
681+
"""
682+
out_xr = _reduce_dataarray_as_xarray(dataarray, **kwargs)
683+
684+
reduce_attrs = {f"{dataarray.name}": dataarray.attrs, f"{out_xr.name}": out_xr.attrs}
685+
686+
if geodataframe is None:
687+
mask_dim = kwargs.get("mask_dim", "index")
688+
if mask_dim not in out_xr.dims:
689+
out_xr = xr.concat([out_xr], dim=mask_dim)
690+
# If no geodataframe, then just convert xarray to dataframe
691+
out = out_xr.to_dataframe()
692+
# Add attributes to the dataframe
672693
out.attrs.update({"reduce_attrs": reduce_attrs})
694+
return out
695+
696+
# Otherwise, splice the geodataframe and reduced xarray
697+
reduce_attrs = {
698+
**geodataframe.attrs.get("reduce_attrs", {}),
699+
**reduce_attrs,
700+
}
701+
702+
# TODO: somehow remove repeat call of get_mask_dim_index (see _reduce_dataarray_as_xarray)
703+
mask_dim_index = get_mask_dim_index(kwargs.get("mask_dim"), geodataframe)
704+
mask_dim_name = mask_dim_index.name
705+
out = geodataframe.set_index(mask_dim_index)
706+
if mask_dim_name not in out_xr.dims:
707+
out_xr = xr.concat([out_xr], dim=mask_dim_name)
708+
if not compact: # Return as a fully expanded pandas.DataFrame
709+
# Convert to DataFrame
710+
out = out.join(out_xr.to_dataframe())
673711
else:
674-
if geodataframe is not None:
675-
if return_geometry_as_coord:
676-
out_xr = out_xr.assign_coords(
677-
**{
678-
"geometry": (mask_dim, [geom for geom in geodataframe["geometry"]]),
679-
}
680-
)
681-
out_xr = out_xr.assign_attrs({**geodataframe.attrs, **extra_out_attrs})
682-
out = out_xr
712+
# add the reduced data into a new column as a numpy array,
713+
# store the dim information in the attributes
714+
_out_dims = [str(dim) for dim in dataarray.coords if dim in out_xr.dims]
715+
out_dims = {dim: dataarray[dim].values for dim in _out_dims}
716+
reduce_attrs[f"{out_xr.name}"].update({"dims": out_dims})
717+
reduced_list = [
718+
out_xr.sel(**{mask_dim_name: mask_dim_value}).values
719+
for mask_dim_value in out_xr[mask_dim_name].values
720+
]
721+
out = out.assign(**{f"{out_xr.name}": reduced_list})
722+
723+
out.attrs.update({"reduce_attrs": reduce_attrs})
683724

684725
return out

tests/test_30_spatial.py

Lines changed: 98 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
1+
import geopandas as gpd
12
import numpy as np
23
import pandas as pd
34
import pytest
5+
import xarray as xr
46

57
# from earthkit.data.core.temporary import temp_directory
6-
import xarray as xr
78
from earthkit import data as ek_data
89
from earthkit.data.testing import earthkit_remote_test_data_file
910
from earthkit.transforms.aggregate import spatial
11+
from shapely.geometry import Polygon
1012

1113
try:
1214
import rasterio # noqa: F401
@@ -198,3 +200,98 @@ def test_mask_kwargs():
198200
reduced_data_nested_2 = spatial.reduce(era5_xr, nuts_DK, mask_kwargs=dict(all_touched=False))
199201
xr.testing.assert_equal(reduced_data_2, reduced_data_nested_2)
200202
np.testing.assert_allclose(reduced_data_2["2t"].mean(), 279.54733)
203+
204+
205+
def create_test_dataarray():
206+
lat = np.linspace(-90, 90, 10)
207+
lon = np.linspace(-180, 180, 20)
208+
data = np.random.rand(10, 20)
209+
return xr.DataArray(
210+
data,
211+
coords={"lat": lat, "lon": lon},
212+
dims=("lat", "lon"),
213+
name="test_var",
214+
)
215+
216+
217+
def create_test_geodataframe():
218+
polygons = [Polygon([(-180, -90), (-180, 90), (180, 90), (180, -90)])]
219+
return gpd.GeoDataFrame(geometry=polygons, index=[1])
220+
221+
222+
def test_reduce_mean():
223+
dataarray = create_test_dataarray()
224+
result = spatial._reduce_dataarray_as_xarray(dataarray, how="mean")
225+
assert isinstance(result, xr.DataArray)
226+
assert "lat" not in result.dims
227+
assert "lon" not in result.dims
228+
229+
230+
def test_reduce_with_geodataframe():
231+
dataarray = create_test_dataarray()
232+
geodataframe = create_test_geodataframe()
233+
result = spatial._reduce_dataarray_as_xarray(dataarray, geodataframe=geodataframe, how="mean")
234+
assert isinstance(result, xr.DataArray)
235+
assert "index" in result.dims # Default mask_dim is "index"
236+
237+
238+
def test_reduce_with_weights():
239+
dataarray = create_test_dataarray()
240+
result = spatial._reduce_dataarray_as_xarray(dataarray, how="mean", weights="latitude")
241+
assert isinstance(result, xr.DataArray)
242+
243+
244+
def test_reduce_invalid_how():
245+
dataarray = create_test_dataarray()
246+
with pytest.raises(ValueError):
247+
spatial._reduce_dataarray_as_xarray(dataarray, how="invalid_method")
248+
249+
250+
def test_reduce_with_mask():
251+
dataarray = create_test_dataarray()
252+
mask = xr.DataArray(
253+
np.random.randint(0, 2, size=dataarray.shape), coords=dataarray.coords, dims=dataarray.dims
254+
)
255+
result = spatial._reduce_dataarray_as_xarray(dataarray, mask_arrays=[mask], how="sum")
256+
assert isinstance(result, xr.DataArray)
257+
258+
259+
def test_return_geometry_as_coord():
260+
dataarray = create_test_dataarray()
261+
geodataframe = create_test_geodataframe()
262+
result = spatial._reduce_dataarray_as_xarray(
263+
dataarray, geodataframe=geodataframe, return_geometry_as_coord=True
264+
)
265+
assert "geometry" in result.coords
266+
assert len(result.coords["geometry"].values) == len(geodataframe)
267+
268+
269+
def test_reduce_as_pandas():
270+
dataarray = create_test_dataarray()
271+
result = spatial._reduce_dataarray_as_pandas(dataarray, how="mean")
272+
assert isinstance(result, pd.DataFrame)
273+
274+
275+
def test_reduce_as_pandas_with_geodataframe():
276+
dataarray = create_test_dataarray()
277+
geodataframe = create_test_geodataframe()
278+
result = spatial._reduce_dataarray_as_pandas(dataarray, geodataframe=geodataframe, how="mean")
279+
assert isinstance(result, pd.DataFrame)
280+
assert not result.empty
281+
282+
283+
def test_reduce_as_pandas_compact():
284+
dataarray = create_test_dataarray()
285+
geodataframe = create_test_geodataframe()
286+
result = spatial._reduce_dataarray_as_pandas(
287+
dataarray, geodataframe=geodataframe, compact=True, how="mean"
288+
)
289+
assert isinstance(result, pd.DataFrame)
290+
assert f"{dataarray.name}" in result.columns
291+
292+
293+
def test_reduce_as_pandas_without_geodataframe():
294+
dataarray = create_test_dataarray()
295+
result = spatial._reduce_dataarray_as_pandas(dataarray, how="sum")
296+
assert isinstance(result, pd.DataFrame)
297+
assert not result.empty

0 commit comments

Comments
 (0)