Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 18 additions & 13 deletions parcels/uxgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,34 +68,39 @@ def get_axis_dim(self, axis: _UXGRID_AXES) -> int:
return self.uxgrid.n_face

def search(self, z, y, x, ei=None):
tol = 1e-10
tol = 1e-6

def try_face(fid):
bcoords, err = self.uxgrid._get_barycentric_coordinates(y, x, fid)
bcoords, err = self._get_barycentric_coordinates(y, x, fid)
if (bcoords >= 0).all() and (bcoords <= 1).all() and err < tol:
return bcoords, fid
return None, None
return bcoords
return None

zi, zeta = _search_1d_array(self.z.values, z)

if ei is not None:
_, fi = self.unravel_index(ei)
bcoords, fi_new = try_face(fi)
bcoords = try_face(fi)
if bcoords is not None:
return bcoords, self.ravel_index(zi, fi_new)
return bcoords, self.ravel_index(zi, fi)
# Try neighbors of current face
for neighbor in self.uxgrid.face_face_connectivity[fi, :]:
if neighbor == -1:
continue
bcoords, fi_new = try_face(neighbor)
bcoords = try_face(neighbor)
if bcoords is not None:
return bcoords, self.ravel_index(zi, fi_new)
return bcoords, self.ravel_index(zi, neighbor)

# Global fallback using spatial hash
fi, bcoords = self.uxgrid.get_spatial_hash().query([[x, y]])
# Global fallback as last ditch effort
face_ids = self.uxgrid.get_faces_containing_point([x, y], return_counts=False)[0]
fi = face_ids[0] if len(face_ids) > 0 else -1
if fi == -1:
raise FieldOutOfBoundError(z, y, x)
return {"Z": (zi, zeta), "FACE": (fi[0], bcoords[0])}
bcoords = try_face(fi)
if bcoords is None:
raise FieldOutOfBoundError(z, y, x)

return {"Z": (zi, zeta), "FACE": (fi, bcoords)}

def _get_barycentric_coordinates(self, y, x, fi):
"""Checks if a point is inside a given face id on a UxGrid."""
Expand All @@ -104,8 +109,8 @@ def _get_barycentric_coordinates(self, y, x, fi):
node_ids = self.uxgrid.face_node_connectivity[fi, 0:n_nodes]
nodes = np.column_stack(
(
np.deg2rad(self.uxgrid.grid.node_lon[node_ids].to_numpy()),
np.deg2rad(self.uxgrid.grid.node_lat[node_ids].to_numpy()),
np.deg2rad(self.uxgrid.node_lon[node_ids].to_numpy()),
np.deg2rad(self.uxgrid.node_lat[node_ids].to_numpy()),
)
)

Expand Down
Loading