Skip to content

Commit

Permalink
fix fixtures, add pytest-xdist
Browse files Browse the repository at this point in the history
  • Loading branch information
Zeitsperre committed Sep 13, 2024
1 parent 42ff945 commit 3556b1c
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 179 deletions.
1 change: 1 addition & 0 deletions environment-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ dependencies:
- typer >=0.12.3
- pytest <8.0.0
- pytest-cov >=5.0.0
- pytest-xdist >=3.2.0
- black ==24.8.0
- blackdoc ==0.3.9
- isort ==5.13.2
Expand Down
8 changes: 6 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ dev = [
"numpydoc >=1.8.0; python_version >='3.9'",
"pytest <8.0.0",
"pytest-cov >=5.0.0",
"pytest-xdist >=3.2.0",
"black ==24.8.0",
"blackdoc ==0.3.9",
"isort ==5.13.2",
Expand Down Expand Up @@ -300,10 +301,13 @@ override_SS05 = [

[tool.pytest.ini_options]
addopts = [
"--verbose",
"--color=yes"
"--color=yes",
"--numprocesses=0",
"--maxprocesses=8",
"--dist=worksteal"
]
filterwarnings = ["ignore::UserWarning"]
strict_markers = true
testpaths = "tests"
usefixtures = "xdoctest_namespace"

Expand Down
277 changes: 100 additions & 177 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,71 +4,64 @@

import os
import re
import shutil
import sys
import time
import warnings
from datetime import datetime as dt
from functools import partial
from pathlib import Path

import numpy as np
import pandas as pd
import pytest
import xarray as xr

# from filelock import FileLock
from packaging.version import Version
from xclim.testing import helpers
from xclim.testing.utils import (
TESTDATA_BRANCH,
TESTDATA_CACHE_DIR,
TESTDATA_REPO_URL,
default_testdata_cache,
gather_testing_data,
)
from xclim.testing.utils import nimbus as _nimbus
from xclim.testing.utils import open_dataset as _open_dataset

from xsdba.testing import TESTDATA_BRANCH # , generate_atmos
from xsdba.testing import open_dataset as _open_dataset
import xsdba
from xsdba import __version__ as __xsdba_version__
from xsdba.testing import (
test_cannon_2015_dist,
test_cannon_2015_rvs,
test_timelonlatseries,
test_timeseries,
)
from xsdba.utils import apply_correction, equally_spaced_nodes

# import xclim
# from xclim import __version__ as __xclim_version__
# from xclim.core.calendar import max_doy
# from xclim.testing import helpers
# from xclim.testing.utils import _default_cache_dir
# from xclim.testing.utils import get_file
# from xclim.testing.utils import open_dataset as _open_dataset
from xsdba.utils import apply_correction

# ADAPT
# if (
# re.match(r"^\d+\.\d+\.\d+$", __xclim_version__)
# and helpers.TESTDATA_BRANCH == "main"
# ):
# # This does not need to be emitted on GitHub Workflows and ReadTheDocs
# if not os.getenv("CI") and not os.getenv("READTHEDOCS"):
# warnings.warn(
# f'`xclim` {__xclim_version__} is running tests against the "main" branch of `Ouranosinc/xclim-testdata`. '
# "It is possible that changes in xclim-testdata may be incompatible with test assertions in this version. "
# "Please be sure to check https://github.com/Ouranosinc/xclim-testdata for more information.",
# UserWarning,
# )
if re.match(r"^\d+\.\d+\.\d+$", __xsdba_version__) and TESTDATA_BRANCH == "main":
# This does not need to be emitted on GitHub Workflows and ReadTheDocs
if not os.getenv("CI") and not os.getenv("READTHEDOCS"):
warnings.warn(
f'`xclim` {__xsdba_version__} is running tests against the "main" branch of `Ouranosinc/xclim-testdata`. '
"It is possible that changes in xclim-testdata may be incompatible with test assertions in this version. "
"Please be sure to check https://github.com/Ouranosinc/xclim-testdata for more information.",
UserWarning,
)

# if re.match(r"^v\d+\.\d+\.\d+", helpers.TESTDATA_BRANCH):
# # Find the date of last modification of xclim source files to generate a calendar version
# install_date = dt.strptime(
# time.ctime(os.path.getmtime(xclim.__file__)),
# "%a %b %d %H:%M:%S %Y",
# )
# install_calendar_version = (
# f"{install_date.year}.{install_date.month}.{install_date.day}"
# )

# if Version(helpers.TESTDATA_BRANCH) > Version(install_calendar_version):
# warnings.warn(
# f"Installation date of `xclim` ({install_date.ctime()}) "
# f"predates the last release of `xclim-testdata` ({helpers.TESTDATA_BRANCH}). "
# "It is very likely that the testing data is incompatible with this build of `xclim`.",
# UserWarning,
# )
if re.match(r"^v\d+\.\d+\.\d+", TESTDATA_BRANCH):
# Find the date of last modification of xclim source files to generate a calendar version
install_date = dt.strptime(
time.ctime(Path(xsdba.__file__).stat().st_mtime),
"%a %b %d %H:%M:%S %Y",
)
install_calendar_version = (
f"{install_date.year}.{install_date.month}.{install_date.day}"
)

if Version(TESTDATA_BRANCH) > Version(install_calendar_version):
warnings.warn(
f"Installation date of `xsdba` ({install_date.ctime()}) "
f"predates the last release of `xclim-testdata` ({TESTDATA_BRANCH}). "
"It is very likely that the testing data is incompatible with this build of `xsdba`.",
UserWarning,
)


@pytest.fixture
Expand Down Expand Up @@ -121,30 +114,6 @@ def random() -> np.random.Generator:
return np.random.default_rng(seed=list(map(ord, "𝕽𝔞𝖓𝔡𝖔𝔪")))


# ADAPT
# @pytest.fixture
# def tmp_netcdf_filename(tmpdir) -> Path:
# yield Path(tmpdir).joinpath("testfile.nc")


@pytest.fixture(autouse=True, scope="session")
def threadsafe_data_dir(tmp_path_factory) -> Path:
yield Path(tmp_path_factory.getbasetemp().joinpath("data"))


@pytest.fixture(scope="session")
def open_dataset(threadsafe_data_dir):
def _open_session_scoped_file(
file: str | os.PathLike, branch: str = TESTDATA_BRANCH, **xr_kwargs
):
xr_kwargs.setdefault("engine", "h5netcdf")
return _open_dataset(
file, cache_dir=threadsafe_data_dir, branch=branch, **xr_kwargs
)

return _open_session_scoped_file


# XC
@pytest.fixture
def mon_triangular():
Expand Down Expand Up @@ -216,20 +185,6 @@ def areacella() -> xr.DataArray:
areacello = areacella


# ADAPT?
# @pytest.fixture(scope="session")
# def open_dataset(threadsafe_data_dir):
# def _open_session_scoped_file(
# file: str | os.PathLike, branch: str = helpers.TESTDATA_BRANCH, **xr_kwargs
# ):
# xr_kwargs.setdefault("engine", "h5netcdf")
# return _open_dataset(
# file, cache_dir=threadsafe_data_dir, branch=branch, **xr_kwargs
# )

# return _open_session_scoped_file


# ADAPT?
# @pytest.fixture(autouse=True, scope="session")
# def add_imports(xdoctest_namespace, threadsafe_data_dir) -> None:
Expand Down Expand Up @@ -277,103 +232,71 @@ def atmosds(threadsafe_data_dir) -> xr.Dataset:
).load()


# @pytest.fixture(scope="function")
# def ensemble_dataset_objects() -> dict:
# edo = dict()
# edo["nc_files_simple"] = [
# "EnsembleStats/BCCAQv2+ANUSPLIN300_ACCESS1-0_historical+rcp45_r1i1p1_1950-2100_tg_mean_YS.nc",
# "EnsembleStats/BCCAQv2+ANUSPLIN300_BNU-ESM_historical+rcp45_r1i1p1_1950-2100_tg_mean_YS.nc",
# "EnsembleStats/BCCAQv2+ANUSPLIN300_CCSM4_historical+rcp45_r1i1p1_1950-2100_tg_mean_YS.nc",
# "EnsembleStats/BCCAQv2+ANUSPLIN300_CCSM4_historical+rcp45_r2i1p1_1950-2100_tg_mean_YS.nc",
# ]
# edo["nc_files_extra"] = [
# "EnsembleStats/BCCAQv2+ANUSPLIN300_CNRM-CM5_historical+rcp45_r1i1p1_1970-2050_tg_mean_YS.nc"
# ]
# edo["nc_files"] = edo["nc_files_simple"] + edo["nc_files_extra"]
# return edo


# @pytest.fixture(scope="session")
# def lafferty_sriver_ds() -> xr.Dataset:
# """Get data from Lafferty & Sriver unit test.

# Notes
# -----
# https://github.com/david0811/lafferty-sriver_2023_npjCliAtm/tree/main/unit_test
# """
# fn = get_file(
# "uncertainty_partitioning/seattle_avg_tas.csv",
# cache_dir=_default_cache_dir,
# branch=helpers.TESTDATA_BRANCH,
# )

# df = pd.read_csv(fn, parse_dates=["time"]).rename(
# columns={"ssp": "scenario", "ensemble": "downscaling"}
# )
@pytest.fixture(scope="session")
def threadsafe_data_dir(tmp_path_factory):
return Path(tmp_path_factory.getbasetemp().joinpath("data"))

# # Make xarray dataset
# return xr.Dataset.from_dataframe(
# df.set_index(["scenario", "model", "downscaling", "time"])
# )

# ADAPT or REMOVE?
# @pytest.fixture(scope="session", autouse=True)
# def gather_session_data(threadsafe_data_dir):
# """Gather testing data on pytest run.

# When running pytest with multiple workers, one worker will copy data remotely to _default_cache_dir while
# other workers wait using lockfile. Once the lock is released, all workers will then copy data to their local
# threadsafe_data_dir.As this fixture is scoped to the session, it will only run once per pytest run.

# Additionally, this fixture is also used to generate the `atmosds` synthetic testing dataset as well as add the
# example file paths to the xdoctest_namespace, used when running doctests.
# """
# generate_atmos(threadsafe_data_dir)


# if (
# not _default_cache_dir.joinpath(helpers.TESTDATA_BRANCH).exists()
# or helpers.PREFETCH_TESTING_DATA
# ):
# if helpers.PREFETCH_TESTING_DATA:
# print("`XCLIM_PREFETCH_TESTING_DATA` set. Prefetching testing data...")
# if sys.platform == "win32":
# raise OSError(
# "UNIX-style file-locking is not supported on Windows. "
# "Consider running `$ xclim prefetch_testing_data` to download testing data."
# )
# elif worker_id in ["master"]:
# helpers.populate_testing_data(branch=helpers.TESTDATA_BRANCH)
# else:
# _default_cache_dir.mkdir(exist_ok=True, parents=True)
# lockfile = _default_cache_dir.joinpath(".lock")
# test_data_being_written = FileLock(lockfile)
# with test_data_being_written:
# # This flag prevents multiple calls from re-attempting to download testing data in the same pytest run
# helpers.populate_testing_data(branch=helpers.TESTDATA_BRANCH)
# _default_cache_dir.joinpath(".data_written").touch()
# with test_data_being_written.acquire():
# if lockfile.exists():
# lockfile.unlink()
# shutil.copytree(_default_cache_dir, threadsafe_data_dir)
# xdoctest_namespace.update(helpers.add_example_file_paths(threadsafe_data_dir))


# @pytest.fixture(scope="session", autouse=True)
# def cleanup(request):
# """Cleanup a testing file once we are finished.

# This flag prevents remote data from being downloaded multiple times in the same pytest run.
# """

# def remove_data_written_flag():
# flag = _default_cache_dir.joinpath(".data_written")
# if flag.exists():
# flag.unlink()

# request.addfinalizer(remove_data_written_flag)
@pytest.fixture(scope="session")
def nimbus(threadsafe_data_dir, worker_id):
return _nimbus(
repo=TESTDATA_REPO_URL,
branch=TESTDATA_BRANCH,
cache_dir=(
TESTDATA_CACHE_DIR if worker_id == "master" else threadsafe_data_dir
),
)


@pytest.fixture
def tmp_netcdf_filename(tmpdir) -> Path:
yield Path(tmpdir).joinpath("testfile.nc")


@pytest.fixture(scope="session")
def open_dataset(threadsafe_data_dir):
def _open_session_scoped_file(
file: str | os.PathLike, branch: str = helpers.TESTDATA_BRANCH, **xr_kwargs
):
xr_kwargs.setdefault("engine", "h5netcdf")
return _open_dataset(
file, cache_dir=threadsafe_data_dir, branch=branch, **xr_kwargs
)

return _open_session_scoped_file


@pytest.fixture(autouse=True, scope="session")
def gather_session_data(request, nimbus, worker_id):
"""Gather testing data on pytest run.
When running pytest with multiple workers, one worker will copy data remotely to default cache dir while
other workers wait using lockfile. Once the lock is released, all workers will then copy data to their local
threadsafe_data_dir. As this fixture is scoped to the session, it will only run once per pytest run.
Due to the lack of UNIX sockets on Windows, the lockfile mechanism is not supported, requiring users on
Windows to run `$ xclim prefetch_testing_data` before running any tests for the first time to populate the
default cache dir.
Additionally, this fixture is also used to generate the `atmosds` synthetic testing dataset.
"""
gather_testing_data(worker_cache_dir=nimbus.path, worker_id=worker_id)


@pytest.fixture(scope="session", autouse=True)
def cleanup(request):
"""Cleanup a testing file once we are finished.
This flag prevents remote data from being downloaded multiple times in the same pytest run.
"""

def remove_data_written_flag():
flag = default_testdata_cache.joinpath(".data_written")
if flag.exists():
flag.unlink()

request.addfinalizer(remove_data_written_flag)


def timeseries():
return test_timeseries

0 comments on commit 3556b1c

Please sign in to comment.