diff --git a/docs/examples/example_mitgcm.py b/docs/examples/example_mitgcm.py index 6a7542fe56..f0f6a62c6c 100644 --- a/docs/examples/example_mitgcm.py +++ b/docs/examples/example_mitgcm.py @@ -1,4 +1,6 @@ from datetime import timedelta +from pathlib import Path +from typing import Literal import numpy as np import parcels @@ -7,7 +9,7 @@ ptype = {"scipy": parcels.ScipyParticle, "jit": parcels.JITParticle} -def run_mitgcm_zonally_reentrant(mode): +def run_mitgcm_zonally_reentrant(mode: Literal["scipy", "jit"], path: Path): """Function that shows how to load MITgcm data in a zonally periodic domain.""" data_folder = parcels.download_example_dataset("MITgcm_example_data") filenames = { @@ -41,7 +43,7 @@ def periodicBC(particle, fieldset, time): size=10, ) pfile = parcels.ParticleFile( - "MIT_particles_" + str(mode) + ".zarr", + str(path), pset, outputdt=timedelta(days=1), chunks=(len(pset), 1), @@ -52,12 +54,15 @@ def periodicBC(particle, fieldset, time): ) -def test_mitgcm_output_compare(): - run_mitgcm_zonally_reentrant("scipy") - run_mitgcm_zonally_reentrant("jit") +def test_mitgcm_output_compare(tmpdir): + def get_path(mode: Literal["scipy", "jit"]) -> Path: + return tmpdir / f"MIT_particles_{mode}.zarr" - ds_jit = xr.open_zarr("MIT_particles_jit.zarr") - ds_scipy = xr.open_zarr("MIT_particles_scipy.zarr") + for mode in ["scipy", "jit"]: + run_mitgcm_zonally_reentrant(mode, get_path(mode)) + + ds_jit = xr.open_zarr(get_path("jit")) + ds_scipy = xr.open_zarr(get_path("scipy")) np.testing.assert_allclose(ds_jit.lat.data, ds_scipy.lat.data) np.testing.assert_allclose(ds_jit.lon.data, ds_scipy.lon.data) diff --git a/tests/test_examples.py b/tests/test_examples.py index 30bbd58355..de42fbb86f 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -1,5 +1,8 @@ +import os import runpy +import shutil import sys +import time from pathlib import Path import pytest @@ -8,6 +11,33 @@ example_fnames = [path.name for path in example_folder.glob("*.py")] +@pytest.fixture(autouse=True) +def cleanup_generated_data_files(): + """Clean up generated data files from test run. + + Records current folder contents before test, and cleans up any generated `.nc` files + and `.zarr` folders afterwards. For safety this is non-recursive. This function is + only necessary as the scripts being run aren't native pytest tests, so they don't + have access to the `tmpdir` fixture. + + """ + folder_contents = os.listdir() + yield + time.sleep(0.1) # Buffer so that files are closed before we try to delete them. + for fname in os.listdir(): + if fname in folder_contents: + continue + if not (fname.endswith(".nc") or fname.endswith(".zarr")): + continue + + path = Path(fname) + if path.is_dir(): + shutil.rmtree(path) + else: + path.unlink() + print(f"Removed {path}") + + @pytest.mark.parametrize("example_fname", example_fnames) def test_example_script(example_fname): script = str(example_folder / example_fname)