Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions test/core/test_isel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import pytest
import uxarray as ux
import numpy as np

@pytest.fixture()
def ds():
uxgrid = ux.Grid.from_healpix(zoom=1)
t_var = ux.UxDataArray(data=np.ones((3,)), dims=['time'], uxgrid=uxgrid)
fc_var = ux.UxDataArray(data=np.ones((3, uxgrid.n_face)), dims=['time', 'n_face'], uxgrid=uxgrid)
nc_var = ux.UxDataArray(data=np.ones((3, uxgrid.n_node)), dims=['time', 'n_node'], uxgrid=uxgrid)

uxds = ux.UxDataset({"fc": fc_var, "nc": nc_var, "t": t_var}, uxgrid=uxgrid)

uxds["fc"] = uxds["fc"].assign_coords(face_id=("n_face", np.arange(uxgrid.n_face)))
uxds["nc"] = uxds["nc"].assign_coords(node_id=("n_node", np.arange(uxgrid.n_node)))
uxds["t"] = uxds["t"].assign_coords(time_id=("time", np.arange(uxds.dims["time"])))


return uxds


class TestDataset:

def test_isel_face_dim(self, ds):
ds_f_single = ds.isel(n_face=0)

assert len(ds_f_single.coords) == 3

assert ds_f_single.uxgrid != ds.uxgrid
assert ds_f_single.sizes['n_face'] == 1
assert ds_f_single.sizes['n_node'] >= 4

ds_f_multi = ds.isel(n_face=[0, 1])

assert len(ds_f_multi.coords) == 3

assert ds_f_multi.uxgrid != ds.uxgrid
assert ds_f_multi.sizes['n_face'] == 2
assert ds_f_multi.sizes['n_node'] >= 4


def test_isel_node_dim(self, ds):
ds_n_single = ds.isel(n_node=0)

assert len(ds_n_single.coords) == 3

assert ds_n_single.uxgrid != ds.uxgrid
assert ds_n_single.sizes['n_face'] >= 1

ds_n_multi = ds.isel(n_node=[0, 1])

assert len(ds_n_multi.coords) == 3

assert ds_n_multi.uxgrid != ds.uxgrid
assert ds_n_multi.uxgrid.sizes['n_face'] >= 1

def test_isel_non_grid_dim(self, ds):
ds_t_single = ds.isel(time=0)

assert len(ds_t_single.coords) == 3

assert ds_t_single.uxgrid == ds.uxgrid
assert "time" not in ds_t_single.sizes

ds_t_multi = ds.isel(time=[0, 1])

assert len(ds_t_multi.coords) == 3

assert ds_t_multi.uxgrid == ds.uxgrid
assert ds_t_multi.sizes['time'] == 2
81 changes: 35 additions & 46 deletions uxarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1256,7 +1256,7 @@ def isel(
**indexers_kwargs,
):
"""
Grid-aware index selection.
Return a new DataArray whose data is given by selecting indexes along the specified dimension(s).

Performs xarray-style integer-location indexing along specified dimensions.
If a single grid dimension ('n_node', 'n_edge', or 'n_face') is provided
Expand Down Expand Up @@ -1292,61 +1292,50 @@ def isel(

Raises
------
TypeError
If `indexers` is provided and is not a Mapping.
ValueError
If more than one grid dimension is selected and `ignore_grid=False`.
"""
from uxarray.constants import GRID_DIMS
from uxarray.core.dataarray import UxDataArray

# merge dict‐style + kw‐style indexers
idx_map = {}
if indexers is not None:
if not isinstance(indexers, dict):
raise TypeError("`indexers` must be a dict of dimension indexers")
idx_map.update(indexers)
idx_map.update(indexers_kwargs)

# detect grid dims
grid_dims = [
d
for d in GRID_DIMS
if d in idx_map
and not (isinstance(idx_map[d], slice) and idx_map[d] == slice(None))
]
from uxarray.core.utils import _validate_indexers

# Grid Branch
if not ignore_grid and len(grid_dims) == 1:
# pop off the one grid‐dim indexer
grid_dim = grid_dims[0]
grid_indexer = idx_map.pop(grid_dim)

# slice the grid
sliced_grid = self.uxgrid.isel(
**{grid_dim: grid_indexer}, inverse_indices=inverse_indices
)

da = self._slice_from_grid(sliced_grid)
indexers, grid_dims = _validate_indexers(
indexers, indexers_kwargs, "isel", ignore_grid
)

# if there are any remaining indexers, apply them
if idx_map:
xarr = super(UxDataArray, da).isel(
indexers=idx_map, drop=drop, missing_dims=missing_dims
# Grid Branch
if not ignore_grid:
if len(grid_dims) == 1:
# pop off the one grid‐dim indexer
grid_dim = grid_dims.pop()
grid_indexer = indexers.pop(grid_dim)

sliced_grid = self.uxgrid.isel(
**{grid_dim: grid_indexer}, inverse_indices=inverse_indices
)
# re‐wrap so the grid sticks around
return UxDataArray(xarr, uxgrid=sliced_grid)

# no other dims, return the grid‐sliced da
return da
da = self._slice_from_grid(sliced_grid)

# More than one grid dim provided
if not ignore_grid and len(grid_dims) > 1:
raise ValueError("Only one grid dimension can be sliced at a time")
# if there are any remaining indexers, apply them
if indexers:
xarr = super(UxDataArray, da).isel(
indexers=indexers, drop=drop, missing_dims=missing_dims
)
# re‐wrap so the grid sticks around
return type(self)(xarr, uxgrid=sliced_grid)

# no other dims, return the grid‐sliced da
return da
else:
return type(self)(
super().isel(
indexers=indexers or None,
drop=drop,
missing_dims=missing_dims,
),
uxgrid=self.uxgrid,
)

# Fallback to Xarray
return super().isel(
indexers=idx_map or None,
indexers=indexers or None,
drop=drop,
missing_dims=missing_dims,
)
Expand Down
126 changes: 125 additions & 1 deletion uxarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import sys
from html import escape
from typing import IO, Any, Optional, Union
from typing import IO, Any, Mapping, Optional, Union
from warnings import warn

import numpy as np
Expand Down Expand Up @@ -343,6 +343,130 @@ def from_healpix(

return cls.from_xarray(ds, uxgrid, {face_dim: "n_face"})

def _slice_dataset_from_grid(self, sliced_grid, grid_dim: str, grid_indexer):
data_vars = {}
for name, da in self.data_vars.items():
if grid_dim in da.dims:
if hasattr(da, "_slice_from_grid"):
data_vars[name] = da._slice_from_grid(sliced_grid)
else:
data_vars[name] = da.isel({grid_dim: grid_indexer})
else:
data_vars[name] = da

coords = {}
for cname, cda in self.coords.items():
if grid_dim in cda.dims:
# Prefer authoritative coords from the sliced grid if available
replacement = getattr(sliced_grid, cname, None)
coords[cname] = (
replacement
if replacement is not None
else cda.isel({grid_dim: grid_indexer})
)
else:
coords[cname] = cda

ds = xr.Dataset(data_vars=data_vars, coords=coords, attrs=self.attrs)

return ds

def isel(
self,
indexers: Mapping[Any, Any] | None = None,
drop: bool = False,
missing_dims: str = "raise",
ignore_grid: bool = False,
inverse_indices: bool = False,
**indexers_kwargs,
):
"""Returns a new dataset with each array indexed along the specified
dimension(s).

Performs xarray-style integer-location indexing along specified dimensions.
If a single grid dimension ('n_node', 'n_edge', or 'n_face') is provided
and `ignore_grid=False`, the underlying grid is sliced accordingly,
and remaining indexers are applied to the resulting Dataset.

Parameters
----------
indexers : dict, optional
A dict with keys matching dimensions and values given
by integers, slice objects or arrays.
indexer can be a integer, slice, array-like or DataArray.
If DataArrays are passed as indexers, xarray-style indexing will be
carried out. See :ref:`indexing` for the details.
One of indexers or indexers_kwargs must be provided.
drop : bool, default: False
If ``drop=True``, drop coordinates variables indexed by integers
instead of making them scalar.
missing_dims : {"raise", "warn", "ignore"}, default: "raise"
What to do if dimensions that should be selected from are not present in the
Dataset:
- "raise": raise an exception
- "warn": raise a warning, and ignore the missing dimensions
- "ignore": ignore the missing dimensions
ignore_grid : bool, default=False
If False (default), allow slicing on one grid dimension to automatically
update the associated UXarray grid. If True, fall back to pure xarray behavior.
inverse_indices : bool, default=False
For grid-based slicing, pass this flag to `Grid.isel` to invert indices
when selecting (useful for staggering or reversing order).
**indexers_kwargs : dimension=indexer pairs, optional

**indexers_kwargs : {dim: indexer, ...}, optional
The keyword arguments form of ``indexers``.
One of indexers or indexers_kwargs must be provided.

Returns
-------
UxDataset
A new UxDataset indexed according to `indexers` and updated grid if applicable.
"""
from uxarray.core.utils import _validate_indexers

indexers, grid_dims = _validate_indexers(
indexers, indexers_kwargs, "isel", ignore_grid
)

if not ignore_grid:
if len(grid_dims) == 1:
grid_dim = grid_dims.pop()
grid_indexer = indexers.pop(grid_dim)

# slice the grid
sliced_grid = self.uxgrid.isel(
**{grid_dim: grid_indexer}, inverse_indices=inverse_indices
)

ds = self._slice_dataset_from_grid(
sliced_grid=sliced_grid,
grid_dim=grid_dim,
grid_indexer=grid_indexer,
)

if indexers:
ds = xr.Dataset.isel(
ds, indexers=indexers, drop=drop, missing_dims=missing_dims
)

return type(self)(ds, uxgrid=sliced_grid)
else:
return type(self)(
super().isel(
indexers=indexers or None,
drop=drop,
missing_dims=missing_dims,
),
uxgrid=self.uxgrid,
)

return super().isel(
indexers=indexers or None,
drop=drop,
missing_dims=missing_dims,
)

def __getattribute__(self, name):
"""Intercept accessor method calls to return Ux-aware accessors."""
# Lazy import to avoid circular imports
Expand Down
16 changes: 16 additions & 0 deletions uxarray/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,19 @@ def match_chunks_to_ugrid(grid_filename_or_obj, chunks):
chunks[original_grid_dim] = chunks[ugrid_grid_dim]

return chunks


def _validate_indexers(indexers, indexers_kwargs, func_name, ignore_grid):
from xarray.core.utils import either_dict_or_kwargs

from uxarray.constants import GRID_DIMS

indexers = either_dict_or_kwargs(indexers, indexers_kwargs, func_name)
grid_dims = set(GRID_DIMS).intersection(indexers)

if not ignore_grid and len(grid_dims) > 1:
raise ValueError(
f"Only one grid dimension can be sliced at a time; got {sorted(grid_dims)}."
)

return indexers, grid_dims
Loading