Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/earthkit/data/core/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from earthkit.data.core.order import Patch, Remapping, build_remapping
from earthkit.data.decorators import normalise
from earthkit.data.utils.args import metadata_argument_new
from earthkit.data.utils.array import flatten_array, reshape_array, target_shape
from earthkit.data.utils.array import flatten_array, outer_indexing, reshape_array, target_shape
from earthkit.data.utils.compute import wrap_maths

GRIB = "grib"
Expand Down Expand Up @@ -485,7 +485,7 @@ def to_numpy(self, flatten=False, dtype=None, copy=True, index=None):
v = flatten_array(v) if flatten else reshape_array(v, self.shape)

if index is not None:
v = v[index]
v = outer_indexing(v, index)

return v

Expand Down Expand Up @@ -531,7 +531,7 @@ def to_array(

v = flatten_array(v) if flatten else reshape_array(v, self.shape)
if index is not None:
v = v[index]
v = outer_indexing(v, index)

return v

Expand Down Expand Up @@ -615,7 +615,7 @@ def _reshape(v, flatten):
raise ValueError(f"data: {k} not available")
v = _reshape(v, flatten)
if index is not None:
v = v[index]
v = outer_indexing(v, index)
r[k] = v

# convert latlon to array format
Expand Down
1 change: 1 addition & 0 deletions src/earthkit/data/data/wrappers/xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# nor does it submit to any jurisdiction.
#


from . import ObjectWrapperData


Expand Down
34 changes: 18 additions & 16 deletions src/earthkit/data/indexing/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,14 +419,13 @@ def _prepare_tensor_data(self, source_to_array_func, index=None):
# * `field_shape` does lose the dimension `dim` if `dim` is a field dimension
current_field_shape = []
for n, _idx in zip(self.field_shape, index):
if isinstance(_idx, int):
# simply, ignore this index
_sizes = np.arange(n)[_idx].shape
if len(_sizes) == 0:
# _idx is a scalar indexer, and thus we ignore it
continue
if isinstance(_idx, slice):
_size = len(range(n)[_idx])
else:
# _idx must be an iterable of integers
_size = len(np.arange(n)[_idx])
# _idx is a slice, an array of int's or a boolean mask
# get the size of the selection made by the indexer _idx
(_size,) = _sizes
current_field_shape.append(_size)
current_field_shape = tuple(current_field_shape)

Expand Down Expand Up @@ -461,13 +460,16 @@ def field_indexes(self, indexes):
return indexes[len(self._user_shape) :]

def is_full_field(self, indexes):
assert len(indexes) == len(self._field_shape)
for i, s in enumerate(indexes):
if not (
s is None
or isinstance(s, slice)
and (s == slice(None, None, None) or s == slice(0, self._field_shape[i], 1))
):
def is_full_indexer(index, size):
if index is None:
return True
if isinstance(index, slice):
return index.indices(size) == (0, size, 1)
full_indexer = np.arange(size)
return np.array_equal(full_indexer[index], full_indexer)

for index, size in zip(indexes, self._field_shape, strict=True):
if not is_full_indexer(index, size):
return False
return True

Expand All @@ -481,7 +483,7 @@ def _subset(self, indexes):
user_indexes = []

for s, c in zip(indexes, self._user_shape):
lst = np.array(list(range(c)))[s].tolist()
lst = np.arange(c)[s].tolist()
if not isinstance(lst, list):
lst = [lst]
user_coords.append(lst)
Expand Down Expand Up @@ -683,7 +685,7 @@ def _subset(self, indexes):
user_indexes = []

for s, c in zip(indexes, self._user_shape):
lst = np.array(list(range(c)))[s].tolist()
lst = np.arange(c)[s].tolist()
if not isinstance(lst, list):
lst = [lst]
user_icoords.append(lst)
Expand Down
39 changes: 39 additions & 0 deletions src/earthkit/data/utils/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,42 @@ def adjust_array(v, flatten=False, dtype=None):
v = target_xp.astype(v, target_dtype, copy=False)

return v


def outer_indexing(v, indices):
"""Performs an outer indexing of an array ``v``.
The parameter ``indices`` is a tuple of indices.
Each index is either an int, a slice of an array of int's.
Each index form ``indices`` restricts/sub-selects the corresponding dimension of the array ``v``,
independently (orthogonally) to other indices (hence the name: "outer indexing").

This function mimics the behaviour of ``xarray.DataArray(v)[indices].values``,
and in general it is different from the behaviour of e.g.
the numpy indexing, i.e. ``v[indices]``
(see https://numpy.org/doc/stable/user/basics.indexing.html#advanced-indexing).

Parameters
----------
v: array-like
The array to be reshaped.
indices: tuple of items, each being an int, a slice or an array-like of int's

Returns
-------
array-like
Sub-selection of the array ``v`` according to ``indices``
"""
if not isinstance(indices, tuple):
indices = (indices,)
full_slices = ()
ndim = v.ndim
for idx in indices:
_1d_index = full_slices + (idx,)
v = v[_1d_index]
v_ndim = v.ndim
if v_ndim == ndim:
full_slices = full_slices + (slice(None),)
else:
# the current dimension has collapsed
ndim = v_ndim
return v
118 changes: 68 additions & 50 deletions tests/xr_engine/test_xr_engine_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,35 +216,44 @@ def test_xr_engine_detailed_check_1(allow_holes, lazy_load, api):
vals_ref = np.array([[269.00918579, 268.78610229], [268.57771301, 268.08932495]])
assert np.allclose(r.values, vals_ref)

r = da[:, 0, :, 2, 9:12, :2]
assert r.shape == (2, 2, 3, 2)
vals_ref = np.array([
[
[
[269.00918579, 269.31680298],
[269.70254517, 269.81387329],
[267.50527954, 266.83828735],
],
[
[268.78610229, 268.80758667],
[269.52731323, 269.75680542],
[266.61813354, 267.12106323],
],
],
[
r = da[:, 0, :, 2, [9], [True] + [False] * 35]
assert r.shape == (2, 2, 1, 1)
vals_ref = np.array([[[[269.00918579]], [[268.78610229]]], [[[268.57771301]], [[268.08932495]]]])
assert np.allclose(r.values, vals_ref)

for r in (
da[:, 0, :, 2, 9:12, :2],
da[:, 0, :, 2, [9, 10, 11], [0, 1]],
da[:, 0, :, 2, [9, 10, 11], [True, True] + [False] * 34],
):
assert r.shape == (2, 2, 3, 2)
vals_ref = np.array([
[
[268.57771301, 269.03767395],
[269.33357239, 269.56111145],
[264.75154114, 266.55036926],
[
[269.00918579, 269.31680298],
[269.70254517, 269.81387329],
[267.50527954, 266.83828735],
],
[
[268.78610229, 268.80758667],
[269.52731323, 269.75680542],
[266.61813354, 267.12106323],
],
],
[
[268.08932495, 268.35983276],
[269.01803589, 269.02389526],
[264.29733276, 266.08248901],
[
[268.57771301, 269.03767395],
[269.33357239, 269.56111145],
[264.75154114, 266.55036926],
],
[
[268.08932495, 268.35983276],
[269.01803589, 269.02389526],
[264.29733276, 266.08248901],
],
],
],
])
assert np.allclose(r.values, vals_ref)
])
assert np.allclose(r.values, vals_ref)

r = da.loc[:, 0, :, 500, 0, 0]
assert r.shape == (2, 2)
Expand Down Expand Up @@ -408,35 +417,44 @@ def test_xr_engine_detailed_check_2(allow_holes, lazy_load, api):
vals_ref = np.array([[269.00918579, 268.78610229], [268.57771301, 268.08932495]])
assert np.allclose(r.values, vals_ref)

r = da[:, 0, :, 2, 9:12, :2]
assert r.shape == (2, 2, 3, 2)
vals_ref = np.array([
[
[
[269.00918579, 269.31680298],
[269.70254517, 269.81387329],
[267.50527954, 266.83828735],
],
[
[268.78610229, 268.80758667],
[269.52731323, 269.75680542],
[266.61813354, 267.12106323],
],
],
[
r = da[:, 0, :, 2, [9], [True] + [False] * 35]
assert r.shape == (2, 2, 1, 1)
vals_ref = np.array([[[[269.00918579]], [[268.78610229]]], [[[268.57771301]], [[268.08932495]]]])
assert np.allclose(r.values, vals_ref)

for r in (
da[:, 0, :, 2, 9:12, :2],
da[:, 0, :, 2, [9, 10, 11], [0, 1]],
da[:, 0, :, 2, [9, 10, 11], [True, True] + [False] * 34],
):
assert r.shape == (2, 2, 3, 2)
vals_ref = np.array([
[
[268.57771301, 269.03767395],
[269.33357239, 269.56111145],
[264.75154114, 266.55036926],
[
[269.00918579, 269.31680298],
[269.70254517, 269.81387329],
[267.50527954, 266.83828735],
],
[
[268.78610229, 268.80758667],
[269.52731323, 269.75680542],
[266.61813354, 267.12106323],
],
],
[
[268.08932495, 268.35983276],
[269.01803589, 269.02389526],
[264.29733276, 266.08248901],
[
[268.57771301, 269.03767395],
[269.33357239, 269.56111145],
[264.75154114, 266.55036926],
],
[
[268.08932495, 268.35983276],
[269.01803589, 269.02389526],
[264.29733276, 266.08248901],
],
],
],
])
assert np.allclose(r.values, vals_ref)
])
assert np.allclose(r.values, vals_ref)

r = da.loc[:, datetime.time(0, 0), :, 500, 0, 0]
assert r.shape == (2, 2)
Expand Down
23 changes: 23 additions & 0 deletions tests/xr_engine/test_xr_engine_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,26 @@ def test_xr_engine_daily_mean(allow_holes, kwargs):
ds.resample({"valid_time": "24h"}).mean().groupby("valid_time.month") - monthly_mean_ds
)
assert np.allclose(np.abs(daily_anomaly_from_monthly_mean_ds).max()["2t"].values, 28.98466590143022)


@pytest.mark.cache
@pytest.mark.long_test
@pytest.mark.timeout(60)
def test_xr_engine_groupby_forecast_valid_time():
seas5_data = from_source(
"url",
"https://sites.ecmwf.int/repository/earthkit-data/test-data/seas5_2m_temperature_201501-201503_europe_1deg.grib",
)
seas5_data = seas5_data.to_fieldlist()
seas5_xr = seas5_data.to_xarray(
time_dim_mode="forecast",
add_valid_time_coord=True,
).rename({"2t": "t2m"})

seas5_xr_mean = seas5_xr.groupby("valid_time.day").mean()
assert np.allclose(seas5_xr_mean.max()["t2m"].values, 288.32264709)

seas5_xr_mean2 = seas5_xr.sel(latitude=[50, 55, 60], longitude=[-10, 0, 10, 20]).groupby("valid_time.day").mean()
assert np.allclose(seas5_xr_mean2["latitude"].values, [50, 55, 60])
assert np.allclose(seas5_xr_mean2["longitude"].values, [-10, 0, 10, 20])
assert np.allclose(seas5_xr_mean2.max()["t2m"].values, 283.53151703)
Loading