Skip to content

Commit 9fa88b6

Browse files
Merge pull request #1860 from OceanParcels/JIT_cleanup
JIT cleanup
2 parents 866dae8 + d49c800 commit 9fa88b6

File tree

15 files changed

+31
-116
lines changed

15 files changed

+31
-116
lines changed

.github/ci/min-core-deps.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ dependencies:
88
# Run ci/min_deps_check.py to verify that this file respects the policy.
99
- python=3.10
1010
- cftime=1.6
11-
- cgen=2020.1
1211
- dask=2022.8
1312
- matplotlib-base=3.5
1413
# netcdf follows a 1.major.minor[.patch] convention

docs/examples/tutorial_parcels_structure.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
"source": [
2626
"1. [**FieldSet**](#1.-FieldSet). Load and set up the fields. These can be velocity fields that are used to advect the particles, but it can also be e.g. temperature.\n",
2727
"2. [**ParticleSet**](#2.-ParticleSet). Define the type of particles. Also additional `Variables` can be added to the particles (e.g. temperature, to keep track of the temperature that particles experience).\n",
28-
"3. [**Kernels**](#3.-Kernels). Define and compile kernels. Kernels perform some specific operation on the particles every time step (e.g. interpolate the temperature from the temperature field to the particle location).\n",
28+
"3. [**Kernels**](#3.-Kernels). Kernels perform some specific operation on the particles every time step (e.g. interpolate the temperature from the temperature field to the particle location).\n",
2929
"4. [**Execution and output**](#4.-Execution-and-Output). Execute the simulation and write and store the output in a Zarr file.\n",
3030
"5. [**Optimising and parallelising**](#5.-Optimising-and-parallelising). Optimise and parallelise the code to run faster.\n",
3131
"\n",

environment.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ channels:
33
- conda-forge
44
dependencies:
55
- python>=3.10
6-
- cgen
76
- ffmpeg>=3.2.3
87
- jupyter
98
- matplotlib-base>=2.0.2

parcels/_typing.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,11 @@
66
77
"""
88

9-
import ast
109
import datetime
1110
import os
1211
from collections.abc import Callable
1312
from typing import Any, Literal, get_args
1413

15-
16-
class ParcelsAST(ast.AST):
17-
ccode: str
18-
19-
2014
InterpMethodOption = Literal[
2115
"linear",
2216
"nearest",

parcels/field.py

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,6 @@ class Field:
149149
Maximum allowed value on the field. Data above this value are set to zero
150150
cast_data_dtype : str
151151
Cast Field data to dtype. Supported dtypes are "float32" (np.float32 (default)) and "float64 (np.float64).
152-
Note that dtype can only be "float32" in JIT mode
153152
time_origin : parcels.tools.converters.TimeConverter
154153
Time origin of the time axis (only if grid is None)
155154
interp_method : str
@@ -320,7 +319,6 @@ def __init__(
320319

321320
self._scaling_factor = None
322321

323-
# Variable names in JIT code
324322
self._dimensions = kwargs.pop("dimensions", None)
325323
self.indices = kwargs.pop("indices", None)
326324
self._dataFiles = kwargs.pop("dataFiles", None)
@@ -1084,18 +1082,6 @@ def eval(self, time, z, y, x, particle=None, applyConversion=True):
10841082
else:
10851083
return value
10861084

1087-
def _ccode_eval(self, var, t, z, y, x):
1088-
self._check_velocitysampling()
1089-
ccode_str = (
1090-
f"temporal_interpolation({t}, {z}, {y}, {x}, {self.ccode_name}, "
1091-
+ "&particles->ti[pnum*ngrid], &particles->zi[pnum*ngrid], &particles->yi[pnum*ngrid], &particles->xi[pnum*ngrid], "
1092-
+ f"&{var}, {self.interp_method.upper()}, {self.gridindexingtype.upper()})"
1093-
)
1094-
return ccode_str
1095-
1096-
def _ccode_convert(self, _, z, y, x):
1097-
return self.units.ccode_to_target(z, y, x)
1098-
10991085
def _get_block_id(self, block):
11001086
return np.ravel_multi_index(block, self.nchunks)
11011087

@@ -1923,22 +1909,6 @@ def __getitem__(self, key):
19231909
except tuple(AllParcelsErrorCodes.keys()) as error:
19241910
return _deal_with_errors(error, key, vector_type=self.vector_type)
19251911

1926-
def _ccode_eval(self, varU, varV, varW, U, V, W, t, z, y, x):
1927-
ccode_str = ""
1928-
if "3D" in self.vector_type:
1929-
ccode_str = (
1930-
f"temporal_interpolationUVW({t}, {z}, {y}, {x}, {U.ccode_name}, {V.ccode_name}, {W.ccode_name}, "
1931-
+ "&particles->ti[pnum*ngrid], &particles->zi[pnum*ngrid], &particles->yi[pnum*ngrid], &particles->xi[pnum*ngrid],"
1932-
+ f"&{varU}, &{varV}, &{varW}, {U.interp_method.upper()}, {U.gridindexingtype.upper()})"
1933-
)
1934-
else:
1935-
ccode_str = (
1936-
f"temporal_interpolationUV({t}, {z}, {y}, {x}, {U.ccode_name}, {V.ccode_name}, "
1937-
+ "&particles->ti[pnum*ngrid], &particles->zi[pnum*ngrid], &particles->yi[pnum*ngrid], &particles->xi[pnum*ngrid],"
1938-
+ f" &{varU}, &{varV}, {U.interp_method.upper()}, {U.gridindexingtype.upper()})"
1939-
)
1940-
return ccode_str
1941-
19421912

19431913
class DeferredArray:
19441914
"""Class used for throwing error when Field.data is not read in deferred loading mode."""

parcels/fieldset.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -309,16 +309,6 @@ def check_velocityfields(U, V, W):
309309
g._time_origin = self.time_origin
310310
self._add_UVfield()
311311

312-
ccode_fieldnames = []
313-
counter = 1
314-
for fld in self.get_fields():
315-
if fld.name not in ccode_fieldnames:
316-
fld.ccode_name = fld.name
317-
else:
318-
fld.ccode_name = fld.name + str(counter)
319-
counter += 1
320-
ccode_fieldnames.append(fld.ccode_name)
321-
322312
for f in self.get_fields():
323313
if isinstance(f, (VectorField, NestedField)) or f._dataFiles is None:
324314
continue
@@ -1447,8 +1437,7 @@ def get_fields(self) -> list[Field | VectorField]:
14471437

14481438
def add_constant(self, name, value):
14491439
"""Add a constant to the FieldSet. Note that all constants are
1450-
stored as 32-bit floats. While constants can be updated during
1451-
execution in SciPy mode, they can not be updated in JIT mode.
1440+
stored as 32-bit floats.
14521441
14531442
Parameters
14541443
----------

parcels/kernel.py

Lines changed: 20 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def __init__(
5555
self.funcvars = funcvars
5656
self.funccode = funccode
5757
self.py_ast = py_ast # TODO v4: check if this is needed
58-
self.scipy_positionupdate_kernels_added = False
58+
self._positionupdate_kernels_added = False
5959

6060
@property
6161
def ptype(self):
@@ -94,7 +94,7 @@ class Kernel(BaseKernel):
9494
9595
Notes
9696
-----
97-
A Kernel is either created from a compiled <function ...> object
97+
A Kernel is either created from a <function ...> object
9898
or the necessary information (funcname, funccode, funcvars) is provided.
9999
The py_ast argument may be derived from the code string, but for
100100
concatenation, the merged AST plus the new header definition is required.
@@ -159,7 +159,7 @@ def __init__(
159159
user_ctx = globals()
160160
finally:
161161
del stack # Remove cyclic references
162-
# Compile and generate Python function from AST
162+
# Generate Python function from AST
163163
py_mod = ast.parse("")
164164
py_mod.body = [self.py_ast]
165165
exec(compile(py_mod, "<ast>", "exec"), user_ctx)
@@ -181,7 +181,7 @@ def pyfunc(self):
181181
def fieldset(self):
182182
return self._fieldset
183183

184-
def add_scipy_positionupdate_kernels(self):
184+
def add_positionupdate_kernels(self):
185185
# Adding kernels that set and update the coordinate changes
186186
def Setcoords(particle, fieldset, time): # pragma: no cover
187187
particle_dlon = 0 # noqa
@@ -324,23 +324,6 @@ def from_list(cls, fieldset, ptype, pyfunc_list, *args, **kwargs):
324324
pyfunc_list[0] = cls(fieldset, ptype, pyfunc_list[0], *args, **kwargs)
325325
return functools.reduce(lambda x, y: x + y, pyfunc_list)
326326

327-
def execute_python(self, pset, endtime, dt):
328-
"""Performs the core update loop via Python."""
329-
if self.fieldset is not None:
330-
for f in self.fieldset.get_fields():
331-
if isinstance(f, (VectorField, NestedField)):
332-
continue
333-
f.data = np.array(f.data)
334-
335-
if not self.scipy_positionupdate_kernels_added:
336-
self.add_scipy_positionupdate_kernels()
337-
self.scipy_positionupdate_kernels_added = True
338-
339-
for p in pset:
340-
self.evaluate_particle(p, endtime)
341-
if p.state == StatusCode.StopAllExecution:
342-
return StatusCode.StopAllExecution
343-
344327
def execute(self, pset, endtime, dt):
345328
"""Execute this Kernel over a ParticleSet for several timesteps."""
346329
pset.particledata.state[:] = StatusCode.Evaluate
@@ -359,7 +342,19 @@ def execute(self, pset, endtime, dt):
359342
g._load_chunk == g._chunk_loaded_touched, g._chunk_deprecated, g._load_chunk
360343
)
361344

362-
self.execute_python(pset, endtime, dt)
345+
for f in self.fieldset.get_fields():
346+
if isinstance(f, (VectorField, NestedField)):
347+
continue
348+
f.data = np.array(f.data)
349+
350+
if not self._positionupdate_kernels_added:
351+
self.add_positionupdate_kernels()
352+
self._positionupdate_kernels_added = True
353+
354+
for p in pset:
355+
self.evaluate_particle(p, endtime)
356+
if p.state == StatusCode.StopAllExecution:
357+
return StatusCode.StopAllExecution
363358

364359
# Remove all particles that signalled deletion
365360
self.remove_deleted(pset)
@@ -398,7 +393,9 @@ def execute(self, pset, endtime, dt):
398393
# Remove all particles that signalled deletion
399394
self.remove_deleted(pset) # Generalizable version!
400395

401-
self.execute_python(pset, endtime, dt)
396+
# Re-execute Kernels to retry particles with StatusCode.Repeat
397+
for p in pset:
398+
self.evaluate_particle(p, endtime)
402399

403400
n_error = pset._num_error_particles
404401

parcels/particle.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,4 +231,6 @@ def setLastID(cls, offset): # TODO v4: check if we can implement this in anothe
231231

232232
class JITParticle(ScipyParticle):
233233
def __init__(self, *args, **kwargs):
234-
raise NotImplementedError("JITParticle has been deprecated in Parcels v4. Use ScipyParticle instead.")
234+
raise NotImplementedError(
235+
"JITParticle has been deprecated in Parcels v4. Use ScipyParticle instead."
236+
) # TODO v4: link to migration guide

parcels/particledata.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def __init__(self, pclass, lon, lat, depth, time, lonlatdepth_dtype, pid_orig, n
147147
self._data["id"][:] = pid
148148
self._data["obs_written"][:] = 0
149149

150-
# special case for exceptions which can only be handled from scipy
150+
# special case for exceptions which can only be handled from scipy # TODO v4: check if this can be removed now that JIT is dropped
151151
self._data["exception"] = np.empty(self._ncount, dtype=object)
152152

153153
initialised |= {

parcels/particleset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -932,13 +932,13 @@ def execute(
932932
933933
Notes
934934
-----
935-
``ParticleSet.execute()`` acts as the main entrypoint for simulations, and provides the simulation time-loop. This method encapsulates the logic controlling the switching between kernel execution (where control in handed to C in JIT mode), output file writing, reading in fields for new timesteps, adding new particles to the simulation domain, stopping the simulation, and executing custom functions (``postIterationCallbacks`` provided by the user).
935+
``ParticleSet.execute()`` acts as the main entrypoint for simulations, and provides the simulation time-loop. This method encapsulates the logic controlling the switching between kernel execution, output file writing, reading in fields for new timesteps, adding new particles to the simulation domain, stopping the simulation, and executing custom functions (``postIterationCallbacks`` provided by the user).
936936
"""
937937
# check if particleset is empty. If so, return immediately
938938
if len(self) == 0:
939939
return
940940

941-
# check if pyfunc has changed since last compile. If so, recompile
941+
# check if pyfunc has changed since last generation. If so, regenerate
942942
if self._kernel is None or (self._kernel.pyfunc is not pyfunc and self._kernel is not pyfunc):
943943
# Generate and store Kernel
944944
if isinstance(pyfunc, Kernel):

0 commit comments

Comments
 (0)