Skip to content

Commit

Permalink
add needed functions for generate_atmos - WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
Zeitsperre committed Aug 5, 2024
1 parent 9164be9 commit 75a159e
Show file tree
Hide file tree
Showing 4 changed files with 271 additions and 12 deletions.
41 changes: 41 additions & 0 deletions src/xsdba/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from scipy.stats import gamma
from xarray import open_dataset as _open_dataset

from xsdba.calendar import percentile_doy
from xsdba.utils import equally_spaced_nodes

__all__ = ["test_timelonlatseries", "test_timeseries"]
Expand Down Expand Up @@ -343,3 +344,43 @@ def nancov(X):
"""Drop observations with NaNs from Numpy's cov."""
X_na = np.isnan(X).any(axis=0)
return np.cov(X[:, ~X_na])


# XC
def generate_atmos(cache_dir: Path) -> dict[str, xr.DataArray]:
"""Create the `atmosds` synthetic testing dataset."""
with open_dataset(
"ERA5/daily_surface_cancities_1990-1993.nc",
cache_dir=cache_dir,
branch=TESTDATA_BRANCH,
engine="h5netcdf",
) as ds:
tn10 = percentile_doy(ds.tasmin, per=10)
t10 = percentile_doy(ds.tas, per=10)
t90 = percentile_doy(ds.tas, per=90)
tx90 = percentile_doy(ds.tasmax, per=90)

# rsus = shortwave_upwelling_radiation_from_net_downwelling(ds.rss, ds.rsds)
# rlus = longwave_upwelling_radiation_from_net_downwelling(ds.rls, ds.rlds)

ds = ds.assign(
# rsus=rsus,
# rlus=rlus,
tn10=tn10,
t10=t10,
t90=t90,
tx90=tx90,
)

# Create a file in session scoped temporary directory
atmos_file = cache_dir.joinpath("atmosds.nc")
ds.to_netcdf(atmos_file, engine="h5netcdf")

# Give access to dataset variables by name in namespace
namespace = dict()
with open_dataset(
atmos_file, branch=TESTDATA_BRANCH, cache_dir=cache_dir, engine="h5netcdf"
) as ds:
for variable in ds.data_vars:
namespace[f"{variable}_dataset"] = ds.get(variable)
return namespace
217 changes: 216 additions & 1 deletion src/xsdba/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from .base import Grouper, ensure_chunk_size, parse_group, uses_dask
from .calendar import ensure_longest_doy
from .nbutils import _extrapolate_on_quantiles
from .nbutils import _extrapolate_on_quantiles, _linear_interpolation

MULTIPLICATIVE = "*"
ADDITIVE = "+"
Expand Down Expand Up @@ -964,3 +964,218 @@ def load_module(path: os.PathLike, name: str | None = None):
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod) # This executes code, effectively loading the module
return mod


# calc_perc-needed functions needed for generate_atmos


# XC
def calc_perc(
arr: np.ndarray,
percentiles: Sequence[float] | None = None,
alpha: float = 1.0,
beta: float = 1.0,
copy: bool = True,
) -> np.ndarray:
"""Compute percentiles using nan_calc_percentiles and move the percentiles' axis to the end."""
if percentiles is None:
_percentiles = [50.0]
else:
_percentiles = percentiles

return np.moveaxis(
nan_calc_percentiles(
arr=arr,
percentiles=_percentiles,
axis=-1,
alpha=alpha,
beta=beta,
copy=copy,
),
source=0,
destination=-1,
)


# XC
def nan_calc_percentiles(
arr: np.ndarray,
percentiles: Sequence[float] | None = None,
axis: int = -1,
alpha: float = 1.0,
beta: float = 1.0,
copy: bool = True,
) -> np.ndarray:
"""Convert the percentiles to quantiles and compute them using _nan_quantile."""
if percentiles is None:
_percentiles = [50.0]
else:
_percentiles = percentiles

if copy:
# bootstrapping already works on a data's copy
# doing it again is extremely costly, especially with dask.
arr = arr.copy()
quantiles = np.array([per / 100.0 for per in _percentiles])
return _nan_quantile(arr, quantiles, axis, alpha, beta)


# XC
def _nan_quantile(
arr: np.ndarray,
quantiles: np.ndarray,
axis: int = 0,
alpha: float = 1.0,
beta: float = 1.0,
) -> float | np.ndarray:
"""Get the quantiles of the array for the given axis.
A linear interpolation is performed using alpha and beta.
Notes
-----
By default, alpha == beta == 1 which performs the 7th method of :cite:t:`hyndman_sample_1996`.
with alpha == beta == 1/3 we get the 8th method.
"""
# --- Setup
data_axis_length = arr.shape[axis]
if data_axis_length == 0:
return np.nan
if data_axis_length == 1:
result = np.take(arr, 0, axis=axis)
return np.broadcast_to(result, (quantiles.size,) + result.shape)
# The dimensions of `q` are prepended to the output shape, so we need the
# axis being sampled from `arr` to be last.
DATA_AXIS = 0
if axis != DATA_AXIS: # But moveaxis is slow, so only call it if axis!=0.
arr = np.moveaxis(arr, axis, destination=DATA_AXIS)
# nan_count is not a scalar
nan_count = np.isnan(arr).sum(axis=DATA_AXIS).astype(float)
valid_values_count = data_axis_length - nan_count
# We need at least two values to do an interpolation
too_few_values = valid_values_count < 2
if too_few_values.any():
# This will result in getting the only available value if it exists
valid_values_count[too_few_values] = np.nan
# --- Computation of indexes
# Add axis for quantiles
valid_values_count = valid_values_count[..., np.newaxis]
virtual_indexes = _compute_virtual_index(valid_values_count, quantiles, alpha, beta)
virtual_indexes = np.asanyarray(virtual_indexes)
previous_indexes, next_indexes = _get_indexes(
arr, virtual_indexes, valid_values_count
)
# --- Sorting
arr.sort(axis=DATA_AXIS)
# --- Get values from indexes
arr = arr[..., np.newaxis]
previous = np.squeeze(
np.take_along_axis(arr, previous_indexes.astype(int)[np.newaxis, ...], axis=0),
axis=0,
)
next_elements = np.squeeze(
np.take_along_axis(arr, next_indexes.astype(int)[np.newaxis, ...], axis=0),
axis=0,
)
# --- Linear interpolation
gamma = _get_gamma(virtual_indexes, previous_indexes)
interpolation = _linear_interpolation(previous, next_elements, gamma)
# When an interpolation is in Nan range, (near the end of the sorted array) it means
# we can clip to the array max value.
result = np.where(np.isnan(interpolation), np.nanmax(arr, axis=0), interpolation)
# Move quantile axis in front
result = np.moveaxis(result, axis, 0)
return result


# XC
def _get_gamma(virtual_indexes: np.ndarray, previous_indexes: np.ndarray):
"""Compute gamma (AKA 'm' or 'weight') for the linear interpolation of quantiles.
Parameters
----------
virtual_indexes: array_like
The indexes where the percentile is supposed to be found in the sorted sample.
previous_indexes: array_like
The floor values of virtual_indexes.
Notes
-----
`gamma` is usually the fractional part of virtual_indexes but can be modified by the interpolation method.
"""
gamma = np.asanyarray(virtual_indexes - previous_indexes)
return np.asanyarray(gamma)


# XC
def _compute_virtual_index(
n: np.ndarray, quantiles: np.ndarray, alpha: float, beta: float
):
"""Compute the floating point indexes of an array for the linear interpolation of quantiles.
Based on the approach used by :cite:t:`hyndman_sample_1996`.
Parameters
----------
n : array_like
The sample sizes.
quantiles : array_like
The quantiles values.
alpha : float
A constant used to correct the index computed.
beta : float
A constant used to correct the index computed.
Notes
-----
`alpha` and `beta` values depend on the chosen method (see quantile documentation).
References
----------
:cite:cts:`hyndman_sample_1996`
"""
return n * quantiles + (alpha + quantiles * (1 - alpha - beta)) - 1


# XC
def _get_indexes(
arr: np.ndarray, virtual_indexes: np.ndarray, valid_values_count: np.ndarray
) -> tuple[np.ndarray, np.ndarray]:
"""Get the valid indexes of arr neighbouring virtual_indexes.
Notes
-----
This is a companion function to linear interpolation of quantiles.
Parameters
----------
arr : array-like
virtual_indexes : array-like
valid_values_count : array-like
Returns
-------
array-like, array-like
A tuple of virtual_indexes neighbouring indexes (previous and next).
"""
previous_indexes = np.asanyarray(np.floor(virtual_indexes))
next_indexes = np.asanyarray(previous_indexes + 1)
indexes_above_bounds = virtual_indexes >= valid_values_count - 1
# When indexes is above max index, take the max value of the array
if indexes_above_bounds.any():
previous_indexes[indexes_above_bounds] = -1
next_indexes[indexes_above_bounds] = -1
# When indexes is below min index, take the min value of the array
indexes_below_bounds = virtual_indexes < 0
if indexes_below_bounds.any():
previous_indexes[indexes_below_bounds] = 0
next_indexes[indexes_below_bounds] = 0
if np.issubdtype(arr.dtype, np.inexact):
# After the sort, slices having NaNs will have for last element a NaN
virtual_indexes_nans = np.isnan(virtual_indexes)
if virtual_indexes_nans.any():
previous_indexes[virtual_indexes_nans] = -1
next_indexes[virtual_indexes_nans] = -1
previous_indexes = previous_indexes.astype(np.intp)
next_indexes = next_indexes.astype(np.intp)
return previous_indexes, next_indexes
24 changes: 13 additions & 11 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
# from filelock import FileLock
from packaging.version import Version

from xsdba.testing import TESTDATA_BRANCH
from xsdba.testing import TESTDATA_BRANCH, generate_atmos
from xsdba.testing import open_dataset as _open_dataset
from xsdba.testing import (
test_cannon_2015_dist,
Expand Down Expand Up @@ -317,17 +317,20 @@ def atmosds(threadsafe_data_dir) -> xr.Dataset:
# )


# @pytest.fixture(scope="session", autouse=True)
# def gather_session_data(threadsafe_data_dir, worker_id, xdoctest_namespace):
# """Gather testing data on pytest run.
@pytest.fixture(scope="session", autouse=True)
def gather_session_data(threadsafe_data_dir):
"""Gather testing data on pytest run.
When running pytest with multiple workers, one worker will copy data remotely to _default_cache_dir while
other workers wait using lockfile. Once the lock is released, all workers will then copy data to their local
threadsafe_data_dir.As this fixture is scoped to the session, it will only run once per pytest run.
Additionally, this fixture is also used to generate the `atmosds` synthetic testing dataset as well as add the
example file paths to the xdoctest_namespace, used when running doctests.
"""
generate_atmos(threadsafe_data_dir)

# When running pytest with multiple workers, one worker will copy data remotely to _default_cache_dir while
# other workers wait using lockfile. Once the lock is released, all workers will then copy data to their local
# threadsafe_data_dir.As this fixture is scoped to the session, it will only run once per pytest run.

# Additionally, this fixture is also used to generate the `atmosds` synthetic testing dataset as well as add the
# example file paths to the xdoctest_namespace, used when running doctests.
# """
# if (
# not _default_cache_dir.joinpath(helpers.TESTDATA_BRANCH).exists()
# or helpers.PREFETCH_TESTING_DATA
Expand All @@ -353,7 +356,6 @@ def atmosds(threadsafe_data_dir) -> xr.Dataset:
# if lockfile.exists():
# lockfile.unlink()
# shutil.copytree(_default_cache_dir, threadsafe_data_dir)
# helpers.generate_atmos(threadsafe_data_dir)
# xdoctest_namespace.update(helpers.add_example_file_paths(threadsafe_data_dir))


Expand Down
1 change: 1 addition & 0 deletions tests/test_units.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pytest
import xarray as xr
from dask import array as dsk
from packaging.version import Version

from xsdba.logging import ValidationError
from xsdba.typing import Quantified
Expand Down

0 comments on commit 75a159e

Please sign in to comment.