Skip to content

Commit 59c2d2f

Browse files
Merge branch 'v4-dev' into updates_from_virtualship_dev
2 parents dc41fe7 + aa90e1c commit 59c2d2f

File tree

15 files changed

+190
-76
lines changed

15 files changed

+190
-76
lines changed

.github/workflows/cache-pixi-lock.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ jobs:
1717
- uses: actions/checkout@v5
1818
with:
1919
fetch-depth: 0
20+
submodules: recursive
2021
- name: Get current date
2122
id: date
2223
run: echo "date=$(date +'%Y-%m-%d')" >> "$GITHUB_OUTPUT"

pixi.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,8 @@ docs = { features = ["docs"], solve-group = "docs" }
2626
typing = { features = ["typing"], solve-group = "typing" }
2727
pre-commit = { features = ["pre-commit"], no-default-feature = true }
2828

29-
[dependencies] # keep section in sync with pyproject.toml dependencies
29+
[package.run-dependencies] # keep section in sync with pyproject.toml dependencies
3030
python = ">=3.11"
31-
parcels = { path = "." }
3231
netcdf4 = ">=1.7.2"
3332
numpy = ">=2.1.0"
3433
tqdm = ">=4.50.0"
@@ -41,6 +40,9 @@ cf_xarray = ">=0.8.6"
4140
cftime = ">=1.6.3"
4241
pooch = ">=1.8.0"
4342

43+
[dependencies]
44+
parcels = { path = "." }
45+
4446
[feature.minimum.dependencies]
4547
python = "==3.11"
4648
netcdf4 = "==1.7.2"

src/parcels/_core/field.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
AllParcelsErrorCodes,
1919
StatusCode,
2020
)
21+
from parcels._core.utils.string import _assert_str_and_python_varname
2122
from parcels._core.utils.time import TimeInterval
2223
from parcels._core.uxgrid import UxGrid
2324
from parcels._core.xgrid import XGrid, _transpose_xfield_data_to_tzyx
@@ -101,8 +102,9 @@ def __init__(
101102
raise ValueError(
102103
f"Expected `data` to be a uxarray.UxDataArray or xarray.DataArray object, got {type(data)}."
103104
)
104-
if not isinstance(name, str):
105-
raise ValueError(f"Expected `name` to be a string, got {type(name)}.")
105+
106+
_assert_str_and_python_varname(name)
107+
106108
if not isinstance(grid, (UxGrid, XGrid)):
107109
raise ValueError(f"Expected `grid` to be a parcels UxGrid, or parcels XGrid object, got {type(grid)}.")
108110

@@ -246,6 +248,8 @@ class VectorField:
246248
def __init__(
247249
self, name: str, U: Field, V: Field, W: Field | None = None, vector_interp_method: Callable | None = None
248250
):
251+
_assert_str_and_python_varname(name)
252+
249253
self.name = name
250254
self.U = U
251255
self.V = V

src/parcels/_core/fieldset.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from parcels._core.converters import Geographic, GeographicPolar
1313
from parcels._core.field import Field, VectorField
14+
from parcels._core.utils.string import _assert_str_and_python_varname
1415
from parcels._core.utils.time import get_datetime_type_calendar
1516
from parcels._core.utils.time import is_compatible as datetime_is_compatible
1617
from parcels._core.xgrid import _DEFAULT_XGCM_KWARGS, XGrid
@@ -163,6 +164,8 @@ def add_constant(self, name, value):
163164
`Diffusion <../examples/tutorial_diffusion.ipynb>`__
164165
`Periodic boundaries <../examples/tutorial_periodic_boundaries.ipynb>`__
165166
"""
167+
_assert_str_and_python_varname(name)
168+
166169
if name in self.constants:
167170
raise ValueError(f"FieldSet already has a constant with name '{name}'")
168171
if not isinstance(value, (float, np.floating, int, np.integer)):
@@ -204,7 +207,10 @@ def from_copernicusmarine(ds: xr.Dataset):
204207
expected_axes = set("XYZT") # TODO: Update after we have support for 2D spatial fields
205208
if missing_axes := (expected_axes - set(ds.cf.axes)):
206209
raise ValueError(
207-
f"Dataset missing axes {missing_axes} to have coordinates for all {expected_axes} axes according to CF conventions."
210+
f"Dataset missing CF compliant metadata for axes "
211+
f"{missing_axes}. Expected 'axis' attribute to be set "
212+
f"on all dimension axes {expected_axes}. "
213+
"HINT: Add xarray metadata attribute 'axis' to dimension - e.g., ds['lat'].attrs['axis'] = 'Y'"
208214
)
209215

210216
ds = _rename_coords_copernicusmarine(ds)

src/parcels/_core/particle.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22

33
import enum
44
import operator
5-
from keyword import iskeyword
65
from typing import Literal
76

87
import numpy as np
98

109
from parcels._compat import _attrgetter_helper
1110
from parcels._core.statuscodes import StatusCode
11+
from parcels._core.utils.string import _assert_str_and_python_varname
1212
from parcels._core.utils.time import TimeInterval
1313
from parcels._reprs import _format_list_items_multiline
1414

@@ -45,9 +45,7 @@ def __init__(
4545
to_write: bool | Literal["once"] = True,
4646
attrs: dict | None = None,
4747
):
48-
if not isinstance(name, str):
49-
raise TypeError(f"Variable name must be a string. Got {name=!r}")
50-
_assert_valid_python_varname(name)
48+
_assert_str_and_python_varname(name)
5149

5250
try:
5351
dtype = np.dtype(dtype)
@@ -153,12 +151,6 @@ def _assert_no_duplicate_variable_names(*, existing_vars: list[Variable], new_va
153151
raise ValueError(f"Variable name already exists: {var.name}")
154152

155153

156-
def _assert_valid_python_varname(name):
157-
if name.isidentifier() and not iskeyword(name):
158-
return
159-
raise ValueError(f"Particle variable has to be a valid Python variable name. Got {name=!r}")
160-
161-
162154
def get_default_particle(spatial_dtype: np.float32 | np.float64) -> ParticleClass:
163155
if spatial_dtype not in [np.float32, np.float64]:
164156
raise ValueError(f"spatial_dtype must be np.float32 or np.float64. Got {spatial_dtype=!r}")

src/parcels/_core/utils/string.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from keyword import iskeyword, kwlist
2+
3+
4+
def _assert_str_and_python_varname(name):
5+
if not isinstance(name, str):
6+
raise TypeError(f"Expected a string for variable name, got {type(name).__name__} instead.")
7+
8+
msg = f"Received invalid Python variable name {name!r}: "
9+
10+
if not name.isidentifier():
11+
msg += "not a valid identifier. HINT: avoid using spaces, special characters, and starting with a number."
12+
raise ValueError(msg)
13+
14+
if iskeyword(name):
15+
msg += f"it is a reserved keyword. HINT: avoid using the following names: {', '.join(kwlist)}"
16+
raise ValueError(msg)

src/parcels/_datasets/structured/generic.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,10 @@ def _rotated_curvilinear_grid():
2323
{
2424
"data_g": (["time", "ZG", "YG", "XG"], np.random.rand(T, Z, Y, X)),
2525
"data_c": (["time", "ZC", "YC", "XC"], np.random.rand(T, Z, Y, X)),
26-
"U (A grid)": (["time", "ZG", "YG", "XG"], np.random.rand(T, Z, Y, X)),
27-
"V (A grid)": (["time", "ZG", "YG", "XG"], np.random.rand(T, Z, Y, X)),
28-
"U (C grid)": (["time", "ZG", "YC", "XG"], np.random.rand(T, Z, Y, X)),
29-
"V (C grid)": (["time", "ZG", "YG", "XC"], np.random.rand(T, Z, Y, X)),
26+
"U_A_grid": (["time", "ZG", "YG", "XG"], np.random.rand(T, Z, Y, X)),
27+
"V_A_grid": (["time", "ZG", "YG", "XG"], np.random.rand(T, Z, Y, X)),
28+
"U_C_grid": (["time", "ZG", "YC", "XG"], np.random.rand(T, Z, Y, X)),
29+
"V_C_grid": (["time", "ZG", "YG", "XC"], np.random.rand(T, Z, Y, X)),
3030
},
3131
coords={
3232
"XG": (["XG"], XG, {"axis": "X", "c_grid_axis_shift": -0.5}),
@@ -92,16 +92,19 @@ def _unrolled_cone_curvilinear_grid():
9292
new_lon_lat.append((lon + pivot[0], lat + pivot[1]))
9393

9494
new_lon, new_lat = zip(*new_lon_lat, strict=True)
95-
LON, LAT = np.array(new_lon).reshape(LON.shape), np.array(new_lat).reshape(LAT.shape)
95+
LON, LAT = (
96+
np.array(new_lon).reshape(LON.shape),
97+
np.array(new_lat).reshape(LAT.shape),
98+
)
9699

97100
return xr.Dataset(
98101
{
99102
"data_g": (["time", "ZG", "YG", "XG"], np.random.rand(T, Z, Y, X)),
100103
"data_c": (["time", "ZC", "YC", "XC"], np.random.rand(T, Z, Y, X)),
101-
"U (A grid)": (["time", "ZG", "YG", "XG"], np.random.rand(T, Z, Y, X)),
102-
"V (A grid)": (["time", "ZG", "YG", "XG"], np.random.rand(T, Z, Y, X)),
103-
"U (C grid)": (["time", "ZG", "YC", "XG"], np.random.rand(T, Z, Y, X)),
104-
"V (C grid)": (["time", "ZG", "YG", "XC"], np.random.rand(T, Z, Y, X)),
104+
"U_A_grid": (["time", "ZG", "YG", "XG"], np.random.rand(T, Z, Y, X)),
105+
"V_A_grid": (["time", "ZG", "YG", "XG"], np.random.rand(T, Z, Y, X)),
106+
"U_C_grid": (["time", "ZG", "YC", "XG"], np.random.rand(T, Z, Y, X)),
107+
"V_C_grid": (["time", "ZG", "YG", "XC"], np.random.rand(T, Z, Y, X)),
105108
},
106109
coords={
107110
"XG": (["XG"], XG, {"axis": "X", "c_grid_axis_shift": -0.5}),
@@ -140,10 +143,10 @@ def _unrolled_cone_curvilinear_grid():
140143
{
141144
"data_g": (["time", "ZG", "YG", "XG"], np.random.rand(T, Z, Y, X)),
142145
"data_c": (["time", "ZC", "YC", "XC"], np.random.rand(T, Z, Y, X)),
143-
"U (A grid)": (["time", "ZG", "YG", "XG"], np.random.rand(T, Z, Y, X)),
144-
"V (A grid)": (["time", "ZG", "YG", "XG"], np.random.rand(T, Z, Y, X)),
145-
"U (C grid)": (["time", "ZG", "YC", "XG"], np.random.rand(T, Z, Y, X)),
146-
"V (C grid)": (["time", "ZG", "YG", "XC"], np.random.rand(T, Z, Y, X)),
146+
"U_A_grid": (["time", "ZG", "YG", "XG"], np.random.rand(T, Z, Y, X)),
147+
"V_A_grid": (["time", "ZG", "YG", "XG"], np.random.rand(T, Z, Y, X)),
148+
"U_C_grid": (["time", "ZG", "YC", "XG"], np.random.rand(T, Z, Y, X)),
149+
"V_C_grid": (["time", "ZG", "YG", "XC"], np.random.rand(T, Z, Y, X)),
147150
},
148151
coords={
149152
"XG": (
@@ -182,10 +185,10 @@ def _unrolled_cone_curvilinear_grid():
182185
{
183186
"data_g": (["time", "ZG", "YG", "XG"], np.random.rand(T, Z, Y, X)),
184187
"data_c": (["time", "ZC", "YC", "XC"], np.random.rand(T, Z, Y, X)),
185-
"U (A grid)": (["time", "ZG", "YG", "XG"], np.random.rand(T, Z, Y, X)),
186-
"V (A grid)": (["time", "ZG", "YG", "XG"], np.random.rand(T, Z, Y, X)),
187-
"U (C grid)": (["time", "ZG", "YC", "XG"], np.random.rand(T, Z, Y, X)),
188-
"V (C grid)": (["time", "ZG", "YG", "XC"], np.random.rand(T, Z, Y, X)),
188+
"U_A_grid": (["time", "ZG", "YG", "XG"], np.random.rand(T, Z, Y, X)),
189+
"V_A_grid": (["time", "ZG", "YG", "XG"], np.random.rand(T, Z, Y, X)),
190+
"U_C_grid": (["time", "ZG", "YC", "XG"], np.random.rand(T, Z, Y, X)),
191+
"V_C_grid": (["time", "ZG", "YG", "XC"], np.random.rand(T, Z, Y, X)),
189192
},
190193
coords={
191194
"XG": (

tests/test_field.py

Lines changed: 70 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,27 @@
1515
def test_field_init_param_types():
1616
data = datasets_structured["ds_2d_left"]
1717
grid = XGrid.from_dataset(data)
18-
with pytest.raises(ValueError, match="Expected `name` to be a string"):
18+
19+
with pytest.raises(TypeError, match="Expected a string for variable name, got int instead."):
1920
Field(name=123, data=data["data_g"], grid=grid)
2021

21-
with pytest.raises(ValueError, match="Expected `data` to be a uxarray.UxDataArray or xarray.DataArray"):
22+
for name in ["a b", "123"]:
23+
with pytest.raises(
24+
ValueError,
25+
match=r"Received invalid Python variable name.*: not a valid identifier. HINT: avoid using spaces, special characters, and starting with a number.",
26+
):
27+
Field(name=name, data=data["data_g"], grid=grid)
28+
29+
with pytest.raises(
30+
ValueError,
31+
match=r"Received invalid Python variable name.*: it is a reserved keyword. HINT: avoid using the following names:.*",
32+
):
33+
Field(name="while", data=data["data_g"], grid=grid)
34+
35+
with pytest.raises(
36+
ValueError,
37+
match="Expected `data` to be a uxarray.UxDataArray or xarray.DataArray",
38+
):
2239
Field(name="test", data=123, grid=grid)
2340

2441
with pytest.raises(ValueError, match="Expected `grid` to be a parcels UxGrid, or parcels XGrid"):
@@ -28,7 +45,11 @@ def test_field_init_param_types():
2845
@pytest.mark.parametrize(
2946
"data,grid",
3047
[
31-
pytest.param(ux.UxDataArray(), XGrid.from_dataset(datasets_structured["ds_2d_left"]), id="uxdata-grid"),
48+
pytest.param(
49+
ux.UxDataArray(),
50+
XGrid.from_dataset(datasets_structured["ds_2d_left"]),
51+
id="uxdata-grid",
52+
),
3253
pytest.param(
3354
xr.DataArray(),
3455
UxGrid(
@@ -76,7 +97,11 @@ def test_field_init_fail_on_float_time_dim():
7697
(users are expected to use timedelta64 or datetime).
7798
"""
7899
ds = datasets_structured["ds_2d_left"].copy()
79-
ds["time"] = (ds["time"].dims, np.arange(0, T_structured, dtype="float64"), ds["time"].attrs)
100+
ds["time"] = (
101+
ds["time"].dims,
102+
np.arange(0, T_structured, dtype="float64"),
103+
ds["time"].attrs,
104+
)
80105

81106
data = ds["data_g"]
82107
grid = XGrid.from_dataset(ds)
@@ -122,7 +147,12 @@ def invalid_interpolator_wrong_signature(self, ti, position, tau, t, z, y, inval
122147

123148
# Test invalid interpolator with wrong signature
124149
with pytest.raises(ValueError, match=".*incorrect name.*"):
125-
Field(name="test", data=ds["data_g"], grid=grid, interp_method=invalid_interpolator_wrong_signature)
150+
Field(
151+
name="test",
152+
data=ds["data_g"],
153+
grid=grid,
154+
interp_method=invalid_interpolator_wrong_signature,
155+
)
126156

127157

128158
def test_vectorfield_invalid_interpolator():
@@ -138,7 +168,12 @@ def invalid_interpolator_wrong_signature(self, ti, position, tau, t, z, y, apply
138168

139169
# Test invalid interpolator with wrong signature
140170
with pytest.raises(ValueError, match=".*incorrect name.*"):
141-
VectorField(name="UV", U=U, V=V, vector_interp_method=invalid_interpolator_wrong_signature)
171+
VectorField(
172+
name="UV",
173+
U=U,
174+
V=V,
175+
vector_interp_method=invalid_interpolator_wrong_signature,
176+
)
142177

143178

144179
def test_field_unstructured_z_linear():
@@ -161,18 +196,34 @@ def test_field_unstructured_z_linear():
161196
P = Field(name="p", data=ds.p, grid=grid, interp_method=UXPiecewiseConstantFace)
162197

163198
# Test above first cell center - for piecewise constant, should return the depth of the first cell center
164-
assert np.isclose(P.eval(time=ds.time[0].values, z=[10.0], y=[30.0], x=[30.0], applyConversion=False), 55.555557)
199+
assert np.isclose(
200+
P.eval(time=ds.time[0].values, z=[10.0], y=[30.0], x=[30.0], applyConversion=False),
201+
55.555557,
202+
)
165203
# Test below first cell center, but in the first layer - for piecewise constant, should return the depth of the first cell center
166-
assert np.isclose(P.eval(time=ds.time[0].values, z=[65.0], y=[30.0], x=[30.0], applyConversion=False), 55.555557)
204+
assert np.isclose(
205+
P.eval(time=ds.time[0].values, z=[65.0], y=[30.0], x=[30.0], applyConversion=False),
206+
55.555557,
207+
)
167208
# Test bottom layer - for piecewise constant, should return the depth of the of the bottom layer cell center
168209
assert np.isclose(
169-
P.eval(time=ds.time[0].values, z=[900.0], y=[30.0], x=[30.0], applyConversion=False), 944.44445801
210+
P.eval(time=ds.time[0].values, z=[900.0], y=[30.0], x=[30.0], applyConversion=False),
211+
944.44445801,
170212
)
171213

172214
W = Field(name="W", data=ds.W, grid=grid, interp_method=UXPiecewiseLinearNode)
173-
assert np.isclose(W.eval(time=ds.time[0].values, z=[10.0], y=[30.0], x=[30.0], applyConversion=False), 10.0)
174-
assert np.isclose(W.eval(time=ds.time[0].values, z=[65.0], y=[30.0], x=[30.0], applyConversion=False), 65.0)
175-
assert np.isclose(W.eval(time=ds.time[0].values, z=[900.0], y=[30.0], x=[30.0], applyConversion=False), 900.0)
215+
assert np.isclose(
216+
W.eval(time=ds.time[0].values, z=[10.0], y=[30.0], x=[30.0], applyConversion=False),
217+
10.0,
218+
)
219+
assert np.isclose(
220+
W.eval(time=ds.time[0].values, z=[65.0], y=[30.0], x=[30.0], applyConversion=False),
221+
65.0,
222+
)
223+
assert np.isclose(
224+
W.eval(time=ds.time[0].values, z=[900.0], y=[30.0], x=[30.0], applyConversion=False),
225+
900.0,
226+
)
176227

177228

178229
def test_field_constant_in_time():
@@ -185,7 +236,13 @@ def test_field_constant_in_time():
185236
# Assert that the field can be evaluated at any time, and returns the same value
186237
time = np.datetime64("2000-01-01T00:00:00")
187238
P1 = P.eval(time=time, z=[10.0], y=[30.0], x=[30.0], applyConversion=False)
188-
P2 = P.eval(time=time + np.timedelta64(1, "D"), z=[10.0], y=[30.0], x=[30.0], applyConversion=False)
239+
P2 = P.eval(
240+
time=time + np.timedelta64(1, "D"),
241+
z=[10.0],
242+
y=[30.0],
243+
x=[30.0],
244+
applyConversion=False,
245+
)
189246
assert np.isclose(P1, P2)
190247

191248

0 commit comments

Comments
 (0)