Skip to content

Commit 829feaa

Browse files
committed
Fix type annotations
1 parent 2414796 commit 829feaa

File tree

3 files changed

+13
-9
lines changed

3 files changed

+13
-9
lines changed

src/parcels/_core/fieldset.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -284,14 +284,17 @@ def from_fesom2(ds: ux.UxDataset):
284284

285285
fields: dict[str, Field | VectorField] = {}
286286
if "U" in ds.data_vars and "V" in ds.data_vars:
287-
fields["U"] = Field("U", ds["U"], grid, _select_uxinterpolator(ds["U"]))
288-
fields["V"] = Field("V", ds["V"], grid, _select_uxinterpolator(ds["U"]))
287+
field_U = Field("U", ds["U"], grid, _select_uxinterpolator(ds["U"]))
288+
field_V = Field("V", ds["V"], grid, _select_uxinterpolator(ds["U"]))
289+
fields["U"] = field_U
290+
fields["V"] = field_V
289291

290292
if "W" in ds.data_vars:
291-
fields["W"] = Field("W", ds["W"], grid, _select_uxinterpolator(ds["U"]))
292-
fields["UVW"] = VectorField("UVW", fields["U"], fields["V"], fields["W"])
293+
field_W = Field("W", ds["W"], grid, _select_uxinterpolator(ds["U"]))
294+
fields["W"] = field_W
295+
fields["UVW"] = VectorField("UVW", field_U, field_V, field_W)
293296
else:
294-
fields["UV"] = VectorField("UV", fields["U"], fields["V"])
297+
fields["UV"] = VectorField("UV", field_U, field_V)
295298

296299
for varname in set(ds.data_vars) - set(fields.keys()):
297300
fields[varname] = Field(varname, ds[varname], grid, _select_uxinterpolator(ds[varname]))

src/parcels/_core/index_search.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import annotations
22

3-
from datetime import datetime
43
from typing import TYPE_CHECKING
54

65
import numpy as np
@@ -63,14 +62,14 @@ def _search_1d_array(
6362
return np.atleast_1d(index), np.atleast_1d(bcoord)
6463

6564

66-
def _search_time_index(field: Field, time: datetime):
65+
def _search_time_index(field: Field, time: float):
6766
"""Find and return the index and relative coordinate in the time array associated with a given time.
6867
6968
Parameters
7069
----------
7170
field: Field
7271
73-
time: datetime
72+
time: float
7473
This is the amount of time, in seconds (time_delta), in unix epoch
7574
Note that we normalize to either the first or the last index
7675
if the sampled value is outside the time value range.
@@ -172,6 +171,8 @@ def _search_indices_curvilinear_2d(
172171
"""
173172
if np.any(xi):
174173
# If an initial guess is provided, we first perform a point in cell check for all guessed indices
174+
assert xi is not None
175+
assert yi is not None
175176
is_in_cell, coords = curvilinear_point_in_cell(grid, y, x, yi, xi)
176177
y_check = y[is_in_cell == 0]
177178
x_check = x[is_in_cell == 0]

src/parcels/_core/xgrid.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,7 @@ def get_axis_dim_mapping(self, dims: list[str]) -> dict[_XGRID_AXES, str]:
380380
return result
381381

382382

383-
def get_axis_from_dim_name(axes: _XGCM_AXES, dim: str) -> _XGCM_AXIS_DIRECTION | None:
383+
def get_axis_from_dim_name(axes: _XGCM_AXES, dim: Hashable) -> _XGCM_AXIS_DIRECTION | None:
384384
"""For a given dimension name in a grid, returns the direction axis it is on."""
385385
for axis_name, axis in axes.items():
386386
if dim in axis.coords.values():

0 commit comments

Comments
 (0)