1- import importlib .util
2- import os
3- import sys
4- import warnings
5- from copy import deepcopy
6- from glob import glob
7-
8- import dask .array as da
1+ import cftime
92import numpy as np
103import uxarray as ux
114from uxarray .neighbors import _barycentric_coordinates
12- import cftime
13-
14- from parcels ._compat import MPI
15- from parcels ._typing import GridIndexingType , InterpMethodOption , Mesh
16- from parcels .field import DeferredArray , Field , NestedField , VectorField
17- from parcels .grid import Grid
18- from parcels .gridset import GridSet
19- from parcels .particlefile import ParticleFile
20- from parcels .tools ._helpers import fieldset_repr
21- from parcels .tools .converters import TimeConverter , convert_xarray_time_units
22- from parcels .tools .loggers import logger
23- from parcels .tools .statuscodes import TimeExtrapolationError
24- from parcels .tools .warnings import FieldSetWarning
255
266__all__ = ["UXFieldSet" ]
277
288_inside_tol = 1e-6
9+
10+
2911class UXFieldSet :
3012 """A FieldSet class that holds hydrodynamic data needed to execute particles
31- in a UXArray.Dataset"""
13+ in a UXArray.Dataset
14+ """
3215
3316 def __init__ (self , uxds : ux .UxDataset , time_origin : float | np .datetime64 | np .timedelta64 | cftime .datetime = 0 ):
34-
35- # Ensure that dataset provides a grid, and the u and v velocity
17+ # Ensure that dataset provides a grid, and the u and v velocity
3618 # components at a minimum
3719 if not hasattr (uxds , "uxgrid" ):
3820 raise ValueError ("The UXArray dataset does not provide a grid" )
3921 if not hasattr (uxds , "u" ):
4022 raise ValueError ("The UXArray dataset does not provide u velocity data" )
4123 if not hasattr (uxds , "v" ):
4224 raise ValueError ("The UXArray dataset does not provide v velocity data" )
43-
25+
4426 self .time_origin = time_origin
4527 self .uxds = uxds
4628 self ._spatialhash = self .uxds .get_spatialhash ()
@@ -52,16 +34,15 @@ def _check_complete(self):
5234 assert self .uxds .uxgrid is not None , "UXFieldSet does not provide a grid"
5335
5436 def _face_interp (self , field , time , z , y , x , particle = None ):
55-
56- #ti, zi, fi = self.unravel_index(particle.ei) # Get the time, z, and face index of the particle
37+ # ti, zi, fi = self.unravel_index(particle.ei) # Get the time, z, and face index of the particle
5738 ti = 0
5839 zi = 0
5940 fi = particle .ei
60- return field [ti ,zi ,fi ]
41+ return field [ti , zi , fi ]
6142
6243 def _node_interp (self , field , time , z , y , x , particle = None ):
6344 """Performs barycentric interpolation of a field at a given location."""
64- #ti, zi, fi = self.unravel_index(particle.ei) # Get the time, z, and face index of the particle
45+ # ti, zi, fi = self.unravel_index(particle.ei) # Get the time, z, and face index of the particle
6546 ti = 0
6647 zi = 0
6748 fi = particle .ei
@@ -77,13 +58,12 @@ def _node_interp(self, field, time, z, y, x, particle=None):
7758
7859 coord = np .deg2rad ([x , y ])
7960 bcoord = _barycentric_coordinates (nodes , coord )
80- return np .sum (bcoord * field [ti ,zi ,node_ids ].flatten (), axis = 0 )
61+ return np .sum (bcoord * field [ti , zi , node_ids ].flatten (), axis = 0 )
8162
8263 def eval (self , field_names : list (str ), time , z , y , x , particle = None , applyConversion = True ):
83-
8464 res = {}
8565 if particle :
86- #ti, zi, fi = self.unravel_index(particle.ei) # Get the time, z, and face index of the particle
66+ # ti, zi, fi = self.unravel_index(particle.ei) # Get the time, z, and face index of the particle
8767 fi = particle .ei
8868 # Check if particle is in the same face, otherwise search again.
8969 n_nodes = self .uxds .uxgrid .n_nodes_per_face [fi ].to_numpy ()
@@ -104,12 +84,12 @@ def eval(self, field_names: list(str), time, z, y, x, particle=None, applyConver
10484 # To do : Get the vertical and time indices for the particle
10585
10686 if (not is_inside ) or (err > _inside_tol ):
107- fi = self ._spatialhash .query ([particle .x ,particle .y ]) # Get the face id for the particle
108- particle .ei = fi
87+ fi = self ._spatialhash .query ([particle .x , particle .y ]) # Get the face id for the particle
88+ particle .ei = fi
10989
11090 for f in field_names :
11191 field = getattr (self , f )
112- face_registered = ( "n_face" in field .dims )
92+ face_registered = "n_face" in field .dims
11393 if face_registered :
11494 if particle :
11595 r = self ._face_interp (field , particle .time , particle .z , particle .y , particle .x , particle )
@@ -125,9 +105,9 @@ def eval(self, field_names: list(str), time, z, y, x, particle=None, applyConver
125105 res [f ] = self .units .to_target (r , z , y , x )
126106 else :
127107 res [f ] = r
128-
108+
129109 return res
130-
110+
131111 # if self.U.interp_method not in ["cgrid_velocity", "partialslip", "freeslip"]:
132112 # u = self.U.eval(time, z, y, x, particle=particle, applyConversion=False)
133113 # v = self.V.eval(time, z, y, x, particle=particle, applyConversion=False)
0 commit comments