@@ -26,21 +26,21 @@ def get_axis_dim(self, axis: str) -> int:
2626@pytest .mark .parametrize (
2727 "grid" ,
2828 [
29- TestGrid ({"Z" : [ 10 ] , "Y" : [ 20 ] , "X" : [ 30 ] }),
30- TestGrid ({"Z" : [ 5 ] , "Y" : [ 15 ] }),
31- TestGrid ({"Z" : [ 8 ] }),
32- TestGrid ({"Z" : [ 12 ] , "FACE" : [ 25 ] }),
29+ TestGrid ({"Z" : 10 , "Y" : 20 , "X" : 30 }),
30+ TestGrid ({"Z" : 5 , "Y" : 15 }),
31+ TestGrid ({"Z" : 8 }),
32+ TestGrid ({"Z" : 12 , "FACE" : 25 }),
3333 ],
3434)
3535def test_basegrid_ravel_unravel_index (grid ):
3636 axes = grid .axes
3737 dimensionalities = (grid .get_axis_dim (axis ) for axis in axes )
38- all_possible_axis_indices = itertools .product (* [range (dim [ 0 ]) for dim in dimensionalities ])
38+ all_possible_axis_indices = itertools .product (* [np . arange (dim )[:, np . newaxis ] for dim in dimensionalities ])
3939
4040 encountered_eis = []
4141
4242 for axis_indices_numeric in all_possible_axis_indices :
43- axis_indices = { axis : [ index ] for axis , index in zip (axes , axis_indices_numeric , strict = True )}
43+ axis_indices = dict ( zip (axes , axis_indices_numeric , strict = True ))
4444
4545 ei = grid .ravel_index (axis_indices )
4646 axis_indices_test = grid .unravel_index (ei )
0 commit comments