Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions docs/examples/example_mitgcm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from datetime import timedelta
from pathlib import Path
from typing import Literal

import numpy as np
import parcels
Expand All @@ -7,7 +9,7 @@
ptype = {"scipy": parcels.ScipyParticle, "jit": parcels.JITParticle}


def run_mitgcm_zonally_reentrant(mode):
def run_mitgcm_zonally_reentrant(mode: Literal["scipy", "jit"], path: Path):
"""Function that shows how to load MITgcm data in a zonally periodic domain."""
data_folder = parcels.download_example_dataset("MITgcm_example_data")
filenames = {
Expand Down Expand Up @@ -41,7 +43,7 @@ def periodicBC(particle, fieldset, time):
size=10,
)
pfile = parcels.ParticleFile(
"MIT_particles_" + str(mode) + ".zarr",
str(path),
pset,
outputdt=timedelta(days=1),
chunks=(len(pset), 1),
Expand All @@ -52,12 +54,15 @@ def periodicBC(particle, fieldset, time):
)


def test_mitgcm_output_compare():
run_mitgcm_zonally_reentrant("scipy")
run_mitgcm_zonally_reentrant("jit")
def test_mitgcm_output_compare(tmpdir):
def get_path(mode: Literal["scipy", "jit"]) -> Path:
return tmpdir / f"MIT_particles_{mode}.zarr"

ds_jit = xr.open_zarr("MIT_particles_jit.zarr")
ds_scipy = xr.open_zarr("MIT_particles_scipy.zarr")
for mode in ["scipy", "jit"]:
run_mitgcm_zonally_reentrant(mode, get_path(mode))

ds_jit = xr.open_zarr(get_path("jit"))
ds_scipy = xr.open_zarr(get_path("scipy"))

np.testing.assert_allclose(ds_jit.lat.data, ds_scipy.lat.data)
np.testing.assert_allclose(ds_jit.lon.data, ds_scipy.lon.data)
30 changes: 30 additions & 0 deletions tests/test_examples.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import os
import runpy
import shutil
import sys
import time
from pathlib import Path

import pytest
Expand All @@ -8,6 +11,33 @@
example_fnames = [path.name for path in example_folder.glob("*.py")]


@pytest.fixture(autouse=True)
def cleanup_generated_data_files():
"""Clean up generated data files from test run.
Records current folder contents before test, and cleans up any generated `.nc` files
and `.zarr` folders afterwards. For safety this is non-recursive. This function is
only necessary as the scripts being run aren't native pytest tests, so they don't
have access to the `tmpdir` fixture.
"""
folder_contents = os.listdir()
yield
time.sleep(0.1) # Buffer so that files are closed before we try to delete them.
for fname in os.listdir():
if fname in folder_contents:
continue
if not (fname.endswith(".nc") or fname.endswith(".zarr")):
continue

Check warning on line 31 in tests/test_examples.py

View check run for this annotation

Codecov / codecov/patch

tests/test_examples.py#L31

Added line #L31 was not covered by tests

path = Path(fname)
if path.is_dir():
shutil.rmtree(path)
else:
path.unlink()
print(f"Removed {path}")


@pytest.mark.parametrize("example_fname", example_fnames)
def test_example_script(example_fname):
script = str(example_folder / example_fname)
Expand Down
Loading