Skip to content

Commit 070e916

Browse files
committed
Update test_rename_dims
1 parent 9e69e9f commit 070e916

File tree

1 file changed

+42
-61
lines changed

1 file changed

+42
-61
lines changed

tests/utils/test_sgrid.py

Lines changed: 42 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import xarray as xr
44
import xgcm
55
from hypothesis import assume, example, given
6-
from hypothesis import strategies as st
76

87
from parcels._core.utils import sgrid
98
from tests.strategies import sgrid as sgrid_strategies
@@ -238,76 +237,58 @@ def test_parse_sgrid_3d(grid_metadata: sgrid.Grid3DMetadata):
238237
assert coords[sgrid.SGRID_PADDING_TO_XGCM_POSITION[padding]] == dim_node
239238

240239

241-
@example(
242-
grids_and_dims_dict=(
243-
sgrid.Grid2DMetadata(
244-
cf_role="grid_topology",
245-
topology_dimension=2,
246-
node_dimensions=("node_dimension1", "node_dimension2"),
247-
face_dimensions=(
248-
sgrid.DimDimPadding("face_dimension1", "node_dimension1", sgrid.Padding.LOW),
249-
sgrid.DimDimPadding("face_dimension2", "node_dimension2", sgrid.Padding.LOW),
250-
),
251-
vertical_dimensions=(
252-
sgrid.DimDimPadding("vertical_dimensions_dim1", "vertical_dimensions_dim2", sgrid.Padding.LOW),
253-
),
254-
),
255-
sgrid.Grid2DMetadata(
256-
cf_role="grid_topology",
257-
topology_dimension=2,
258-
node_dimensions=("new_node_dimension1", "new_node_dimension2"),
259-
face_dimensions=(
260-
sgrid.DimDimPadding("new_face_dimension1", "new_node_dimension1", sgrid.Padding.LOW),
261-
sgrid.DimDimPadding("new_face_dimension2", "new_node_dimension2", sgrid.Padding.LOW),
262-
),
263-
vertical_dimensions=(
264-
sgrid.DimDimPadding("new_vertical_dimensions_dim1", "new_vertical_dimensions_dim2", sgrid.Padding.LOW),
265-
),
266-
),
267-
{
268-
"node_dimension1": "new_node_dimension1",
269-
"node_dimension2": "new_node_dimension2",
270-
"face_dimension1": "new_face_dimension1",
271-
"face_dimension2": "new_face_dimension2",
272-
"vertical_dimensions_dim1": "new_vertical_dimensions_dim1",
273-
"vertical_dimensions_dim2": "new_vertical_dimensions_dim2",
274-
},
275-
)
276-
)
277-
@given(
278-
grids_and_dims_dict=st.lists(sgrid_strategies.dimension_name, min_size=12, max_size=12, unique=True).map(
279-
lambda dims: (
240+
@pytest.mark.parametrize(
241+
"grid",
242+
[
243+
(
280244
sgrid.Grid2DMetadata(
281245
cf_role="grid_topology",
282246
topology_dimension=2,
283-
node_dimensions=(dims[0], dims[1]),
247+
node_dimensions=("node_dimension1", "node_dimension2"),
284248
face_dimensions=(
285-
sgrid.DimDimPadding(dims[2], dims[0], sgrid.Padding.LOW),
286-
sgrid.DimDimPadding(dims[3], dims[1], sgrid.Padding.LOW),
249+
sgrid.DimDimPadding("face_dimension1", "node_dimension1", sgrid.Padding.LOW),
250+
sgrid.DimDimPadding("face_dimension2", "node_dimension2", sgrid.Padding.LOW),
251+
),
252+
vertical_dimensions=(
253+
sgrid.DimDimPadding("vertical_dimensions_dim1", "vertical_dimensions_dim2", sgrid.Padding.LOW),
287254
),
288-
vertical_dimensions=(sgrid.DimDimPadding(dims[4], dims[5], sgrid.Padding.LOW),),
289-
),
255+
)
256+
),
257+
(
290258
sgrid.Grid2DMetadata(
291259
cf_role="grid_topology",
292260
topology_dimension=2,
293-
node_dimensions=(dims[6 + 0], dims[6 + 1]),
261+
node_dimensions=("node_dimension1", "node_dimension2"),
294262
face_dimensions=(
295-
sgrid.DimDimPadding(dims[6 + 2], dims[6 + 0], sgrid.Padding.LOW),
296-
sgrid.DimDimPadding(dims[6 + 3], dims[6 + 1], sgrid.Padding.LOW),
263+
sgrid.DimDimPadding("face_dimension1", "node_dimension1", sgrid.Padding.LOW),
264+
sgrid.DimDimPadding("face_dimension2", "node_dimension2", sgrid.Padding.LOW),
297265
),
298-
vertical_dimensions=(sgrid.DimDimPadding(dims[6 + 4], dims[6 + 5], sgrid.Padding.LOW),),
299-
),
300-
{k: v for k, v in zip(dims[:6], dims[6:], strict=True)},
301-
)
302-
)
266+
vertical_dimensions=None,
267+
)
268+
),
269+
(
270+
sgrid.Grid3DMetadata(
271+
cf_role="grid_topology",
272+
topology_dimension=3,
273+
node_dimensions=("node_dimension1", "node_dimension2", "node_dimension3"),
274+
volume_dimensions=(
275+
sgrid.DimDimPadding("face_dimension1", "node_dimension1", sgrid.Padding.LOW),
276+
sgrid.DimDimPadding("face_dimension2", "node_dimension2", sgrid.Padding.LOW),
277+
sgrid.DimDimPadding("face_dimension3", "node_dimension3", sgrid.Padding.LOW),
278+
),
279+
)
280+
),
281+
],
303282
)
304-
def test_rename_dims(grids_and_dims_dict):
305-
"""Creates two SGrid 2D metadata objects with disjoint dimension names, and a mapping between the dimension names.
306-
Renames the dimensions of the first grid according to the mapping, and checks that the result
307-
is equal to the second grid.
308-
"""
309-
grid_old, grid_new, dims_dict = grids_and_dims_dict
310-
assert grid_old.rename_dims(dims_dict).to_attrs() == grid_new.to_attrs()
283+
def test_rename_dims(grid):
284+
dims = sgrid.get_unique_dim_names(grid)
285+
dims_dict = {dim: f"new_{dim}" for dim in dims}
286+
dims_dict_inv = {v: k for k, v in dims_dict.items()}
287+
288+
grid_new = grid.rename_dims(dims_dict)
289+
assert dims & set(sgrid.get_unique_dim_names(grid_new)) == set()
290+
291+
assert grid == grid_new.rename_dims(dims_dict_inv)
311292

312293

313294
def test_rename_dims_errors(grid2dmetadata):
@@ -318,7 +299,7 @@ def test_rename_dims_errors(grid2dmetadata):
318299
"node_dimension1": "new_node_dimension",
319300
"node_dimension2": "new_node_dimension",
320301
}
321-
with pytest.raises(AssertionError, match="dims_dict contains non-unique target dimension names"):
302+
with pytest.raises(AssertionError, match="dims_dict contains duplicate target dimension names"):
322303
grid.rename_dims(dims_dict)
323304
# Unexpected attribute in dims_dict
324305
dims_dict = {

0 commit comments

Comments
 (0)