diff --git a/docs/examples/index.rst b/docs/examples/index.rst index 9e60465..91f452e 100644 --- a/docs/examples/index.rst +++ b/docs/examples/index.rst @@ -10,3 +10,4 @@ Here is a list of example notebooks to illustrate how to use earthkit-data. :maxdepth: 1 return_period.ipynb + seven_weather_regimes.ipynb diff --git a/docs/examples/seven_weather_regimes.ipynb b/docs/examples/seven_weather_regimes.ipynb new file mode 100644 index 0000000..291b17a --- /dev/null +++ b/docs/examples/seven_weather_regimes.ipynb @@ -0,0 +1,330 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "9485a729", + "metadata": {}, + "source": [ + "# Year-round North Atlantic-European Weather Regimes\n", + "\n", + "Following the example in Part 4 of [github.com/cmgrams/wr_data_package_era5](https://github.com/cmgrams/wr_data_package_era5/blob/main/wr_data_package_V1.0/scripts_first_steps/WR_read_example.ipynb), compute the regime index for the seven-regime classification of [Grams (2026, in review)](https://doi.org/10.5194/egusphere-2025-6385)." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "38c63573", + "metadata": {}, + "outputs": [], + "source": [ + "import pooch\n", + "import numpy as np\n", + "import pandas as pd\n", + "import xarray as xr\n", + "\n", + "from earthkit.meteo import regimes" + ] + }, + { + "cell_type": "markdown", + "id": "1fb3f386", + "metadata": {}, + "source": [ + "## Get the regime classification data\n", + "\n", + "Retrieve data from https://zenodo.org/records/18154492." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "94c5d0a6", + "metadata": {}, + "outputs": [], + "source": [ + "files = pooch.retrieve(\n", + " url=\"doi:10.5281/zenodo.18154492/wr_data_package_V1.1.zip\",\n", + " known_hash=\"dc942ff2a1b3da6dedd3b0b2fadda017fff9e8fc10228ace31c2209a9be7dc62\",\n", + " processor=pooch.Unzip()\n", + ")\n", + "\n", + "def get_file(name):\n", + " for file in files:\n", + " if file.endswith(name):\n", + " return file\n", + " raise FileNotFoundError(name)" + ] + }, + { + "cell_type": "markdown", + "id": "0c18029b", + "metadata": {}, + "source": [ + "## Regime pattern normalisation based on day-of-year\n", + "\n", + "Load the file with normalisation weights from the repository." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "f4fdd30f", + "metadata": {}, + "outputs": [], + "source": [ + "mod_ds = xr.open_dataset(get_file(\"wr_data/EOFs_WRs.nc\"))" + ] + }, + { + "cell_type": "markdown", + "id": "b73380a3", + "metadata": {}, + "source": [ + "Create a lookup table for the normalisation weights and define a function that finds the appropriate weight for a given date." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "8e820f2d", + "metadata": {}, + "outputs": [], + "source": [ + "mod = 1. / mod_ds[\"normwgt\"].to_series()\n", + "mod.index = pd.Index(zip(mod.index.month, mod.index.day, mod.index.hour))\n", + "\n", + "def pattern_normalisation_weight(date):\n", + " date = np.asarray(date)\n", + " shp = date.shape\n", + " date = pd.to_datetime(date.flatten())\n", + " if isinstance(date, pd.Timestamp):\n", + " return mod.loc[(date.month, date.day, date.hour)]\n", + " else:\n", + " idx = list(zip(date.month, date.day, date.hour))\n", + " return mod.loc[idx].values.reshape(shp)" + ] + }, + { + "cell_type": "markdown", + "id": "651d73e0", + "metadata": {}, + "source": [ + "## Load the regime patterns\n", + "\n", + "Patterns are stored in a NetCDF file in the repository." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "418019a3", + "metadata": {}, + "outputs": [], + "source": [ + "pattern_ds = xr.open_dataset(get_file(\"wr_data/Normed_Z0500-patterns_EOFdomain.nc\"))" + ] + }, + { + "cell_type": "markdown", + "id": "fdba08fc", + "metadata": {}, + "source": [ + "Seven base patterns. The normalisation of projections is implemented via a modulation of the regime patterns, i.e., we vary the amplitude of the patterns." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "14ab5c58", + "metadata": {}, + "outputs": [], + "source": [ + "patterns = regimes.ModulatedRegimePatterns(\n", + " regimes=pattern_ds.attrs[\"ClassNames\"].split(),\n", + " grid={\n", + " \"grid\": [0.5, 0.5],\n", + " \"area\": [90, -80, 30, 40] # 30-90°N, 80°W-40°E\n", + " },\n", + " patterns=pattern_ds[\"Z0500_mean\"].values,\n", + " modulator=pattern_normalisation_weight\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "69a9f27d", + "metadata": {}, + "source": [ + "**Note:** the regime data from this repository uses ascending order for the latitude coordinate. The latitude ordering of fields projected onto these patterns must match." + ] + }, + { + "cell_type": "markdown", + "id": "18c7b6ca", + "metadata": {}, + "source": [ + "## Project the test field onto the patterns" + ] + }, + { + "cell_type": "markdown", + "id": "8a2a0526", + "metadata": {}, + "source": [ + "Load the test field provided with the dataset. This field contains pre-processed Z500 anomalies ready for projection." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "f0b8eeba", + "metadata": {}, + "outputs": [], + "source": [ + "ds_field = xr.open_dataset(get_file(\"wr_data_package_V1.1/example_data/Z0500_20250601_00.nc\")).squeeze()\n", + "\n", + "# Extract values for the EUR-ATL region and create associated coordinates\n", + "# following the example in the reference notebook\n", + "field = ds_field[\"Z0\"].values[240:361,200:441]" + ] + }, + { + "cell_type": "markdown", + "id": "a491d706", + "metadata": {}, + "source": [ + "Create area-based weights for the grid points to use in the projection." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "6b0a13c0", + "metadata": {}, + "outputs": [], + "source": [ + "lat = np.linspace(ds_field.attrs[\"domymin\"], ds_field.attrs[\"domymax\"], ds_field[\"Z0\"].shape[0])\n", + "\n", + "weights = np.cos(np.deg2rad(lat))\n", + "weights = np.repeat(weights[240:361], field.shape[1]).reshape(field.shape)" + ] + }, + { + "cell_type": "markdown", + "id": "ea43918a", + "metadata": {}, + "source": [ + "Project onto the regime patterns, supply the valid date of the field for the modulator function to select the normalisation weight." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "d016fba8", + "metadata": {}, + "outputs": [], + "source": [ + "projections = regimes.project(field, patterns, weights, date=np.datetime64(\"2025-06-01 00:00\"))" + ] + }, + { + "cell_type": "markdown", + "id": "9b316d2c", + "metadata": {}, + "source": [ + "## Compute the regime index\n", + "\n", + "Standardise the projections to obtain the regime index.\n", + "Reference values are given in the notebook:\n", + "\n", + " weather regime indices for 20250601_00\n", + " 0.8218320408050166 AT\n", + " 1.1691867614026132 ZO\n", + " 1.0489664646400954 ScTr\n", + " -0.6948883613466961 AR\n", + " -0.5738999234610754 EuBL\n", + " -0.9323144170607468 ScBL\n", + " -0.6369122243737739 GL\n", + "\n", + "Read the standardisation parameters (mean and standard deviation) from the text file in the repo." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "342bdb6e", + "metadata": {}, + "outputs": [], + "source": [ + "def read_wri_std_parameters(path):\n", + " f.readline()\n", + " name = f.readline().strip().split()\n", + " mean = f.readline().strip().split()[1:]\n", + " std = f.readline().strip().split()[1:]\n", + " return (\n", + " {n: float(v) for n, v in zip(name, mean)},\n", + " {n: float(v) for n, v in zip(name, std)}\n", + " )\n", + " \n", + "with open(get_file(\"wr_data/WRI_std_params.txt\"), \"r\") as f:\n", + " norm_mean, norm_std = read_wri_std_parameters(f)" + ] + }, + { + "cell_type": "markdown", + "id": "9485431d", + "metadata": {}, + "source": [ + "Compute the regime index." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "913f8d12", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'AT': np.float32(0.8218321),\n", + " 'ZO': np.float32(1.169187),\n", + " 'ScTr': np.float32(1.0489665),\n", + " 'AR': np.float32(-0.6948884),\n", + " 'EuBL': np.float32(-0.5738999),\n", + " 'ScBL': np.float32(-0.9323146),\n", + " 'GL': np.float32(-0.6369122)}" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "regimes.standardise(projections, norm_mean, norm_std)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "earthkit-meteo", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.14.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/src/earthkit/meteo/regimes/__init__.py b/src/earthkit/meteo/regimes/__init__.py new file mode 100644 index 0000000..a53d8ec --- /dev/null +++ b/src/earthkit/meteo/regimes/__init__.py @@ -0,0 +1,13 @@ +# (C) Copyright 2021 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +from .index import project +from .index import standardise +from .patterns import ConstantRegimePatterns +from .patterns import ModulatedRegimePatterns +from .patterns import RegimePatterns diff --git a/src/earthkit/meteo/regimes/array/__init__.py b/src/earthkit/meteo/regimes/array/__init__.py new file mode 100644 index 0000000..bf7b65a --- /dev/null +++ b/src/earthkit/meteo/regimes/array/__init__.py @@ -0,0 +1,9 @@ +# (C) Copyright 2021 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +from .index import * diff --git a/src/earthkit/meteo/regimes/array/index.py b/src/earthkit/meteo/regimes/array/index.py new file mode 100644 index 0000000..39b553d --- /dev/null +++ b/src/earthkit/meteo/regimes/array/index.py @@ -0,0 +1,68 @@ +# (C) Copyright 2021 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +def project(field, patterns, weights, **patterns_kwargs): + """Project onto the given regime patterns. + + Parameters + ---------- + field : array_like + Input field(s) to project. + patterns : earthkit.meteo.regimes.RegimePatterns + Regime patterns. + weights : array_like + Weights for the summation over the spatial dimensions. + **patterns_kwargs : dict[str, Any], optional + Keyword arguments for the pattern generation. E.g., a sequence of + dates for date-modulated regime patterns. + + Returns + ------- + dict[str, array_like] + Results of the projection for each regime. + """ + ndim_field = len(patterns.shape) + if field.shape[-ndim_field:] != patterns.shape: + raise ValueError( + f"shape of input fields {field.shape} incompatible with shape of regime patterns {patterns.shape}" + ) + + ps = patterns.patterns(**patterns_kwargs) + + if weights is None: + # TODO generate area-based weights from grid of patterns with earthkit-geo + # TODO make weights an optional argument with None default and document + raise NotImplementedError("automatic generation of weights") + if weights.shape != patterns.shape: + raise ValueError(f"shape of weights {weights.shape} must match shape of patterns {patterns.shape}") + weights = weights / weights.sum() + + # Project onto each regime pattern + sum_axes = tuple(range(-ndim_field, 0, 1)) + return {regime: (field * pattern * weights).sum(axis=sum_axes) for regime, pattern in ps.items()} + + +def standardise(projections, mean, std): + """Regime index by standardisation of regime projections. + + Convenience function to work with dictionaries. + + Parameters + ---------- + projections : dict[str, array_like] + Projections onto regime patterns. + mean : dict[str, array_like] + std : dict[str, array_like] + + Returns + ------- + dict[str, array_like] + ``(projection - mean) / std`` for each regime + """ + return {regime: (proj - mean[regime]) / std[regime] for regime, proj in projections.items()} diff --git a/src/earthkit/meteo/regimes/index.py b/src/earthkit/meteo/regimes/index.py new file mode 100644 index 0000000..13dd9d6 --- /dev/null +++ b/src/earthkit/meteo/regimes/index.py @@ -0,0 +1,17 @@ +# (C) Copyright 2021 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +from . import array + + +def project(*args, **kwargs): + return array.project(*args, **kwargs) + + +def standardise(*args, **kwargs): + return array.standardise(*args, **kwargs) diff --git a/src/earthkit/meteo/regimes/patterns.py b/src/earthkit/meteo/regimes/patterns.py new file mode 100644 index 0000000..d149c3a --- /dev/null +++ b/src/earthkit/meteo/regimes/patterns.py @@ -0,0 +1,138 @@ +# (C) Copyright 2021 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import abc + +from earthkit.utils.array import array_namespace + + +class RegimePatterns(abc.ABC): + """Collection of weather regime patterns. + + Parameters + ---------- + regimes : Iterable[str] + Names of the regimes. The ordering here determines the ordering of + regimes in all outputs. + grid : dict + The grid on which the regime patterns live. + """ + + def __init__(self, regimes, grid): + self._regimes = tuple(regimes) + self._grid = grid + + @property + def regimes(self): + """Names of the regime patterns.""" + return self._regimes + + @property + def grid(self): + """Grid specification of the regime patterns.""" + return self._grid + + @property + def shape(self): + """Shape of the regime patterns.""" + # TODO placeholder until this functionality is available from earthkit-geo + lat0, lon0, lat1, lon1 = self.grid["area"] + dlat, dlon = self.grid["grid"] + return (int(abs(lat0 - lat1) / dlat) + 1, int(abs(lon0 - lon1) / dlon) + 1) + + @abc.abstractmethod + def patterns(self, **kwargs) -> dict: + """Patterns for all regimes.""" + + def __repr__(self): + return f"{self.__class__.__name__}{self.regimes}" + + +class ConstantRegimePatterns(RegimePatterns): + """Constant regime patterns. + + Parameters + ---------- + regimes : Iterable[str] + Regime labels. + grid : dict + Grid specification of the patterns. + patterns : array_like + Regime patterns. + """ + + def __init__(self, regimes, grid, patterns): + super().__init__(regimes, grid) + self._xp = array_namespace(patterns) + self._patterns = self._xp.asarray(patterns) + if self._patterns.ndim != 1 + len(self.shape): + raise ValueError("must have exactly one regime dimension in the patterns") + if len(self.regimes) != self._patterns.shape[0]: + raise ValueError("number of regimes does not match number of patterns") + + def patterns(self): + """Regime patterns. + + Returns + ------- + dict[str, array_like] + Regime patterns. + """ + return dict(zip(self._regimes, self._patterns)) + + +class ModulatedRegimePatterns(RegimePatterns): + """Regime patterns modulated by a custom scalar function. + + Parameters + ---------- + regimes : Iterable[str] + Regime labels. + grid : dict + Grid specification of the patterns. + patterns : array_like + Base regime patterns. + modulator : Callable[Any, array_like] + Scalar function to modulate the base patterns. + """ + + def __init__(self, regimes, grid, patterns, modulator): + super().__init__(regimes, grid) + self._xp = array_namespace(patterns) + self._base_patterns = self._xp.asarray(patterns) + # Pattern verification + if self._base_patterns.ndim != 1 + len(self.shape): + raise ValueError("must have exactly one regime dimension in the patterns") + if len(self.regimes) != self._base_patterns.shape[0]: + raise ValueError("number of regimes does not match number of patterns") + self.modulator = modulator + if not callable(self.modulator): + raise ValueError("modulator must be callable") + + @property + def _base_patterns_dict(self): + return dict(zip(self._regimes, self._base_patterns)) + + def patterns(self, **kwargs): + """Regime patterns for a given input to the modulator function. + + Parameters + ---------- + **kwargs : dict[str, Any], optional + Keyword arguments for the modulator function. + + Returns + ------- + dict[str, array_like] + Modulated regime patterns. + """ + xp = self._xp + modulator = xp.asarray(self.modulator(**kwargs)) + # Adapt to shape of regime patterns + modulator = modulator[(..., *((xp.newaxis,) * len(self.shape)))] + return {regime: modulator * base_pattern for regime, base_pattern in self._base_patterns_dict.items()} diff --git a/tests/regimes/test_index.py b/tests/regimes/test_index.py new file mode 100644 index 0000000..cfd9f10 --- /dev/null +++ b/tests/regimes/test_index.py @@ -0,0 +1,124 @@ +# (C) Copyright 2021 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import numpy as np +import pytest + +from earthkit.meteo import regimes + + +@pytest.fixture +def patterns(): + class MockRegimePatterns: + + _lat = np.linspace(90.0, 0.0, 91) + _lon = np.linspace(60.0, -60.0, 121) + _dipole = np.cos(np.deg2rad(_lon[None, :])) * np.cos(np.deg2rad(_lat[:, None]) * 2) + _monopole = np.cos(np.deg2rad(_lon[None, :])) * np.sin(np.deg2rad(_lat[:, None]) * 2) + shape = (91, 121) + grid = {"grid": [1.0, 1.0], "area": [max(_lat), min(_lon), min(_lat), max(_lon)]} + + def patterns(self, single=True): + return { + "dipole": self._dipole if single else np.stack([self._dipole, 2 * self._dipole]), + "monopole": self._monopole if single else np.stack([self._monopole, 2 * self._monopole]), + "dipole_inv": -self._dipole if single else np.stack([-self._dipole, -2 * self._dipole]), + } + + return MockRegimePatterns() + + +def test_project_matches_field_and_pattern_shapes(patterns): + with pytest.raises(ValueError): + regimes.project(np.zeros((91 * 121,)), patterns, weights=None) + with pytest.raises(ValueError): + regimes.project(np.zeros((20, 30)), patterns, weights=None) + with pytest.raises(ValueError): + regimes.project(np.zeros((91, 2, 3)), patterns, weights=None) + + +def test_project_matches_weights_and_pattern_shapes(patterns): + with pytest.raises(ValueError): + regimes.project(np.ones(patterns.shape), patterns, weights=np.ones((20, 30))) + + +def test_project_ones_with_uniform_weights(patterns): + proj = regimes.project(np.ones(patterns.shape), patterns, weights=np.ones(patterns.shape)) + assert np.isclose(proj["dipole"], np.mean(patterns._dipole)) + assert np.isclose(proj["monopole"], np.mean(patterns._monopole)) + assert np.isclose(proj["dipole_inv"], np.mean(patterns._dipole)) + # Pattern symmetry + assert np.isclose(proj["dipole"], -proj["dipole_inv"]) + + +def test_project_ones_with_coslat_weights(patterns): + lat_2d = np.repeat(patterns._lat, patterns._lon.size).reshape(patterns.shape) + coslat = np.cos(np.deg2rad(lat_2d)) + proj = regimes.project(np.ones(patterns.shape), patterns, weights=coslat) + assert proj["dipole"] > 0 # positive values where weights are heigher + assert proj["dipole_inv"] < 0 # negative values where weights are higher + assert np.isclose(proj["dipole"], -proj["dipole_inv"]) + + +def test_project_zeros_returns_zero(patterns): + proj = regimes.project(np.zeros(patterns.shape), patterns, weights=np.ones(patterns.shape)) + assert np.isclose(proj["dipole"], 0.0) + assert np.isclose(proj["monopole"], 0.0) + assert np.isclose(proj["dipole_inv"], 0.0) + + +def test_project_is_commutative(patterns): + fields = np.stack([patterns._dipole, patterns._monopole]) + proj = regimes.project(fields, patterns, weights=np.ones(patterns.shape)) + np.testing.assert_allclose(proj["dipole"][1], proj["monopole"][0]) + + +def test_project_maintains_shape(patterns): + fields = np.zeros((2, 3, 4, *patterns.shape)) + proj = regimes.project(fields, patterns, weights=np.ones(patterns.shape)) + assert proj["dipole"].shape == (2, 3, 4) + + +@pytest.mark.xfail(reason="grid info not available from earthkit-geo") +def test_project_generates_weights_by_default(patterns): + regimes.project(np.ones(patterns.shape), patterns) + + +def test_project_with_single_pattern_return(patterns): + proj = regimes.project( + np.ones((2, *patterns.shape)), patterns, weights=np.ones(patterns.shape), single=True + ) + # All patterns are the same + assert proj["dipole"].shape == (2,) + assert np.isclose(proj["dipole"][0], proj["dipole"][1]) + + +def test_project_with_multiple_pattern_return(patterns): + proj = regimes.project( + np.ones((2, *patterns.shape)), patterns, weights=np.ones(patterns.shape), single=False + ) + # Second pattern has twice the amplitude + assert proj["dipole"].shape == (2,) + assert np.isclose(proj["dipole"][0], 0.5 * proj["dipole"][1]) + + +def test_standardise_with_dict(): + proj = { + "foo": np.asarray([0.0, 1.0, 2.0, 3.0, 4.0, 5.0]), + "bar": np.asarray([0.0, 1.0, 2.0, 3.0, 4.0, 5.0]), + } + mean = {"foo": 2.0, "bar": -4.0} + std = {"foo": 10.0, "bar": 2.0} + + index = regimes.standardise(proj, mean, std) + + assert len(index) == 2 + assert "foo" in index + assert "bar" in index + np.testing.assert_allclose(index["foo"], [-0.2, -0.1, 0.0, 0.1, 0.2, 0.3]) + np.testing.assert_allclose(index["bar"], [2.0, 2.5, 3.0, 3.5, 4.0, 4.5]) diff --git a/tests/regimes/test_patterns.py b/tests/regimes/test_patterns.py new file mode 100644 index 0000000..2fc47ae --- /dev/null +++ b/tests/regimes/test_patterns.py @@ -0,0 +1,59 @@ +# (C) Copyright 2021 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import numpy as np + +from earthkit.meteo import regimes + + +class TestConstantRegimePatterns: + + lat = np.linspace(90.0, 0.0, 91) + lon = np.linspace(60.0, -60.0, 121) + dipole = np.cos(np.deg2rad(lon[None, :])) * np.cos(np.deg2rad(lat[:, None]) * 2) + monopole = np.cos(np.deg2rad(lon[None, :])) * np.sin(np.deg2rad(lat[:, None]) * 2) + patterns = regimes.ConstantRegimePatterns( + regimes=["dipole", "monopole", "dipole_inv"], + grid={"grid": [1.0, 1.0], "area": [max(lat), min(lon), min(lat), max(lon)]}, + patterns=np.stack([dipole, monopole, -dipole]), + ) + + zeros = np.zeros(shape=patterns.shape) + ones = np.ones(shape=patterns.shape) + + def test_shape(self): + assert self.patterns.shape == (self.lat.size, self.lon.size) + + def test_patterns(self): + pat = self.patterns.patterns() + assert len(pat) == 3 + np.testing.assert_allclose(pat["dipole"], self.dipole) + np.testing.assert_allclose(pat["monopole"], self.monopole) + + +class TestModulatedRegimePatterns: + + lat = np.linspace(90.0, 0.0, 91) + lon = np.linspace(60.0, -60.0, 121) + dipole = np.cos(np.deg2rad(lon[None, :])) * np.cos(np.deg2rad(lat[:, None]) * 2) + patterns = regimes.ModulatedRegimePatterns( + regimes=["dipole"], + grid={"grid": [1.0, 1.0], "area": [max(lat), min(lon), min(lat), max(lon)]}, + patterns=np.stack([dipole]), + modulator=lambda x: np.sign(x), # clarify the arg name + ) + + def test_shape(self): + assert self.patterns.shape == (self.lat.size, self.lon.size) + + def test_patterns(self): + pat = self.patterns.patterns(x=[3.0, 0.0, -4.0]) + assert len(pat) == 1 + np.testing.assert_allclose(pat["dipole"][0], self.dipole) + np.testing.assert_allclose(pat["dipole"][1], 0.0) + np.testing.assert_allclose(pat["dipole"][2], -self.dipole)