Skip to content

Commit e3b00cd

Browse files
Reordering dimensions to time,depth,lat,lon in Grid class
1 parent 9ad4246 commit e3b00cd

File tree

11 files changed

+71
-58
lines changed

11 files changed

+71
-58
lines changed

parcels/field.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -130,12 +130,12 @@ class Field:
130130
A numpy array containing the timestamps for each of the files in filenames, for loading
131131
from netCDF files only. Default is None if the netCDF dimensions dictionary includes time.
132132
grid : parcels.grid.Grid
133-
:class:`parcels.grid.Grid` object containing all the lon, lat depth, time
133+
:class:`parcels.grid.Grid` object containing all the time, depth, lat, lon,
134134
mesh and time_origin information. Can be constructed from any of the Grid objects
135135
fieldtype : str
136136
Type of Field to be used for UnitConverter (either 'U', 'V', 'Kh_zonal', 'Kh_meridional' or None)
137137
transpose : bool
138-
Transpose data to required (lon, lat) layout
138+
Transpose data to required (time, depth, lat, lon) layout
139139
vmin : float
140140
Minimum allowed value on the field. Data below this value are set to zero
141141
vmax : float
@@ -211,7 +211,7 @@ def __init__(
211211
if grid:
212212
if grid.defer_load and isinstance(data, np.ndarray):
213213
raise ValueError(
214-
"Cannot combine Grid from defer_loaded Field with np.ndarray data. please specify lon, lat, depth and time dimensions separately"
214+
"Cannot combine Grid from defer_loaded Field with np.ndarray data. please specify time, depth, lat and lon dimensions separately"
215215
)
216216
self._grid = grid
217217
else:
@@ -220,7 +220,7 @@ def __init__(
220220
time = np.array([time_origin.reltime(t) for t in time])
221221
else:
222222
time_origin = TimeConverter(0)
223-
self._grid = Grid.create_grid(lon, lat, depth, time, time_origin=time_origin, mesh=mesh)
223+
self._grid = Grid.create_grid(time, depth, lat, lon, time_origin=time_origin, mesh=mesh)
224224
self.igrid = -1
225225
self.fieldtype = self.name if fieldtype is None else fieldtype
226226
self.to_write = to_write
@@ -687,7 +687,7 @@ def from_netcdf(
687687
time, time_origin, timeslices, dataFiles = cls._collect_timeslices(
688688
timestamps, data_filenames, _grid_fb_class, dimensions, indices, netcdf_engine
689689
)
690-
grid = Grid.create_grid(lon, lat, depth, time, time_origin=time_origin, mesh=mesh)
690+
grid = Grid.create_grid(time, depth, lat, lon, time_origin=time_origin, mesh=mesh)
691691
grid.timeslices = timeslices
692692
kwargs["dataFiles"] = dataFiles
693693
elif grid is not None and ("dataFiles" not in kwargs or kwargs["dataFiles"] is None):
@@ -838,7 +838,7 @@ def from_xarray(
838838
time_origin = TimeConverter(time[0])
839839
time = time_origin.reltime(time) # type: ignore[assignment]
840840

841-
grid = Grid.create_grid(lon, lat, depth, time, time_origin=time_origin, mesh=mesh)
841+
grid = Grid.create_grid(time, depth, lat, lon, time_origin=time_origin, mesh=mesh)
842842
kwargs["time_periodic"] = time_periodic
843843
return cls(
844844
name,

parcels/fieldset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def from_data(
155155
time = np.array([time_origin.reltime(t) for t in time])
156156
else:
157157
time_origin = kwargs.pop("time_origin", TimeConverter(0))
158-
grid = Grid.create_grid(lon, lat, depth, time, time_origin=time_origin, mesh=mesh)
158+
grid = Grid.create_grid(time, depth, lat, lon, time_origin=time_origin, mesh=mesh)
159159
if "creation_log" not in kwargs.keys():
160160
kwargs["creation_log"] = "from_data"
161161

parcels/grid.py

Lines changed: 34 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,10 @@ class Grid:
4444

4545
def __init__(
4646
self,
47-
lon: npt.NDArray,
47+
time: npt.NDArray,
48+
depth: npt.NDArray,
4849
lat: npt.NDArray,
49-
time: npt.NDArray | None,
50+
lon: npt.NDArray,
5051
time_origin: TimeConverter | None,
5152
mesh: Mesh,
5253
):
@@ -71,6 +72,7 @@ def __init__(
7172

7273
self._lon = lon
7374
self._lat = lat
75+
self._depth = depth
7476
self.time = time
7577
self.time_full = self.time # needed for deferred_loaded Fields
7678
self._time_origin = TimeConverter() if time_origin is None else time_origin
@@ -98,7 +100,7 @@ def __repr__(self):
98100
with np.printoptions(threshold=5, suppress=True, linewidth=120, formatter={"float": "{: 0.2f}".format}):
99101
return (
100102
f"{type(self).__name__}("
101-
f"lon={self.lon!r}, lat={self.lat!r}, time={self.time!r}, "
103+
f"lon={self.lon!r}, lat={self.lat!r}, time={self.time!r}, depth={self.depth!r}, "
102104
f"time_origin={self.time_origin!r}, mesh={self.mesh!r})"
103105
)
104106

@@ -195,10 +197,10 @@ def load_chunk(self):
195197

196198
@staticmethod
197199
def create_grid(
198-
lon: npt.ArrayLike,
200+
time: npt.ArrayLike,
201+
depth: npt.ArrayLike,
199202
lat: npt.ArrayLike,
200-
depth,
201-
time,
203+
lon: npt.ArrayLike,
202204
time_origin,
203205
mesh: Mesh,
204206
**kwargs,
@@ -211,14 +213,14 @@ def create_grid(
211213

212214
if len(lon.shape) <= 1:
213215
if depth is None or len(depth.shape) <= 1:
214-
return RectilinearZGrid(lon, lat, depth, time, time_origin=time_origin, mesh=mesh, **kwargs)
216+
return RectilinearZGrid(time, depth, lat, lon, time_origin=time_origin, mesh=mesh, **kwargs)
215217
else:
216-
return RectilinearSGrid(lon, lat, depth, time, time_origin=time_origin, mesh=mesh, **kwargs)
218+
return RectilinearSGrid(time, depth, lat, lon, time_origin=time_origin, mesh=mesh, **kwargs)
217219
else:
218220
if depth is None or len(depth.shape) <= 1:
219-
return CurvilinearZGrid(lon, lat, depth, time, time_origin=time_origin, mesh=mesh, **kwargs)
221+
return CurvilinearZGrid(time, depth, lat, lon, time_origin=time_origin, mesh=mesh, **kwargs)
220222
else:
221-
return CurvilinearSGrid(lon, lat, depth, time, time_origin=time_origin, mesh=mesh, **kwargs)
223+
return CurvilinearSGrid(time, depth, lat, lon, time_origin=time_origin, mesh=mesh, **kwargs)
222224

223225
@property
224226
def ctypes_struct(self):
@@ -461,14 +463,14 @@ class RectilinearGrid(Grid):
461463
462464
"""
463465

464-
def __init__(self, lon, lat, time, time_origin, mesh: Mesh):
466+
def __init__(self, time, depth, lat, lon, time_origin, mesh: Mesh):
465467
assert isinstance(lon, np.ndarray) and len(lon.shape) <= 1, "lon is not a numpy vector"
466468
assert isinstance(lat, np.ndarray) and len(lat.shape) <= 1, "lat is not a numpy vector"
467469
assert isinstance(time, np.ndarray) or not time, "time is not a numpy array"
468470
if isinstance(time, np.ndarray):
469471
assert len(time.shape) == 1, "time is not a vector"
470472

471-
super().__init__(lon, lat, time, time_origin, mesh)
473+
super().__init__(time, depth, lat, lon, time_origin, mesh)
472474
self.tdim = self.time.size
473475

474476
if self.ydim > 1 and self.lat[-1] < self.lat[0]:
@@ -559,8 +561,8 @@ class RectilinearZGrid(RectilinearGrid):
559561
2. flat: No conversion, lat/lon are assumed to be in m.
560562
"""
561563

562-
def __init__(self, lon, lat, depth=None, time=None, time_origin=None, mesh: Mesh = "flat"):
563-
super().__init__(lon, lat, time, time_origin, mesh)
564+
def __init__(self, time, depth, lat, lon, time_origin=None, mesh: Mesh = "flat"):
565+
super().__init__(time, depth, lat, lon, time_origin, mesh)
564566
if isinstance(depth, np.ndarray):
565567
assert len(depth.shape) <= 1, "depth is not a vector"
566568

@@ -592,7 +594,7 @@ class RectilinearSGrid(RectilinearGrid):
592594
which are s-coordinates.
593595
s-coordinates can be terrain-following (sigma) or iso-density (rho) layers,
594596
or any generalised vertical discretisation.
595-
The depth of each node depends then on the horizontal position (lon, lat),
597+
The depth of each node depends then on the horizontal position (lat, lon),
596598
the number of the layer and the time is depth is a 4D array.
597599
depth array is either a 4D array[xdim][ydim][zdim][tdim] or a 3D array[xdim][ydim[zdim].
598600
time :
@@ -610,14 +612,14 @@ class RectilinearSGrid(RectilinearGrid):
610612

611613
def __init__(
612614
self,
613-
lon: npt.NDArray,
614-
lat: npt.NDArray,
615+
time: npt.NDArray,
615616
depth: npt.NDArray,
616-
time: npt.NDArray | None = None,
617+
lat: npt.NDArray,
618+
lon: npt.NDArray,
617619
time_origin: TimeConverter | None = None,
618620
mesh: Mesh = "flat",
619621
):
620-
super().__init__(lon, lat, time, time_origin, mesh)
622+
super().__init__(time, depth, lat, lon, time_origin, mesh)
621623
assert isinstance(depth, np.ndarray) and len(depth.shape) in [3, 4], "depth is not a 3D or 4D numpy array"
622624

623625
self._gtype = GridType.RectilinearSGrid
@@ -656,9 +658,10 @@ def zdim(self):
656658
class CurvilinearGrid(Grid):
657659
def __init__(
658660
self,
659-
lon: npt.NDArray,
661+
time: npt.NDArray,
662+
depth: npt.NDArray,
660663
lat: npt.NDArray,
661-
time: npt.NDArray | None = None,
664+
lon: npt.NDArray,
662665
time_origin: TimeConverter | None = None,
663666
mesh: Mesh = "flat",
664667
):
@@ -670,7 +673,7 @@ def __init__(
670673

671674
lon = lon.squeeze()
672675
lat = lat.squeeze()
673-
super().__init__(lon, lat, time, time_origin, mesh)
676+
super().__init__(time, None, lat, lon, time_origin, mesh)
674677
self.tdim = self.time.size
675678

676679
@property
@@ -770,14 +773,14 @@ class CurvilinearZGrid(CurvilinearGrid):
770773

771774
def __init__(
772775
self,
773-
lon: npt.NDArray,
776+
time: npt.NDArray,
777+
depth: npt.NDArray,
774778
lat: npt.NDArray,
775-
depth: npt.NDArray | None = None,
776-
time: npt.NDArray | None = None,
779+
lon: npt.NDArray,
777780
time_origin: TimeConverter | None = None,
778781
mesh: Mesh = "flat",
779782
):
780-
super().__init__(lon, lat, time, time_origin, mesh)
783+
super().__init__(time, depth, lat, lon, time_origin, mesh)
781784
if isinstance(depth, np.ndarray):
782785
assert len(depth.shape) == 1, "depth is not a vector"
783786

@@ -808,7 +811,7 @@ class CurvilinearSGrid(CurvilinearGrid):
808811
which are s-coordinates.
809812
s-coordinates can be terrain-following (sigma) or iso-density (rho) layers,
810813
or any generalised vertical discretisation.
811-
The depth of each node depends then on the horizontal position (lon, lat),
814+
The depth of each node depends then on the horizontal position (lat, lon),
812815
the number of the layer and the time is depth is a 4D array.
813816
depth array is either a 4D array[xdim][ydim][zdim][tdim] or a 3D array[xdim][ydim[zdim].
814817
time :
@@ -826,14 +829,14 @@ class CurvilinearSGrid(CurvilinearGrid):
826829

827830
def __init__(
828831
self,
829-
lon: npt.NDArray,
830-
lat: npt.NDArray,
832+
time: npt.NDArray,
831833
depth: npt.NDArray,
832-
time: npt.NDArray | None = None,
834+
lat: npt.NDArray,
835+
lon: npt.NDArray,
833836
time_origin: TimeConverter | None = None,
834837
mesh: Mesh = "flat",
835838
):
836-
super().__init__(lon, lat, time, time_origin, mesh)
839+
super().__init__(time, depth, lat, lon, time_origin, mesh)
837840
assert isinstance(depth, np.ndarray) and len(depth.shape) in [3, 4], "depth is not a 4D numpy array"
838841

839842
self._gtype = GridType.CurvilinearSGrid

parcels/gridset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def add_grid(self, field):
4444
field.igrid = self.grids.index(field.grid)
4545

4646
def dimrange(self, dim):
47-
"""Returns maximum value of a dimension (lon, lat, depth or time)
47+
"""Returns maximum value of a dimension (time, depth, lat or lon)
4848
on 'left' side and minimum value on 'right' side for all grids
4949
in a gridset. Useful for finding e.g. longitude range that
5050
overlaps on all grids in a gridset.

parcels/particledata.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def __init__(self, pclass, lon, lat, depth, time, lonlatdepth_dtype, pid_orig, n
6464
), "particle's initial depth is None - incompatible with the ParticleData class. Invalid state."
6565
assert lon.size == lat.size and lon.size == depth.size, "lon, lat, depth don't all have the same lenghts."
6666

67-
assert lon.size == time.size, "time and positions (lon, lat, depth) don't have the same lengths."
67+
assert lon.size == time.size, "time and positions (depth, lat, lon) don't have the same lengths."
6868

6969
# If a partitioning function for MPI runs has been passed into the
7070
# particle creation with the "partition_function" kwarg, retrieve it here.
@@ -74,7 +74,7 @@ def __init__(self, pclass, lon, lat, depth, time, lonlatdepth_dtype, pid_orig, n
7474
for kwvar in kwargs:
7575
assert (
7676
lon.size == kwargs[kwvar].size
77-
), f"{kwvar} and positions (lon, lat, depth) don't have the same lengths."
77+
), f"{kwvar} and positions (depth, lat, lon) don't have the same lengths."
7878

7979
offset = np.max(pid) if (pid is not None) and len(pid) > 0 else -1
8080
if MPI:

tests/test_deprecations.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ def create_test_data():
285285
lon_g0 = np.linspace(0, 1000, 11, dtype=np.float32)
286286
lat_g0 = np.linspace(0, 1000, 11, dtype=np.float32)
287287
time_g0 = np.linspace(0, 1000, 2, dtype=np.float64)
288-
grid = RectilinearZGrid(lon_g0, lat_g0, time=time_g0)
288+
grid = RectilinearZGrid(time_g0, None, lat_g0, lon_g0)
289289

290290
pfile = ParticleFile("test.zarr", pset, outputdt=1)
291291

tests/test_diffusion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def test_fieldKh_SpatiallyVaryingDiffusion(mesh, mode, kernel):
6868
for x in range(xdim):
6969
Kh[:, x] = np.tanh(fieldset.U.lon[x] / fieldset.U.lon[-1] * 10.0) * xdim / 2.0 + xdim / 2.0 + 100.0
7070

71-
grid = RectilinearZGrid(lon=fieldset.U.lon, lat=fieldset.U.lat, mesh=mesh)
71+
grid = RectilinearZGrid(time=None, depth=None, lat=fieldset.U.lat, lon=fieldset.U.lon, mesh=mesh)
7272
fieldset.add_field(Field("Kh_zonal", Kh, grid=grid))
7373
fieldset.add_field(Field("Kh_meridional", Kh, grid=grid))
7474
fieldset.add_constant("dres", fieldset.U.lon[1] - fieldset.U.lon[0])

tests/test_fieldset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -944,7 +944,7 @@ def test_fieldset_defer_loading_with_diff_time_origin(tmpdir, fail):
944944
Wtime_origin = TimeConverter(np.datetime64("2018-04-22"))
945945
else:
946946
Wtime_origin = TimeConverter(np.datetime64("2018-04-18"))
947-
gridW = RectilinearZGrid(dims1["lon"], dims1["lat"], dims1["depth"], dims1["time"], time_origin=Wtime_origin)
947+
gridW = RectilinearZGrid(dims1["time"], dims1["depth"], dims1["lat"], dims1["lon"], time_origin=Wtime_origin)
948948
fieldW = Field("W", np.zeros(data1["U"].shape), grid=gridW)
949949
fieldset_out.add_field(fieldW)
950950
fieldset_out.write(filepath)

tests/test_grids.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def temp_func(lon, lat):
4040
lon_g0 = np.linspace(0, a, xdim_g0, dtype=np.float32)
4141
lat_g0 = np.linspace(0, b, ydim_g0, dtype=np.float32)
4242
time_g0 = np.linspace(0.0, 1000.0, 2, dtype=np.float64)
43-
grid_0 = RectilinearZGrid(lon_g0, lat_g0, time=time_g0)
43+
grid_0 = RectilinearZGrid(time_g0, None, lat_g0, lon_g0)
4444

4545
# Grid 1
4646
xdim_g1 = 51
@@ -49,7 +49,7 @@ def temp_func(lon, lat):
4949
lon_g1 = np.linspace(0, a, xdim_g1, dtype=np.float32)
5050
lat_g1 = np.linspace(0, b, ydim_g1, dtype=np.float32)
5151
time_g1 = np.linspace(0.0, 1000.0, 2, dtype=np.float64)
52-
grid_1 = RectilinearZGrid(lon_g1, lat_g1, time=time_g1)
52+
grid_1 = RectilinearZGrid(time_g1, None, lat_g1, lon_g1)
5353

5454
u_data = np.ones((lon_g0.size, lat_g0.size, time_g0.size), dtype=np.float32)
5555
u_data = 2 * u_data
@@ -109,7 +109,7 @@ def test_time_format_in_grid():
109109
lat = np.linspace(0, 1, 2, dtype=np.float32)
110110
time = np.array([np.datetime64("2000-01-01")] * 2)
111111
with pytest.raises(AssertionError, match="Time vector"):
112-
RectilinearZGrid(lon, lat, time=time)
112+
RectilinearZGrid(time, None, lat, lon)
113113

114114

115115
def test_negate_depth():
@@ -126,12 +126,12 @@ def test_avoid_repeated_grids():
126126
lon_g0 = np.linspace(0, 1000, 11, dtype=np.float32)
127127
lat_g0 = np.linspace(0, 1000, 11, dtype=np.float32)
128128
time_g0 = np.linspace(0, 1000, 2, dtype=np.float64)
129-
grid_0 = RectilinearZGrid(lon_g0, lat_g0, time=time_g0)
129+
grid_0 = RectilinearZGrid(time_g0, None, lat_g0, lon_g0)
130130

131131
lon_g1 = np.linspace(0, 1000, 21, dtype=np.float32)
132132
lat_g1 = np.linspace(0, 1000, 21, dtype=np.float32)
133133
time_g1 = np.linspace(0, 1000, 2, dtype=np.float64)
134-
grid_1 = RectilinearZGrid(lon_g1, lat_g1, time=time_g1)
134+
grid_1 = RectilinearZGrid(time_g1, None, lat_g1, lon_g1)
135135

136136
u_data = np.zeros((lon_g0.size, lat_g0.size, time_g0.size), dtype=np.float32)
137137
u_field = Field("U", u_data, grid=grid_0, transpose=True)
@@ -166,8 +166,8 @@ def bath_func(lon):
166166
for k in range(zdim):
167167
depth_g0[k, :, i] = bath[i] * k / (zdim - 1)
168168

169-
grid_0 = RectilinearSGrid(lon_g0, lat_g0, depth=depth_g0)
170-
grid_1 = RectilinearSGrid(lon_g0, lat_g0, depth=depth_g0)
169+
grid_0 = RectilinearSGrid(None, depth_g0, lat_g0, lon_g0)
170+
grid_1 = RectilinearSGrid(None, depth_g0, lat_g0, lon_g0)
171171

172172
u_data = np.zeros((zdim, lat_g0.size, lon_g0.size), dtype=np.float32)
173173
v_data = np.zeros((zdim, lat_g0.size, lon_g0.size), dtype=np.float32)
@@ -216,7 +216,7 @@ def bath_func(lon):
216216
else:
217217
depth_g0[k, :, i] = bath[i] * k / (zdim - 1)
218218

219-
grid = RectilinearSGrid(lon_g0, lat_g0, depth=depth_g0, time=time_g0)
219+
grid = RectilinearSGrid(time_g0, depth_g0, lat_g0, lon_g0)
220220

221221
u_data = np.zeros((grid.tdim, grid.zdim, grid.ydim, grid.xdim), dtype=np.float32)
222222
v_data = np.zeros((grid.tdim, grid.zdim, grid.ydim, grid.xdim), dtype=np.float32)
@@ -262,7 +262,7 @@ def bath_func(lon):
262262
depth_g0[i, :, k] = bath[i] * k / (depth_g0.shape[2] - 1)
263263
depth_g0 = depth_g0.transpose() # we don't change it on purpose, to check if the transpose op if fixed in jit
264264

265-
grid = RectilinearSGrid(lon_g0, lat_g0, depth=depth_g0)
265+
grid = RectilinearSGrid(None, depth_g0, lat_g0, lon_g0)
266266

267267
zdim = depth_g0.shape[0]
268268
u_data = np.zeros((zdim, lat_g0.size, lon_g0.size), dtype=np.float32)
@@ -306,7 +306,7 @@ def bath_func(lon):
306306
for k in range(zdim):
307307
depth_g0[k, :, i] = bath[i] * k / (zdim - 1)
308308

309-
grid = RectilinearSGrid(lon_g0, lat_g0, depth=depth_g0)
309+
grid = RectilinearSGrid(None, depth_g0, lat_g0, lon_g0)
310310

311311
u_data = np.zeros((zdim, lat_g0.size, lon_g0.size), dtype=np.float32)
312312
v_data = np.zeros((zdim, lat_g0.size, lon_g0.size), dtype=np.float32)
@@ -347,7 +347,7 @@ def test_curvilinear_grids(mode):
347347
lon = r * np.cos(theta)
348348
lat = r * np.sin(theta)
349349
time = np.array([0, 86400], dtype=np.float64)
350-
grid = CurvilinearZGrid(lon, lat, time=time)
350+
grid = CurvilinearZGrid(time, None, lat, lon)
351351

352352
u_data = np.ones((2, y.size, x.size), dtype=np.float32)
353353
v_data = np.zeros((2, y.size, x.size), dtype=np.float32)

tests/test_particlesets.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def test_pset_create_field_curvi():
139139

140140
x = -1 + r * np.cos(theta)
141141
y = -1 + r * np.sin(theta)
142-
grid = CurvilinearZGrid(x, y)
142+
grid = CurvilinearZGrid(time=None, depth=None, lat=y, lon=x)
143143

144144
u = np.ones(x.shape)
145145
v = np.where(np.logical_and(theta > np.pi / 4, theta < np.pi / 3), 1, 0)

0 commit comments

Comments
 (0)