diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 68f0a0b7aee..a49564649cf 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -49,6 +49,9 @@ Bug Fixes ``np.isclose`` by default to handle accumulated floating point errors from slicing operations. Use ``exact=True`` for exact comparison (:pull:`11035`). By `Ian Hunt-Isaak `_. +- Ensure the :py:class:`~xarray.groupers.SeasonResampler` preserves the datetime + unit of the underlying time index when resampling (:issue:`11048`, + :pull:`11049`). By `Spencer Clark `_. Documentation ~~~~~~~~~~~~~ diff --git a/xarray/groupers.py b/xarray/groupers.py index a16933e690f..a26741ff3fe 100644 --- a/xarray/groupers.py +++ b/xarray/groupers.py @@ -12,8 +12,9 @@ import operator from abc import ABC, abstractmethod from collections import defaultdict -from collections.abc import Hashable, Mapping, Sequence +from collections.abc import Callable, Hashable, Mapping, Sequence from dataclasses import dataclass, field +from functools import partial from itertools import chain, pairwise from typing import TYPE_CHECKING, Any, Literal, cast @@ -38,8 +39,10 @@ from xarray.core.resample_cftime import CFTimeGrouper from xarray.core.types import ( Bins, + CFTimeDatetime, DatetimeLike, GroupIndices, + PDDatetimeUnitOptions, ResampleCompatible, Self, SideOptions, @@ -61,6 +64,16 @@ RESAMPLE_DIM = "__resample_dim__" +def _datetime64_via_timestamp(unit: PDDatetimeUnitOptions, **kwargs) -> np.datetime64: + """Construct a numpy.datetime64 object through the pandas.Timestamp + constructor with a specific resolution.""" + # TODO: when pandas 3 is our minimum requirement we will no longer need to + # convert to np.datetime64 values prior to passing to the DatetimeIndex + # constructor. With pandas < 3 the DatetimeIndex constructor does not + # infer the resolution from the resolution of the Timestamp values. + return pd.Timestamp(**kwargs).as_unit(unit).to_numpy() + + @dataclass(init=False) class EncodedGroups: """ @@ -955,19 +968,28 @@ def factorize(self, group: T_Group) -> EncodedGroups: counts = agged["count"] index_class: type[CFTimeIndex | pd.DatetimeIndex] + datetime_class: CFTimeDatetime | Callable[..., np.datetime64] if _contains_cftime_datetimes(group.data): index_class = CFTimeIndex datetime_class = type(first_n_items(group.data, 1).item()) else: index_class = pd.DatetimeIndex - datetime_class = datetime.datetime + unit, _ = np.datetime_data(group.dtype) + unit = cast(PDDatetimeUnitOptions, unit) + datetime_class = partial(_datetime64_via_timestamp, unit) # these are the seasons that are present + + # TODO: when pandas 3 is our minimum requirement we will no longer need + # to cast the list to a NumPy array prior to passing to the index + # constructor. unique_coord = index_class( - [ - datetime_class(year=year, month=season_tuples[season][0], day=1) - for year, season in first_items.index - ] + np.array( + [ + datetime_class(year=year, month=season_tuples[season][0], day=1) + for year, season in first_items.index + ] + ) ) # This sorted call is a hack. It's hard to figure out how @@ -975,15 +997,21 @@ def factorize(self, group: T_Group) -> EncodedGroups: # for example "DJF" as first entry or last entry # So we construct the largest possible index and slice it to the # range present in the data. + + # TODO: when pandas 3 is our minimum requirement we will no longer need + # to cast the list to a NumPy array prior to passing to the index + # constructor. complete_index = index_class( - sorted( - [ - datetime_class(year=y, month=m, day=1) - for y, m in itertools.product( - range(year[0].item(), year[-1].item() + 1), - [s[0] for s in season_inds], - ) - ] + np.array( + sorted( + [ + datetime_class(year=y, month=m, day=1) + for y, m in itertools.product( + range(year[0].item(), year[-1].item() + 1), + [s[0] for s in season_inds], + ) + ] + ) ) ) diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index b9b9fb151c7..c320931098a 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -14,7 +14,7 @@ import xarray as xr from xarray import DataArray, Dataset, Variable, date_range from xarray.core.groupby import _consolidate_slices -from xarray.core.types import InterpOptions, ResampleCompatible +from xarray.core.types import InterpOptions, PDDatetimeUnitOptions, ResampleCompatible from xarray.groupers import ( BinGrouper, EncodedGroups, @@ -3605,6 +3605,16 @@ def test_season_resampler_groupby_identical(self): gb = da.groupby(time=resampler).sum() assert_identical(rs, gb) + def test_season_resampler_preserves_time_unit( + self, time_unit: PDDatetimeUnitOptions + ) -> None: + time = date_range("2000", periods=12, freq="MS", unit=time_unit) + da = DataArray(np.ones(time.size), dims="time", coords={"time": time}) + resampler = SeasonResampler(["DJF", "MAM", "JJA", "SON"]) + result = da.resample(time=resampler).sum() + result_unit, _ = np.datetime_data(result.time.dtype) + assert result_unit == time_unit + @pytest.mark.parametrize( "chunk",