diff --git a/environment.yml b/environment.yml index 7225490..550ca1a 100644 --- a/environment.yml +++ b/environment.yml @@ -7,6 +7,7 @@ dependencies: - numpy - pandas - earthkit-utils>=0.0.1 +- earthkit-transforms>=0.5.0 - make - mypy - myst-parser @@ -24,3 +25,4 @@ dependencies: - nbconvert - nbsphinx - ipykernel +- xarray diff --git a/pyproject.toml b/pyproject.toml index 4e590a9..4f012a7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,8 +27,10 @@ classifiers = [ ] dynamic = [ "version" ] dependencies = [ + "earthkit-transforms>=0.5", "earthkit-utils>=0.0.1", "numpy", + "xarray", ] optional-dependencies.gpu = [ "cupy", diff --git a/src/earthkit/meteo/extreme/__init__.py b/src/earthkit/meteo/extreme/__init__.py index d3ee0b6..aefbbee 100644 --- a/src/earthkit/meteo/extreme/__init__.py +++ b/src/earthkit/meteo/extreme/__init__.py @@ -15,6 +15,7 @@ planned to work with objects like *earthkit.data FieldLists* or *xarray DataSets*. """ +from . import excess_heat # noqa from .cpf import * # noqa from .efi import * # noqa from .sot import * # noqa diff --git a/src/earthkit/meteo/extreme/excess_heat.py b/src/earthkit/meteo/extreme/excess_heat.py new file mode 100644 index 0000000..bbee843 --- /dev/null +++ b/src/earthkit/meteo/extreme/excess_heat.py @@ -0,0 +1,216 @@ +# (C) Copyright 2025 - ECMWF and individual contributors. + +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation nor +# does it submit to any jurisdiction. + +import functools +import numbers + +import earthkit.transforms._tools as _ekt_tools +import earthkit.transforms.temporal +import numpy as np +import xarray as xr + + +class _with_metadata: + """Decorator to attach metadata to an output DataArray""" + + # TODO just a quick solution until something better is in place + # TODO input-dependent unit handling (take input unit and transform) + + def __init__(self, name, **attrs): + self.name = name + self.attrs = attrs + + def __call__(self, f): + @functools.wraps(f) + def wrapped(*args, **kwargs): + return f(*args, **kwargs).rename(self.name).assign_attrs(self.attrs) + + return wrapped + + +def _rolling_mean(da, n, shift_days=0): + return earthkit.transforms.temporal.rolling_reduce( + da, n, center=False, how_reduce="mean", time_shift={"days": shift_days}, how_dropna="any" + ) + + +__DMT_TIME_SHIFT_COORD = "__daily_mean_temperature_time_shift" + + +@_with_metadata("dmt", long_name="Daily mean temperature") +def daily_mean_temperature(t2m, day_start=9, time_shift=0, **kwargs): + """Daily mean temperature, computed from min and max. + + It is recommended to install flox for efficient aggregations. + + Supports custom definitions of "day". E.g., by defining a later start of + the day (positive `day_start`), the usual early morning temperature minimum + can be attributed to the previous day, so that the mean value is derived + from the daytime maximum and the minimum of the following night (rather + than the previous night). At the same time, local time zones can be + accounted with the `time_shift` parameter by specifying the time zone + offsets with respect to the time coordinate of the input data as a function + of the spatial coordinates. + + Example + ------- + + Assume that the time coordinate of the data is given in UTC and we want to + define the day from 10:00 to 10:00 Atlantic Standard Time (UTC -4 hours), + then we need to call + + daily_mean_temperature(..., day_start=10, time_shift=-4) + + Note that while both offest parameters lead to a "later" start of the day + relative to the reference time coordinate of the data, the sign of their + value is different to match their use as outlined in the text above. Set + `day_start` to zero and integrate the start offset into `time_shift` if + you prefer to specify only a single offset for both settings. The above + call is equivalent to + + daily_mean_temperature(..., day_start=0, time_shift=-14) + + Parameters + ---------- + t2m : xr.DataArray + 2-metre temperature. + day_start : number | np.timedelta64 + Constant offset for the start of the day in the aggregations. By + default, the day is defined from 09:00 to 09:00. Positive offsets + indicate a late start of the day (see default), negative an early + start. A numeric value is interpreted as hours. + time_shift : np.timedelta64 | number | str | xr.DataArray + Offset relative to time coordinate of input. Can be a function of + space to specify timezones. A numeric value is interpreted as hours. + Provide a string value to take a correspondingly named coordinate + from the input DataArray. + **kwargs + Keyword arguments for the daily_min and daily_max functions of + earthkit.transforms.temporal. + """ + assert isinstance(t2m, xr.DataArray) + + if isinstance(day_start, numbers.Number): + day_start = np.timedelta64(day_start, "h") + assert isinstance(day_start, np.timedelta64) + + if isinstance(time_shift, numbers.Number): + time_shift = np.timedelta64(time_shift, "h") + if isinstance(time_shift, str): + time_shift = t2m.coords[time_shift] + if isinstance(time_shift, xr.DataArray): + unique_shifts = np.unique(time_shift.values) + # Can only proceed if all timeseries are in the same time zone. If + # there are multiple time zones: split, process separately, merge. + if unique_shifts.size == 1: + time_shift = unique_shifts[0] + else: + assert __DMT_TIME_SHIFT_COORD not in t2m.coords + # Merging of groups where different partial days were removed + # fails after map when the time coordinate is of pandas period + # dtype. The period dtype works as long as the same partial days + # are in all groups removed. + return ( + t2m + # Groups don't know their time shift unless we attach it here. + # Grouping by the time shift means the shift value per group + # is unique and the recursion ends immediately. + .assign_coords({__DMT_TIME_SHIFT_COORD: time_shift}) + .groupby(__DMT_TIME_SHIFT_COORD) + .map(daily_mean_temperature, day_start=day_start, time_shift=__DMT_TIME_SHIFT_COORD, **kwargs) + # Don't expose the internal shift coordinate to the user + .drop_vars(__DMT_TIME_SHIFT_COORD) + ) + time_shift = np.asarray(time_shift) + assert time_shift.size == 1 + assert np.issubdtype(time_shift.dtype, np.timedelta64) + + agg_kwargs = {"time_shift": time_shift - day_start, "remove_partial_periods": True, **kwargs} + tmin = earthkit.transforms.temporal.daily_min(t2m, **agg_kwargs) + tmax = earthkit.transforms.temporal.daily_max(t2m, **agg_kwargs) + return 0.5 * (tmin + tmax) + + +@_with_metadata("ehi_sig", long_name="Significance index") +def significance_index(dmt, threshold=("quantile", 0.95), ndays=3, time_dim=None): + """Significance index + + Parameters + ---------- + dmt : xr.DataArray + Daily mean temperature. + threshold : number + TODO + ndays : number + Length of evaluation time window. + time_dim : None | str + Name of time dimension in dmt DataArray. + + Returns + ------- + xr.DataArray + Significance index. + """ + # Time coordinate detection compatible with earthkit.transforms + if time_dim is None: + time_dim = _ekt_tools.get_dim_key(dmt, "t", raise_error=True) + # Compute threshold as quantile + if isinstance(threshold, tuple): + assert len(threshold) == 2 + if threshold[0] == "quantile": + threshold = dmt.quantile(threshold[1], dim=time_dim) + else: + raise NotImplementedError + # TODO: also support day-of-year climatology to detect warm spells + current = _rolling_mean(dmt, ndays, shift_days=(1 - ndays)) + return current - threshold + + +@_with_metadata("ehi_accl", long_name="acclimatisation_index") +def acclimatisation_index(dmt, ndays=3, ndays_ref=30): + """Acclimatisation index + + Parameters + ---------- + dmt : xr.DataArray + Daily mean temperature. + ndays : number + Length of evaluation time window. + ndays_ref : number + Length of reference time window (recent past). + """ + current = _rolling_mean( + dmt, ndays, shift_days=(1 - ndays) + ) # TODO: shared with significance index, would be nice to not compute it twice + reference = _rolling_mean(dmt, ndays_ref, shift_days=1) + return current - reference + + +# https://codes.ecmwf.int/grib/param-db/261024 +@_with_metadata("exhf", long_name="Excess heat factor") +def excess_heat_factor(ehi_sig, ehi_accl, nonnegative=True): + """Excess heat factor + + Parameters + ---------- + ehi_sig : xr.DataArray | array_like + Significance index. + ehi_accl : xr.DataArray | array_like + Acclimatisation index. + nonnegative : bool + Whether to clip the lower value range by zero. + """ + if nonnegative: + ehi_sig = np.maximum(0, ehi_sig) + return ehi_sig * np.maximum(1.0, ehi_accl) + + +# https://codes.ecmwf.int/grib/param-db/261025 +@_with_metadata("excf", long_name="Excess cold factor") +def excess_cold_factor(ehi_sig, ehi_accl, nonnegative=True): + return NotImplementedError # TODO diff --git a/tests/extreme/test_excess_heat.py b/tests/extreme/test_excess_heat.py new file mode 100644 index 0000000..681dd63 --- /dev/null +++ b/tests/extreme/test_excess_heat.py @@ -0,0 +1,167 @@ +# (C) Copyright 2025 - ECMWF and individual contributors. + +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation nor +# does it submit to any jurisdiction. + +import numpy as np +import pandas as pd +import pytest +import xarray as xr + +from earthkit.meteo.extreme import excess_heat + + +class TestDailyMeanTemeratureArgDayStart: + + @pytest.fixture + def t2m(self): + nt = 4 * 24 + return xr.DataArray( + data=np.arange(nt), + dims=["valid_time"], + coords={"valid_time": pd.date_range("2026-01-01", periods=nt, freq="1h")}, + name="t2m", + ) + + @pytest.mark.parametrize("start", [-5, 0, 3]) + def test_number_value_interpreted_as_hours(self, t2m, start): + dmt_tdelta = excess_heat.daily_mean_temperature(t2m, day_start=np.timedelta64(start, "h")) + dmt_number = excess_heat.daily_mean_temperature(t2m, day_start=start) + np.testing.assert_allclose(dmt_tdelta, dmt_number) + + def test_zero_means_day_starts_at_midnight(self, t2m): + dmt = excess_heat.daily_mean_temperature(t2m, day_start=0, time_shift=0) + np.testing.assert_allclose(dmt, [0.5 * (0 + 23), 0.5 * (24 + 47), 0.5 * (48 + 71), 0.5 * (72 + 95)]) + + @pytest.mark.parametrize("start", [-9, -5, -2]) + def test_negative_value_means_day_starts_early(self, t2m, start): + dmt = excess_heat.daily_mean_temperature(t2m, day_start=start, time_shift=0) + # Partial periods removed: days 2, 3, 4 returned; 1 and 5 partial + np.testing.assert_allclose( + dmt, [0.5 * (24 + 47) + start, 0.5 * (48 + 71) + start, 0.5 * (72 + 95) + start] + ) + + @pytest.mark.parametrize("start", [2, 5, 9]) + def test_positive_means_day_starts_late(self, t2m, start): + dmt = excess_heat.daily_mean_temperature(t2m, day_start=start, time_shift=0) + # Partial periods removed: days 1, 2, 3 returned; 0 and 4 partial + np.testing.assert_allclose( + dmt, [0.5 * (0 + 23) + start, 0.5 * (24 + 47) + start, 0.5 * (48 + 71) + start] + ) + + +class TestDailyMeanTemperatureArgTimeShift: + + @pytest.fixture + def t2m(self): + nx, nt = 3, 4 * 24 + return xr.DataArray( + # Values differ by order of magnitude in space, increase over time + data=np.logspace(0, nx - 1, nx)[None, :] * np.arange(nt)[:, None], + dims=["valid_time", "x"], + coords={ + "valid_time": pd.date_range("2026-01-01", periods=nt, freq="1h"), + "x": np.arange(nx), + }, + name="t2m", + ) + + @pytest.mark.parametrize("shift", [-5, 0, 3]) + def test_number_value_interpreted_as_hours(self, t2m, shift): + dmt_tdelta = excess_heat.daily_mean_temperature( + t2m, day_start=0, time_shift=np.timedelta64(shift, "h") + ) + dmt_number = excess_heat.daily_mean_temperature(t2m, day_start=0, time_shift=shift) + np.testing.assert_allclose(dmt_tdelta, dmt_number) + + def test_field_access_with_str_value(self, t2m): + tz = xr.DataArray( + data=[np.timedelta64(1, "h"), np.timedelta64(5, "h"), np.timedelta64(7, "h")], + dims=["x"], + coords={"x": t2m.coords["x"]}, + name="timezone", + ) + dmt_array = excess_heat.daily_mean_temperature(t2m, day_start=0, time_shift=tz) + dmt_coord = excess_heat.daily_mean_temperature( + t2m.assign_coords({"timezone": tz}), day_start=0, time_shift="timezone" + ) + np.testing.assert_allclose(dmt_array, dmt_coord) + + @pytest.mark.parametrize("shift", [-5, 0, 3]) + def test_scalar_value_is_applied_everywhere(self, t2m, shift): + dmt_scalar = excess_heat.daily_mean_temperature(t2m, day_start=0, time_shift=shift) + dmt_array = excess_heat.daily_mean_temperature( + t2m.assign_coords({"timezone": ("x", 3 * [np.timedelta64(shift, "h")])}), + day_start=0, + time_shift="timezone", + ) + np.testing.assert_allclose(dmt_scalar, dmt_array) + + @pytest.mark.parametrize("shift", [-2, -5, -9]) + def test_negative_value_means_day_starts_late(self, t2m, shift): + dmt = excess_heat.daily_mean_temperature(t2m, day_start=0, time_shift=shift) + ref = (np.asarray([11.5, 35.5, 59.5]) - shift)[:, None] * np.asarray([1e0, 1e1, 1e2])[None, :] + np.testing.assert_allclose(dmt, ref) + + @pytest.mark.parametrize("shift", [2, 5, 9]) + def test_positive_value_means_day_starts_early(self, t2m, shift): + dmt = excess_heat.daily_mean_temperature(t2m, day_start=0, time_shift=shift) + ref = (np.asarray([35.5, 59.5, 83.5]) - shift)[:, None] * np.asarray([1e0, 1e1, 1e2])[None, :] + np.testing.assert_allclose(dmt, ref) + + def test_multiple_timezones_grouping_of_duplicate_values(self, t2m): + t2m = t2m.assign_coords( + {"timezone": ("x", [np.timedelta64(1, "h"), np.timedelta64(4, "h"), np.timedelta64(4, "h")])} + ) + dmt = excess_heat.daily_mean_temperature(t2m, day_start=0, time_shift="timezone") + np.testing.assert_allclose(dmt, [[34.5, 315.0, 3150.0], [58.5, 555.0, 5550.0], [82.5, 795.0, 7950.0]]) + + def test_multiple_timezones_fills_with_nan_to_preserve_outputs(self, t2m): + t2m = t2m.assign_coords( + {"timezone": ("x", [np.timedelta64(-2, "h"), np.timedelta64(0, "h"), np.timedelta64(10, "h")])} + ) + dmt = excess_heat.daily_mean_temperature(t2m, day_start=0, time_shift="timezone") + np.testing.assert_allclose( + dmt, + [ + [13.5, 115.0, np.nan], + [37.5, 355.0, 2550.0], + [61.5, 595.0, 4950.0], + [np.nan, 835.0, 7350.0], + ], + ) + + +def test_daily_mean_temperature_combined_day_start_and_time_shift_args(): + nx, nt = 4, 4 * 24 + t2m = xr.DataArray( + data=np.arange(nt).repeat(nx).reshape((nt, nx)), + dims=["valid_time", "x"], + coords={ + "valid_time": pd.date_range("2026-01-01", periods=nt, freq="1h"), + "x": np.arange(nx), + "timezone": ( + "x", + [ + np.timedelta64(-5, "h"), + np.timedelta64(1, "h"), + np.timedelta64(3, "h"), + np.timedelta64(12, "h"), + ], + ), + }, + name="t2m", + ) + dmt = excess_heat.daily_mean_temperature(t2m, day_start=3, time_shift="timezone") + np.testing.assert_allclose( + dmt, + [ + [19.5, 13.5, 11.5, np.nan], + [43.5, 37.5, 35.5, 26.5], + [67.5, 61.5, 59.5, 50.5], + [np.nan, np.nan, 83.5, 74.5], + ], + )