Skip to content

Commit 825e208

Browse files
committed
Supporting code for 3d interpolation refactor
1 parent d7f7dea commit 825e208

File tree

2 files changed

+44
-7
lines changed

2 files changed

+44
-7
lines changed

parcels/_interpolation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ def _linear_invdist_land_tracer_3d(ctx: InterpolationContext3D) -> float:
230230
@register_3d_interpolator("bgrid_w_velocity")
231231
@register_3d_interpolator("partialslip")
232232
@register_3d_interpolator("freeslip")
233-
def _linear_3d(ctx: InterpolationContext3D) -> float:
233+
def _linear_3d_old(ctx: InterpolationContext3D) -> float:
234234
zeta = ctx.zeta
235235
eta = ctx.eta
236236
xsi = ctx.xsi

tests/test_interpolation.py

Lines changed: 43 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ class TestInterpolationMethods:
6464
zi, yi, xi = 1, 1, 1
6565

6666
@pytest.mark.parametrize(
67-
"func, eta, xi, expected",
67+
"func, eta, xsi, expected",
6868
[
6969
pytest.param(interpolation._nearest_2d, 0.49, 0.49, 3.0, id="nearest_2d-1"),
7070
pytest.param(interpolation._nearest_2d, 0.49, 0.51, 4.0, id="nearest_2d-2"),
@@ -75,12 +75,12 @@ class TestInterpolationMethods:
7575
# pytest.param(interpolation._linear_invdist_land_tracer_2d, ...),
7676
],
7777
)
78-
def test_2d(self, data_2d, func, eta, xi, expected):
79-
ctx = interpolation.InterpolationContext2D(data_2d, eta, xi, self.ti, self.yi, self.xi)
78+
def test_2d(self, data_2d, func, eta, xsi, expected):
79+
ctx = interpolation.InterpolationContext2D(data_2d, eta, xsi, self.ti, self.yi, self.xi)
8080
assert func(ctx) == expected
8181

8282
@pytest.mark.parametrize(
83-
"func, eta, xi, expected",
83+
"func, eta, xsi, expected",
8484
[
8585
# pytest.param(interpolation._nearest_3d, ...),
8686
# pytest.param(interpolation._cgrid_velocity_3d, ...),
@@ -89,6 +89,43 @@ def test_2d(self, data_2d, func, eta, xi, expected):
8989
# pytest.param(interpolation._tracer_3d, ...),
9090
],
9191
)
92-
def test_3d(self, data_3d, func, zeta, eta, xi, expected):
93-
ctx = interpolation.InterpolationContext3D(data_2d, zeta, eta, xi, self.ti, self.zi, self.yi, self.xi)
92+
def test_3d(self, data_3d, func, zeta, eta, xsi, expected):
93+
ctx = interpolation.InterpolationContext3D(data_3d, zeta, eta, xsi, self.ti, self.zi, self.yi, self.xi)
9494
assert func(ctx) == expected
95+
96+
97+
@pytest.mark.parametrize("zeta", np.linspace(0, 1, 5))
98+
@pytest.mark.parametrize("eta", np.linspace(0, 1, 5))
99+
@pytest.mark.parametrize("xsi", np.linspace(0, 1, 5))
100+
@pytest.mark.parametrize(
101+
"interp_method",
102+
[
103+
"linear",
104+
"bgrid_velocity",
105+
"bgrid_w_velocity",
106+
"partialslip",
107+
"freeslip",
108+
],
109+
)
110+
@pytest.mark.parametrize(
111+
"gridindexingtype",
112+
[
113+
"mom5",
114+
"pop",
115+
],
116+
)
117+
def test_interpolation_3d_refactor(data_3d, zeta, eta, xsi, interp_method, gridindexingtype):
118+
# fmt: off
119+
f_old, f_new = {
120+
"linear": (interpolation._linear_3d_old, interpolation._linear_3d_old),
121+
"bgrid_velocity": (interpolation._linear_3d_old, interpolation._linear_3d_old),
122+
"bgrid_w_velocity": (interpolation._linear_3d_old, interpolation._linear_3d_old),
123+
"partialslip": (interpolation._linear_3d_old, interpolation._linear_3d_old),
124+
"freeslip": (interpolation._linear_3d_old, interpolation._linear_3d_old),
125+
}[interp_method]
126+
# fmt: on
127+
128+
ti, zi, yi, xi = 1, 1, 1, 1
129+
130+
ctx = interpolation.InterpolationContext3D(data_3d, zeta, eta, xsi, ti, zi, yi, xi, interp_method, gridindexingtype)
131+
assert np.isclose(f_old(ctx), f_new(ctx))

0 commit comments

Comments
 (0)