Skip to content

Commit 2f7b05f

Browse files
Fixing failing unit tests
1 parent e086bfa commit 2f7b05f

File tree

2 files changed

+15
-12
lines changed

2 files changed

+15
-12
lines changed

tests/test_fieldset_sampling.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def create_fieldset_geometric(xdim=200, ydim=100):
4848
"""Standard earth fieldset with U and V equivalent to lon/lat in m."""
4949
lon = np.linspace(-180, 180, xdim, dtype=np.float32)
5050
lat = np.linspace(-90, 90, ydim, dtype=np.float32)
51-
U, V = np.meshgrid(lat, lon)
51+
V, U = np.meshgrid(lon, lat)
5252
U *= 1000.0 * 1.852 * 60.0
5353
V *= 1000.0 * 1.852 * 60.0
5454
data = {"U": U, "V": V}
@@ -70,10 +70,10 @@ def create_fieldset_geometric_polar(xdim=200, ydim=100):
7070
"""
7171
lon = np.linspace(-180, 180, xdim, dtype=np.float32)
7272
lat = np.linspace(-90, 90, ydim, dtype=np.float32)
73-
U, V = np.meshgrid(lat, lon)
73+
V, U = np.meshgrid(lon, lat)
7474
# Apply inverse of pole correction to U
7575
for i, y in enumerate(lat):
76-
U[:, i] *= cos(y * pi / 180)
76+
U[i, :] *= cos(y * pi / 180)
7777
U *= 1000.0 * 1.852 * 60.0
7878
V *= 1000.0 * 1.852 * 60.0
7979
data = {"U": U, "V": V}
@@ -147,9 +147,9 @@ def test_pset_from_field():
147147
"lon": np.linspace(0.0, 1.0, xdim, dtype=np.float32),
148148
"lat": np.linspace(0.0, 1.0, ydim, dtype=np.float32),
149149
}
150-
startfield = np.ones((xdim, ydim), dtype=np.float32)
150+
startfield = np.ones((ydim, xdim), dtype=np.float32)
151151
for x in range(xdim):
152-
startfield[x, :] = x
152+
startfield[:, x] = x
153153
data = {
154154
"U": np.zeros((ydim, xdim), dtype=np.float32),
155155
"V": np.zeros((ydim, xdim), dtype=np.float32),
@@ -166,7 +166,7 @@ def test_pset_from_field():
166166

167167
fieldset.add_field(densfield)
168168
pset = ParticleSet.from_field(fieldset, size=npart, pclass=Particle, start_field=fieldset.start)
169-
pdens = np.histogram2d(pset.lon, pset.lat, bins=[np.linspace(0.0, 1.0, xdim + 1), np.linspace(0.0, 1.0, ydim + 1)])[
169+
pdens = np.histogram2d(pset.lat, pset.lon, bins=[np.linspace(0.0, 1.0, ydim + 1), np.linspace(0.0, 1.0, xdim + 1)])[
170170
0
171171
]
172172
assert np.allclose(pdens / sum(pdens.flatten()), startfield / sum(startfield.flatten()), atol=1e-2)
@@ -184,7 +184,7 @@ def test_nearest_neighbor_interpolation2D():
184184
"V": np.zeros(dims, dtype=np.float32),
185185
"P": np.zeros(dims, dtype=np.float32),
186186
}
187-
data["P"][0, 1] = 1.0
187+
data["P"][1, 0] = 1.0
188188
fieldset = FieldSet.from_data(data, dimensions, mesh="flat")
189189
fieldset.P.interp_method = "nearest"
190190
xv, yv = np.meshgrid(np.linspace(0.0, 1.0, int(np.sqrt(npart))), np.linspace(0.0, 1.0, int(np.sqrt(npart))))
@@ -207,7 +207,7 @@ def test_nearest_neighbor_interpolation3D():
207207
"V": np.zeros(dims, dtype=np.float32),
208208
"P": np.zeros(dims, dtype=np.float32),
209209
}
210-
data["P"][0, 1, 1] = 1.0
210+
data["P"][1, 1, 0] = 1.0
211211
fieldset = FieldSet.from_data(data, dimensions, mesh="flat")
212212
fieldset.P.interp_method = "nearest"
213213
xv, yv = np.meshgrid(np.linspace(0, 1.0, int(np.sqrt(npart))), np.linspace(0, 1.0, int(np.sqrt(npart))))
@@ -391,7 +391,7 @@ def test_fieldset_sample_particle(lat_flip):
391391
lat = np.linspace(90, -90, 100, dtype=np.float32)
392392
else:
393393
lat = np.linspace(-90, 90, 100, dtype=np.float32)
394-
U, V = np.meshgrid(lat, lon)
394+
V, U = np.meshgrid(lon, lat)
395395
data = {"U": U, "V": V}
396396
dimensions = {"lon": lon, "lat": lat}
397397

@@ -557,7 +557,7 @@ def test_sampling_out_of_bounds_time(allow_time_extrapolation):
557557
data = {
558558
"U": np.zeros((tdim, ydim, xdim), dtype=np.float32),
559559
"V": np.zeros((tdim, ydim, xdim), dtype=np.float32),
560-
"P": np.ones((1, ydim, xdim), dtype=np.float32) * dimensions["time"],
560+
"P": np.transpose(np.ones((xdim, ydim, 1), dtype=np.float32) * dimensions["time"]),
561561
}
562562

563563
fieldset = FieldSet.from_data(data, dimensions, mesh="flat", allow_time_extrapolation=allow_time_extrapolation)
@@ -590,6 +590,9 @@ def test_sampling_out_of_bounds_time(allow_time_extrapolation):
590590
pset.execute(SampleP, runtime=0.1, dt=0.1)
591591

592592

593+
test_sampling_out_of_bounds_time(True)
594+
595+
593596
def test_sampling_3DCROCO():
594597
data_path = os.path.join(os.path.dirname(__file__), "test_data/")
595598
fieldset = FieldSet.from_modulefile(data_path + "fieldset_CROCO3D.py")

tests/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def create_fieldset_unit_mesh(xdim=20, ydim=20, mesh="flat") -> FieldSet:
1717
"""Standard unit mesh fieldset with U and V equivalent to longitude and latitude."""
1818
lon = np.linspace(0.0, 1.0, xdim, dtype=np.float32)
1919
lat = np.linspace(0.0, 1.0, ydim, dtype=np.float32)
20-
U, V = np.meshgrid(lat, lon)
20+
V, U = np.meshgrid(lon, lat)
2121
data = {"U": np.array(U, dtype=np.float32), "V": np.array(V, dtype=np.float32)}
2222
dimensions = {"lat": lat, "lon": lon}
2323
return FieldSet.from_data(data, dimensions, mesh=mesh)
@@ -55,7 +55,7 @@ def create_fieldset_global(xdim=200, ydim=100):
5555
"""Standard fieldset spanning the earth's coordinates with U and V equivalent to longitude and latitude in deg."""
5656
lon = np.linspace(-180, 180, xdim, dtype=np.float32)
5757
lat = np.linspace(-90, 90, ydim, dtype=np.float32)
58-
U, V = np.meshgrid(lat, lon)
58+
V, U = np.meshgrid(lon, lat)
5959
data = {"U": U, "V": V}
6060
dimensions = {"lon": lon, "lat": lat}
6161
return FieldSet.from_data(data, dimensions, mesh="flat")

0 commit comments

Comments
 (0)