@@ -64,7 +64,7 @@ class TestInterpolationMethods:
6464 zi , yi , xi = 1 , 1 , 1
6565
6666 @pytest .mark .parametrize (
67- "func, eta, xi , expected" ,
67+ "func, eta, xsi , expected" ,
6868 [
6969 pytest .param (interpolation ._nearest_2d , 0.49 , 0.49 , 3.0 , id = "nearest_2d-1" ),
7070 pytest .param (interpolation ._nearest_2d , 0.49 , 0.51 , 4.0 , id = "nearest_2d-2" ),
@@ -75,12 +75,12 @@ class TestInterpolationMethods:
7575 # pytest.param(interpolation._linear_invdist_land_tracer_2d, ...),
7676 ],
7777 )
78- def test_2d (self , data_2d , func , eta , xi , expected ):
79- ctx = interpolation .InterpolationContext2D (data_2d , eta , xi , self .ti , self .yi , self .xi )
78+ def test_2d (self , data_2d , func , eta , xsi , expected ):
79+ ctx = interpolation .InterpolationContext2D (data_2d , eta , xsi , self .ti , self .yi , self .xi )
8080 assert func (ctx ) == expected
8181
8282 @pytest .mark .parametrize (
83- "func, eta, xi , expected" ,
83+ "func, eta, xsi , expected" ,
8484 [
8585 # pytest.param(interpolation._nearest_3d, ...),
8686 # pytest.param(interpolation._cgrid_velocity_3d, ...),
@@ -89,6 +89,43 @@ def test_2d(self, data_2d, func, eta, xi, expected):
8989 # pytest.param(interpolation._tracer_3d, ...),
9090 ],
9191 )
92- def test_3d (self , data_3d , func , zeta , eta , xi , expected ):
93- ctx = interpolation .InterpolationContext3D (data_2d , zeta , eta , xi , self .ti , self .zi , self .yi , self .xi )
92+ def test_3d (self , data_3d , func , zeta , eta , xsi , expected ):
93+ ctx = interpolation .InterpolationContext3D (data_3d , zeta , eta , xsi , self .ti , self .zi , self .yi , self .xi )
9494 assert func (ctx ) == expected
95+
96+
97+ @pytest .mark .parametrize ("zeta" , np .linspace (0 , 1 , 5 ))
98+ @pytest .mark .parametrize ("eta" , np .linspace (0 , 1 , 5 ))
99+ @pytest .mark .parametrize ("xsi" , np .linspace (0 , 1 , 5 ))
100+ @pytest .mark .parametrize (
101+ "interp_method" ,
102+ [
103+ "linear" ,
104+ "bgrid_velocity" ,
105+ "bgrid_w_velocity" ,
106+ "partialslip" ,
107+ "freeslip" ,
108+ ],
109+ )
110+ @pytest .mark .parametrize (
111+ "gridindexingtype" ,
112+ [
113+ "mom5" ,
114+ "pop" ,
115+ ],
116+ )
117+ def test_interpolation_3d_refactor (data_3d , zeta , eta , xsi , interp_method , gridindexingtype ):
118+ # fmt: off
119+ f_old , f_new = {
120+ "linear" : (interpolation ._linear_3d_old , interpolation ._linear_3d_old ),
121+ "bgrid_velocity" : (interpolation ._linear_3d_old , interpolation ._linear_3d_old ),
122+ "bgrid_w_velocity" : (interpolation ._linear_3d_old , interpolation ._linear_3d_old ),
123+ "partialslip" : (interpolation ._linear_3d_old , interpolation ._linear_3d_old ),
124+ "freeslip" : (interpolation ._linear_3d_old , interpolation ._linear_3d_old ),
125+ }[interp_method ]
126+ # fmt: on
127+
128+ ti , zi , yi , xi = 1 , 1 , 1 , 1
129+
130+ ctx = interpolation .InterpolationContext3D (data_3d , zeta , eta , xsi , ti , zi , yi , xi , interp_method , gridindexingtype )
131+ assert np .isclose (f_old (ctx ), f_new (ctx ))
0 commit comments