Skip to content

Commit 4caf74c

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent f1802ce commit 4caf74c

File tree

5 files changed

+44
-65
lines changed

5 files changed

+44
-65
lines changed

parcels/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from parcels.application_kernels import *
66
from parcels.field import *
77
from parcels.fieldset import *
8-
from parcels.uxfieldset import *
98
from parcels.grid import *
109
from parcels.gridset import *
1110
from parcels.interaction import *
@@ -14,3 +13,4 @@
1413
from parcels.particlefile import *
1514
from parcels.particleset import *
1615
from parcels.tools import *
16+
from parcels.uxfieldset import *

parcels/application_kernels/advection.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,16 @@
1414
"AdvectionRK45",
1515
]
1616

17-
def UxAdvectionEuler(particle,fieldset:UXFieldSet,time):
17+
18+
def UxAdvectionEuler(particle, fieldset: UXFieldSet, time):
1819
"""Advection of particles using Explicit Euler (aka Euler Forward) integration.
19-
on an unstructured grid."""
20-
vel = fieldset.eval(["u","v"],time,particle.depth,particle.lat,particle.lon,particle)
20+
on an unstructured grid.
21+
"""
22+
vel = fieldset.eval(["u", "v"], time, particle.depth, particle.lat, particle.lon, particle)
2123
particle.lon += vel["u"] * particle.dt
2224
particle.lat += vel["v"] * particle.dt
2325

26+
2427
def AdvectionRK4(particle, fieldset, time): # pragma: no cover
2528
"""Advection of particles using fourth-order Runge-Kutta integration."""
2629
(u1, v1) = fieldset.UV[particle]

parcels/particleset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import xarray as xr
1010
from scipy.spatial import KDTree
1111
from tqdm import tqdm
12-
import uxarray as ux
1312

1413
from parcels._compat import MPI
1514
from parcels.application_kernels.advection import AdvectionRK4
@@ -32,6 +31,7 @@
3231
from parcels.tools.statuscodes import StatusCode
3332
from parcels.tools.warnings import ParticleSetWarning
3433
from parcels.uxfieldset import UXFieldSet
34+
3535
__all__ = ["ParticleSet"]
3636

3737

@@ -169,7 +169,7 @@ def ArrayClass_init(self, *args, **kwargs):
169169
assert lon.size == time.size, "time and positions (lon, lat, depth) do not have the same lengths."
170170

171171
if type(fieldset) == UXFieldSet:
172-
lonlatdepth_dtype = np.float32 # To do : get precision from fieldset
172+
lonlatdepth_dtype = np.float32 # To do : get precision from fieldset
173173
else:
174174
if isinstance(fieldset.U, Field) and (not fieldset.U.allow_time_extrapolation):
175175
_warn_particle_times_outside_fieldset_time_bounds(time, fieldset.U.grid.time_full)

parcels/uxfieldset.py

Lines changed: 17 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,28 @@
1-
import importlib.util
2-
import os
3-
import sys
4-
import warnings
5-
from copy import deepcopy
6-
from glob import glob
7-
8-
import dask.array as da
1+
import cftime
92
import numpy as np
103
import uxarray as ux
114
from uxarray.neighbors import _barycentric_coordinates
12-
import cftime
13-
14-
from parcels._compat import MPI
15-
from parcels._typing import GridIndexingType, InterpMethodOption, Mesh
16-
from parcels.field import DeferredArray, Field, NestedField, VectorField
17-
from parcels.grid import Grid
18-
from parcels.gridset import GridSet
19-
from parcels.particlefile import ParticleFile
20-
from parcels.tools._helpers import fieldset_repr
21-
from parcels.tools.converters import TimeConverter, convert_xarray_time_units
22-
from parcels.tools.loggers import logger
23-
from parcels.tools.statuscodes import TimeExtrapolationError
24-
from parcels.tools.warnings import FieldSetWarning
255

266
__all__ = ["UXFieldSet"]
277

288
_inside_tol = 1e-6
9+
10+
2911
class UXFieldSet:
3012
"""A FieldSet class that holds hydrodynamic data needed to execute particles
31-
in a UXArray.Dataset"""
13+
in a UXArray.Dataset
14+
"""
3215

3316
def __init__(self, uxds: ux.UxDataset, time_origin: float | np.datetime64 | np.timedelta64 | cftime.datetime = 0):
34-
35-
# Ensure that dataset provides a grid, and the u and v velocity
17+
# Ensure that dataset provides a grid, and the u and v velocity
3618
# components at a minimum
3719
if not hasattr(uxds, "uxgrid"):
3820
raise ValueError("The UXArray dataset does not provide a grid")
3921
if not hasattr(uxds, "u"):
4022
raise ValueError("The UXArray dataset does not provide u velocity data")
4123
if not hasattr(uxds, "v"):
4224
raise ValueError("The UXArray dataset does not provide v velocity data")
43-
25+
4426
self.time_origin = time_origin
4527
self.uxds = uxds
4628
self._spatialhash = self.uxds.get_spatialhash()
@@ -52,16 +34,15 @@ def _check_complete(self):
5234
assert self.uxds.uxgrid is not None, "UXFieldSet does not provide a grid"
5335

5436
def _face_interp(self, field, time, z, y, x, particle=None):
55-
56-
#ti, zi, fi = self.unravel_index(particle.ei) # Get the time, z, and face index of the particle
37+
# ti, zi, fi = self.unravel_index(particle.ei) # Get the time, z, and face index of the particle
5738
ti = 0
5839
zi = 0
5940
fi = particle.ei
60-
return field[ti,zi,fi]
41+
return field[ti, zi, fi]
6142

6243
def _node_interp(self, field, time, z, y, x, particle=None):
6344
"""Performs barycentric interpolation of a field at a given location."""
64-
#ti, zi, fi = self.unravel_index(particle.ei) # Get the time, z, and face index of the particle
45+
# ti, zi, fi = self.unravel_index(particle.ei) # Get the time, z, and face index of the particle
6546
ti = 0
6647
zi = 0
6748
fi = particle.ei
@@ -77,13 +58,12 @@ def _node_interp(self, field, time, z, y, x, particle=None):
7758

7859
coord = np.deg2rad([x, y])
7960
bcoord = _barycentric_coordinates(nodes, coord)
80-
return np.sum(bcoord * field[ti,zi,node_ids].flatten(), axis=0)
61+
return np.sum(bcoord * field[ti, zi, node_ids].flatten(), axis=0)
8162

8263
def eval(self, field_names: list(str), time, z, y, x, particle=None, applyConversion=True):
83-
8464
res = {}
8565
if particle:
86-
#ti, zi, fi = self.unravel_index(particle.ei) # Get the time, z, and face index of the particle
66+
# ti, zi, fi = self.unravel_index(particle.ei) # Get the time, z, and face index of the particle
8767
fi = particle.ei
8868
# Check if particle is in the same face, otherwise search again.
8969
n_nodes = self.uxds.uxgrid.n_nodes_per_face[fi].to_numpy()
@@ -104,12 +84,12 @@ def eval(self, field_names: list(str), time, z, y, x, particle=None, applyConver
10484
# To do : Get the vertical and time indices for the particle
10585

10686
if (not is_inside) or (err > _inside_tol):
107-
fi = self._spatialhash.query([particle.x,particle.y]) # Get the face id for the particle
108-
particle.ei = fi
87+
fi = self._spatialhash.query([particle.x, particle.y]) # Get the face id for the particle
88+
particle.ei = fi
10989

11090
for f in field_names:
11191
field = getattr(self, f)
112-
face_registered = ("n_face" in field.dims)
92+
face_registered = "n_face" in field.dims
11393
if face_registered:
11494
if particle:
11595
r = self._face_interp(field, particle.time, particle.z, particle.y, particle.x, particle)
@@ -125,9 +105,9 @@ def eval(self, field_names: list(str), time, z, y, x, particle=None, applyConver
125105
res[f] = self.units.to_target(r, z, y, x)
126106
else:
127107
res[f] = r
128-
108+
129109
return res
130-
110+
131111
# if self.U.interp_method not in ["cgrid_velocity", "partialslip", "freeslip"]:
132112
# u = self.U.eval(time, z, y, x, particle=particle, applyConversion=False)
133113
# v = self.V.eval(time, z, y, x, particle=particle, applyConversion=False)

tests/test_uxfieldset.py

Lines changed: 18 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,33 @@
1-
from datetime import timedelta
2-
3-
import numpy as np
4-
import pytest
51
import uxarray as ux
62

7-
from parcels import (
8-
UXFieldSet,
9-
ParticleSet,
10-
Particle
11-
)
12-
3+
from parcels import Particle, ParticleSet, UXFieldSet
134
from tests.utils import TEST_DATA
145

156

167
def test_fesom_fieldset():
178
# Load a FESOM dataset
18-
grid_path=f"{TEST_DATA}/fesom_channel.nc"
19-
data_path=[f"{TEST_DATA}/u.fesom_channel.nc",
20-
f"{TEST_DATA}/v.fesom_channel.nc",
21-
f"{TEST_DATA}/w.fesom_channel.nc"]
22-
ds = ux.open_mfdataset(grid_path,data_path)
9+
grid_path = f"{TEST_DATA}/fesom_channel.nc"
10+
data_path = [
11+
f"{TEST_DATA}/u.fesom_channel.nc",
12+
f"{TEST_DATA}/v.fesom_channel.nc",
13+
f"{TEST_DATA}/w.fesom_channel.nc",
14+
]
15+
ds = ux.open_mfdataset(grid_path, data_path)
2316
fieldset = UXFieldSet(ds)
2417
fieldset._check_complete()
2518
# Check that the fieldset has the expected properties
2619
assert fieldset.uxds == ds
2720

21+
2822
def test_fesom_in_particleset():
29-
# Load a FESOM dataset
30-
grid_path=f"{TEST_DATA}/fesom_channel.nc"
31-
data_path=[f"{TEST_DATA}/u.fesom_channel.nc",
32-
f"{TEST_DATA}/v.fesom_channel.nc",
33-
f"{TEST_DATA}/w.fesom_channel.nc"]
34-
ds = ux.open_mfdataset(grid_path,data_path)
23+
# Load a FESOM dataset
24+
grid_path = f"{TEST_DATA}/fesom_channel.nc"
25+
data_path = [
26+
f"{TEST_DATA}/u.fesom_channel.nc",
27+
f"{TEST_DATA}/v.fesom_channel.nc",
28+
f"{TEST_DATA}/w.fesom_channel.nc",
29+
]
30+
ds = ux.open_mfdataset(grid_path, data_path)
3531
fieldset = UXFieldSet(ds)
3632
print(type(fieldset))
37-
pset = ParticleSet(fieldset, pclass=Particle)
33+
pset = ParticleSet(fieldset, pclass=Particle)

0 commit comments

Comments
 (0)