Skip to content

Commit f92e054

Browse files
Merge pull request #1875 from OceanParcels/remove-jit-compat
Remove cstruct objects and remove c-contiguous requirement
2 parents cf3eb91 + 5742560 commit f92e054

File tree

4 files changed

+12
-169
lines changed

4 files changed

+12
-169
lines changed

parcels/field.py

Lines changed: 3 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import math
33
import warnings
44
from collections.abc import Iterable
5-
from ctypes import POINTER, Structure, c_float, c_int, pointer
5+
from ctypes import c_int
66
from pathlib import Path
77
from typing import TYPE_CHECKING, Literal
88

@@ -50,11 +50,9 @@
5050
DeferredNetcdfFileBuffer,
5151
NetcdfFileBuffer,
5252
)
53-
from .grid import CGrid, Grid, GridType, _calc_cell_areas
53+
from .grid import Grid, GridType, _calc_cell_areas
5454

5555
if TYPE_CHECKING:
56-
from ctypes import _Pointer as PointerType
57-
5856
from parcels.fieldset import FieldSet
5957

6058
__all__ = ["Field", "NestedField", "VectorField"]
@@ -341,7 +339,6 @@ def __init__(
341339
# since some datasets do not provide the deeper level of data (which is ignored by the interpolation).
342340
self.data_full_zdim = kwargs.pop("data_full_zdim", None)
343341
self._data_chunks = [] # type: ignore # the data buffer of the FileBuffer raw loaded data - shall be a list of C-contiguous arrays
344-
self._c_data_chunks: list[PointerType | None] = [] # C-pointers to the data_chunks array
345342
self.nchunks: tuple[int, ...] = ()
346343
self._chunk_set: bool = False
347344
self.filebuffers = [None] * 2
@@ -1016,10 +1013,7 @@ def _time_index(self, time):
10161013
periods = int(
10171014
math.floor((time - self.grid.time_full[0]) / (self.grid.time_full[-1] - self.grid.time_full[0]))
10181015
)
1019-
if isinstance(self.grid.periods, c_int):
1020-
self.grid.periods.value = periods
1021-
else:
1022-
self.grid.periods = periods
1016+
self.grid.periods = periods
10231017
time -= periods * (self.grid.time_full[-1] - self.grid.time_full[0])
10241018
time_index = self.grid.time <= time
10251019
ti = time_index.argmin() - 1 if time_index.any() else 0
@@ -1106,7 +1100,6 @@ def _chunk_setup(self):
11061100
return
11071101

11081102
self._data_chunks = [None] * npartitions
1109-
self._c_data_chunks = [None] * npartitions
11101103
self.grid._load_chunk = np.zeros(npartitions, dtype=c_int, order="C")
11111104
# self.grid.chunk_info format: number of dimensions (without tdim); number of chunks per dimensions;
11121105
# chunksizes (the 0th dim sizes for all chunk of dim[0], then so on for next dims
@@ -1136,62 +1129,14 @@ def _chunk_data(self):
11361129
self._data_chunks[block_id] = None
11371130
else:
11381131
self._data_chunks[block_id, :] = None
1139-
self._c_data_chunks[block_id] = None
11401132
else:
11411133
if isinstance(self._data_chunks, list):
11421134
self._data_chunks[0] = None
11431135
else:
11441136
self._data_chunks[0, :] = None
1145-
self._c_data_chunks[0] = None
11461137
self.grid._load_chunk[0] = g._chunk_loaded_touched
11471138
self._data_chunks[0] = np.array(self.data, order="C")
11481139

1149-
@property
1150-
def ctypes_struct(self):
1151-
"""Returns a ctypes struct object containing all relevant pointers and sizes for this field."""
1152-
1153-
# Ctypes struct corresponding to the type definition in parcels.h
1154-
class CField(Structure):
1155-
_fields_ = [
1156-
("xdim", c_int),
1157-
("ydim", c_int),
1158-
("zdim", c_int),
1159-
("tdim", c_int),
1160-
("igrid", c_int),
1161-
("allow_time_extrapolation", c_int),
1162-
("time_periodic", c_int),
1163-
("data_chunks", POINTER(POINTER(POINTER(c_float)))),
1164-
("grid", POINTER(CGrid)),
1165-
]
1166-
1167-
# Create and populate the c-struct object
1168-
allow_time_extrapolation = 1 if self.allow_time_extrapolation else 0
1169-
time_periodic = 1 if self.time_periodic else 0
1170-
for i in range(len(self.grid._load_chunk)):
1171-
if self.grid._load_chunk[i] == self.grid._chunk_loading_requested:
1172-
raise ValueError(
1173-
"data_chunks should have been loaded by now if requested. grid._load_chunk[bid] cannot be 1"
1174-
)
1175-
if self.grid._load_chunk[i] in self.grid._chunk_loaded:
1176-
if not self._data_chunks[i].flags["C_CONTIGUOUS"]:
1177-
self._data_chunks[i] = np.array(self._data_chunks[i], order="C")
1178-
self._c_data_chunks[i] = self._data_chunks[i].ctypes.data_as(POINTER(POINTER(c_float)))
1179-
else:
1180-
self._c_data_chunks[i] = None
1181-
1182-
cstruct = CField(
1183-
self.grid.xdim,
1184-
self.grid.ydim,
1185-
self.grid.zdim,
1186-
self.grid.tdim,
1187-
self.igrid,
1188-
allow_time_extrapolation,
1189-
time_periodic,
1190-
(POINTER(POINTER(c_float)) * len(self._c_data_chunks))(*self._c_data_chunks),
1191-
pointer(self.grid.ctypes_struct),
1192-
)
1193-
return cstruct
1194-
11951140
def add_periodic_halo(self, zonal, meridional, halosize=5, data=None):
11961141
"""Add a 'halo' to all Fields in a FieldSet.
11971142

parcels/grid.py

Lines changed: 8 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import functools
22
import warnings
3-
from ctypes import POINTER, Structure, c_double, c_float, c_int, c_void_p, cast, pointer
43
from enum import IntEnum
54

65
import numpy as np
@@ -11,7 +10,6 @@
1110
from parcels.tools.warnings import FieldSetWarning
1211

1312
__all__ = [
14-
"CGrid",
1513
"CurvilinearSGrid",
1614
"CurvilinearZGrid",
1715
"Grid",
@@ -34,10 +32,6 @@ class GridType(IntEnum):
3432
GridCode = GridType
3533

3634

37-
class CGrid(Structure):
38-
_fields_ = [("gtype", c_int), ("grid", c_void_p)]
39-
40-
4135
class Grid:
4236
"""Grid class that defines a (spatial and temporal) grid on which Fields are defined."""
4337

@@ -51,13 +45,10 @@ def __init__(
5145
):
5246
self._ti = -1
5347
self._update_status: UpdateStatus | None = None
54-
if not lon.flags["C_CONTIGUOUS"]:
55-
lon = np.array(lon, order="C")
56-
if not lat.flags["C_CONTIGUOUS"]:
57-
lat = np.array(lat, order="C")
48+
lon = np.array(lon)
49+
lat = np.array(lat)
5850
time = np.zeros(1, dtype=np.float64) if time is None else time
59-
if not time.flags["C_CONTIGUOUS"]:
60-
time = np.array(time, order="C")
51+
time = np.array(time)
6152
if not lon.dtype == np.float32:
6253
lon = lon.astype(np.float32)
6354
if not lat.dtype == np.float32:
@@ -76,7 +67,6 @@ def __init__(
7667
assert isinstance(self.time_origin, TimeConverter), "time_origin needs to be a TimeConverter object"
7768
assert_valid_mesh(mesh)
7869
self._mesh = mesh
79-
self._cstruct = None
8070
self._cell_edge_sizes: dict[str, npt.NDArray] = {}
8171
self._zonal_periodic = False
8272
self._zonal_halo = 0
@@ -179,67 +169,6 @@ def create_grid(
179169
else:
180170
return CurvilinearSGrid(lon, lat, depth, time, time_origin=time_origin, mesh=mesh, **kwargs)
181171

182-
@property
183-
def ctypes_struct(self):
184-
# This is unnecessary for the moment, but it could be useful when going will fully unstructured grids
185-
self._cgrid = cast(pointer(self._child_ctypes_struct), c_void_p)
186-
cstruct = CGrid(self._gtype, self._cgrid.value)
187-
return cstruct
188-
189-
@property
190-
def _child_ctypes_struct(self):
191-
"""Returns a ctypes struct object containing all relevant
192-
pointers and sizes for this grid.
193-
"""
194-
195-
class CStructuredGrid(Structure):
196-
# z4d is only to have same cstruct as RectilinearSGrid
197-
_fields_ = [
198-
("xdim", c_int),
199-
("ydim", c_int),
200-
("zdim", c_int),
201-
("tdim", c_int),
202-
("z4d", c_int),
203-
("mesh_spherical", c_int),
204-
("zonal_periodic", c_int),
205-
("chunk_info", POINTER(c_int)),
206-
("load_chunk", POINTER(c_int)),
207-
("tfull_min", c_double),
208-
("tfull_max", c_double),
209-
("periods", POINTER(c_int)),
210-
("lonlat_minmax", POINTER(c_float)),
211-
("lon", POINTER(c_float)),
212-
("lat", POINTER(c_float)),
213-
("depth", POINTER(c_float)),
214-
("time", POINTER(c_double)),
215-
]
216-
217-
# Create and populate the c-struct object
218-
if not self._cstruct: # Not to point to the same grid various times if grid in various fields
219-
if not isinstance(self.periods, c_int):
220-
self.periods = c_int()
221-
self.periods.value = 0
222-
self._cstruct = CStructuredGrid(
223-
self.xdim,
224-
self.ydim,
225-
self.zdim,
226-
self.tdim,
227-
self._z4d,
228-
int(self.mesh == "spherical"),
229-
int(self.zonal_periodic),
230-
(c_int * len(self.chunk_info))(*self.chunk_info),
231-
self._load_chunk.ctypes.data_as(POINTER(c_int)),
232-
self.time_full[0],
233-
self.time_full[-1],
234-
pointer(self.periods),
235-
self.lonlat_minmax.ctypes.data_as(POINTER(c_float)),
236-
self.lon.ctypes.data_as(POINTER(c_float)),
237-
self.lat.ctypes.data_as(POINTER(c_float)),
238-
self.depth.ctypes.data_as(POINTER(c_float)),
239-
self.time.ctypes.data_as(POINTER(c_double)),
240-
)
241-
return self._cstruct
242-
243172
def _check_zonal_periodic(self):
244173
if self.zonal_periodic or self.mesh == "flat" or self.lon.size == 1:
245174
return
@@ -278,7 +207,7 @@ def _add_Sdepth_periodic_halo(self, zonal, meridional, halosize):
278207

279208
def _computeTimeChunk(self, f, time, signdt):
280209
nextTime_loc = np.inf if signdt >= 0 else -np.inf
281-
periods = self.periods.value if isinstance(self.periods, c_int) else self.periods
210+
periods = self.periods
282211
prev_time_indices = self.time
283212
if self._update_status == "not_updated":
284213
if self._ti >= 0:
@@ -316,7 +245,7 @@ def _computeTimeChunk(self, f, time, signdt):
316245
if self._ti == -1:
317246
self.time = self.time_full
318247
self._ti, _ = f._time_index(time)
319-
periods = self.periods.value if isinstance(self.periods, c_int) else self.periods
248+
periods = self.periods
320249
if (
321250
signdt == -1
322251
and self._ti == 0
@@ -483,8 +412,7 @@ def __init__(self, lon, lat, depth=None, time=None, time_origin=None, mesh: Mesh
483412

484413
self._gtype = GridType.RectilinearZGrid
485414
self._depth = np.zeros(1, dtype=np.float32) if depth is None else depth
486-
if not self.depth.flags["C_CONTIGUOUS"]:
487-
self._depth = np.array(self.depth, order="C")
415+
self._depth = np.array(self.depth)
488416
self._z4d = -1 # only used in RectilinearSGrid
489417
if not self.depth.dtype == np.float32:
490418
self._depth = self.depth.astype(np.float32)
@@ -539,8 +467,7 @@ def __init__(
539467

540468
self._gtype = GridType.RectilinearSGrid
541469
self._depth = depth
542-
if not self.depth.flags["C_CONTIGUOUS"]:
543-
self._depth = np.array(self.depth, order="C")
470+
self._depth = np.array(self.depth)
544471
self._z4d = 1 if len(self.depth.shape) == 4 else 0
545472
if self._z4d:
546473
# self.depth.shape[0] is 0 for S grids loaded from netcdf file
@@ -656,8 +583,6 @@ def __init__(
656583

657584
self._gtype = GridType.CurvilinearZGrid
658585
self._depth = np.zeros(1, dtype=np.float32) if depth is None else depth
659-
if not self.depth.flags["C_CONTIGUOUS"]:
660-
self._depth = np.array(self.depth, order="C")
661586
self._z4d = -1 # only for SGrid
662587
if not self.depth.dtype == np.float32:
663588
self._depth = self.depth.astype(np.float32)
@@ -710,9 +635,7 @@ def __init__(
710635
assert isinstance(depth, np.ndarray) and len(depth.shape) in [3, 4], "depth is not a 4D numpy array"
711636

712637
self._gtype = GridType.CurvilinearSGrid
713-
self._depth = depth # should be a C-contiguous array of floats
714-
if not self.depth.flags["C_CONTIGUOUS"]:
715-
self._depth = np.array(self.depth, order="C")
638+
self._depth = depth
716639
self._z4d = 1 if len(self.depth.shape) == 4 else 0
717640
if self._z4d:
718641
# self.depth.shape[0] is 0 for S grids loaded from netcdf file

parcels/particledata.py

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import warnings
2-
from ctypes import POINTER, Structure
32
from operator import attrgetter
43

54
import numpy as np
@@ -42,8 +41,7 @@ def __init__(self, pclass, lon, lat, depth, time, lonlatdepth_dtype, pid_orig, n
4241
Parameters
4342
----------
4443
ngrid :
45-
number of grids in the fieldset of the overarching ParticleSet - required for initialising the
46-
field references of the ctypes-link of particles that are allocated
44+
number of grids in the fieldset of the overarching ParticleSet.
4745
"""
4846
self._ncount = -1
4947
self._pu_indicators = None
@@ -328,21 +326,6 @@ def remove_multi_by_indices(self, indices):
328326

329327
self._ncount -= len(indices)
330328

331-
def cstruct(self):
332-
"""Return the ctypes mapping of the particle data."""
333-
334-
class CParticles(Structure):
335-
_fields_ = [(v.name, POINTER(np.ctypeslib.as_ctypes_type(v.dtype))) for v in self._ptype.variables]
336-
337-
def flatten_dense_data_array(vname):
338-
data_flat = self._data[vname].view()
339-
data_flat.shape = -1
340-
return np.ctypeslib.as_ctypes(data_flat)
341-
342-
cdata = [flatten_dense_data_array(v.name) for v in self._ptype.variables]
343-
cstruct = CParticles(*cdata)
344-
return cstruct
345-
346329
def _to_write_particles(self, time):
347330
"""Return the Particles that need to be written at time: if particle.time is between time-dt/2 and time+dt (/2)"""
348331
pd = self._data

parcels/particleset.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -321,14 +321,6 @@ def lonlatdepth_dtype_from_field_interp_method(field):
321321
return np.float64
322322
return np.float32
323323

324-
def cstruct(self):
325-
cstruct = self.particledata.cstruct()
326-
return cstruct
327-
328-
@property
329-
def ctypes_struct(self):
330-
return self.cstruct()
331-
332324
@property
333325
def size(self):
334326
# ==== to change at some point - len and size are different things ==== #

0 commit comments

Comments
 (0)