Skip to content

Commit 31e2670

Browse files
committed
Add test_create_particle_data
1 parent 38709fb commit 31e2670

File tree

2 files changed

+45
-5
lines changed

2 files changed

+45
-5
lines changed

parcels/particle.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ def create_particle_data(
194194
nparticles: int,
195195
ngrids: int,
196196
time_interval: TimeInterval,
197-
initial: dict[str, np.array] | None,
197+
initial: dict[str, np.array] | None = None,
198198
):
199199
if initial is None:
200200
initial = {}
@@ -231,18 +231,20 @@ def create_particle_data(
231231
name_to_copy = var.initial(_attrgetter_helper)
232232
data[var.name] = data[name_to_copy].copy()
233233
else:
234-
data[var.name] = _create_array_for_variable(var, nparticles)
234+
data[var.name] = _create_array_for_variable(var, nparticles, time_interval)
235235
return data
236236

237237

238-
def _create_array_for_variable(variable: Variable, nparticles: int):
238+
def _create_array_for_variable(variable: Variable, nparticles: int, time_interval: TimeInterval):
239239
assert not isinstance(variable.initial, operator.attrgetter), (
240240
"This function cannot handle attrgetter initial values."
241241
)
242+
if (dtype := variable.dtype) is _SAME_AS_FIELDSET_TIME_INTERVAL.VALUE:
243+
dtype = type(time_interval.left)
242244
return np.full(
243245
shape=(nparticles,),
244246
fill_value=variable.initial,
245-
dtype=variable.dtype,
247+
dtype=dtype,
246248
)
247249

248250

tests/v4/test_particle.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import numpy as np
22
import pytest
33

4-
from parcels.particle import ParticleClass, Variable
4+
from parcels._core.utils.time import TimeInterval
5+
from parcels._datasets.structured.generic import TIME
6+
from parcels.particle import _SAME_AS_FIELDSET_TIME_INTERVAL, Particle, ParticleClass, Variable, create_particle_data
57

68

79
def test_variable_init():
@@ -104,3 +106,39 @@ def test_particleclass_add_variable_collision():
104106

105107
with pytest.raises(ValueError, match="Variable name already exists: "):
106108
p_initial.add_variable([Variable("vara", dtype=np.float32, to_write=True)])
109+
110+
111+
@pytest.mark.parametrize(
112+
"particle",
113+
[
114+
ParticleClass(
115+
variables=[
116+
Variable("vara", dtype=np.float32, initial=1.0),
117+
Variable("varb", dtype=np.float32, initial=2.0),
118+
]
119+
),
120+
Particle,
121+
],
122+
)
123+
@pytest.mark.parametrize("nparticles", [5, 10])
124+
def test_create_particle_data(particle, nparticles):
125+
time_interval = TimeInterval(TIME[0], TIME[-1])
126+
ngrids = 4
127+
data = create_particle_data(pclass=particle, nparticles=nparticles, ngrids=ngrids, time_interval=time_interval)
128+
129+
assert isinstance(data, dict)
130+
assert len(data) == len(particle.variables) + 1 # ei variable is separate
131+
132+
variables = {var.name: var for var in particle.variables}
133+
134+
for variable_name in variables.keys():
135+
variable = variables[variable_name]
136+
variable_array = data[variable_name]
137+
138+
assert variable_array.shape[0] == nparticles
139+
140+
dtype = variable.dtype
141+
if dtype is _SAME_AS_FIELDSET_TIME_INTERVAL.VALUE:
142+
dtype = type(time_interval.left)
143+
144+
assert variable_array.dtype == dtype

0 commit comments

Comments
 (0)