Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ Bug Fixes
``np.isclose`` by default to handle accumulated floating point errors from
slicing operations. Use ``exact=True`` for exact comparison (:pull:`11035`).
By `Ian Hunt-Isaak <https://github.com/ianhi>`_.
- Ensure the :py:class:`~xarray.groupers.SeasonResampler` preserves the datetime
unit of the underlying time index when resampling (:issue:`11048`,
:pull:`11049`). By `Spencer Clark <https://github.com/spencerkclark>`_.

Documentation
~~~~~~~~~~~~~
Expand Down
56 changes: 42 additions & 14 deletions xarray/groupers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
import operator
from abc import ABC, abstractmethod
from collections import defaultdict
from collections.abc import Hashable, Mapping, Sequence
from collections.abc import Callable, Hashable, Mapping, Sequence
from dataclasses import dataclass, field
from functools import partial
from itertools import chain, pairwise
from typing import TYPE_CHECKING, Any, Literal, cast

Expand All @@ -38,8 +39,10 @@
from xarray.core.resample_cftime import CFTimeGrouper
from xarray.core.types import (
Bins,
CFTimeDatetime,
DatetimeLike,
GroupIndices,
PDDatetimeUnitOptions,
ResampleCompatible,
Self,
SideOptions,
Expand All @@ -61,6 +64,16 @@
RESAMPLE_DIM = "__resample_dim__"


def _datetime64_via_timestamp(unit: PDDatetimeUnitOptions, **kwargs) -> np.datetime64:
"""Construct a numpy.datetime64 object through the pandas.Timestamp
constructor with a specific resolution."""
# TODO: when pandas 3 is our minimum requirement we will no longer need to
# convert to np.datetime64 values prior to passing to the DatetimeIndex
# constructor. With pandas < 3 the DatetimeIndex constructor does not
# infer the resolution from the resolution of the Timestamp values.
return pd.Timestamp(**kwargs).as_unit(unit).to_numpy()


@dataclass(init=False)
class EncodedGroups:
"""
Expand Down Expand Up @@ -955,35 +968,50 @@ def factorize(self, group: T_Group) -> EncodedGroups:
counts = agged["count"]

index_class: type[CFTimeIndex | pd.DatetimeIndex]
datetime_class: CFTimeDatetime | Callable[..., np.datetime64]
if _contains_cftime_datetimes(group.data):
index_class = CFTimeIndex
datetime_class = type(first_n_items(group.data, 1).item())
else:
index_class = pd.DatetimeIndex
datetime_class = datetime.datetime
unit, _ = np.datetime_data(group.dtype)
unit = cast(PDDatetimeUnitOptions, unit)
datetime_class = partial(_datetime64_via_timestamp, unit)

# these are the seasons that are present

# TODO: when pandas 3 is our minimum requirement we will no longer need
# to cast the list to a NumPy array prior to passing to the index
# constructor.
unique_coord = index_class(
[
datetime_class(year=year, month=season_tuples[season][0], day=1)
for year, season in first_items.index
]
np.array(
[
datetime_class(year=year, month=season_tuples[season][0], day=1)
for year, season in first_items.index
]
)
)

# This sorted call is a hack. It's hard to figure out how
# to start the iteration for arbitrary season ordering
# for example "DJF" as first entry or last entry
# So we construct the largest possible index and slice it to the
# range present in the data.

# TODO: when pandas 3 is our minimum requirement we will no longer need
# to cast the list to a NumPy array prior to passing to the index
# constructor.
complete_index = index_class(
sorted(
[
datetime_class(year=y, month=m, day=1)
for y, m in itertools.product(
range(year[0].item(), year[-1].item() + 1),
[s[0] for s in season_inds],
)
]
np.array(
sorted(
[
datetime_class(year=y, month=m, day=1)
for y, m in itertools.product(
range(year[0].item(), year[-1].item() + 1),
[s[0] for s in season_inds],
)
]
)
)
)

Expand Down
12 changes: 11 additions & 1 deletion xarray/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import xarray as xr
from xarray import DataArray, Dataset, Variable, date_range
from xarray.core.groupby import _consolidate_slices
from xarray.core.types import InterpOptions, ResampleCompatible
from xarray.core.types import InterpOptions, PDDatetimeUnitOptions, ResampleCompatible
from xarray.groupers import (
BinGrouper,
EncodedGroups,
Expand Down Expand Up @@ -3605,6 +3605,16 @@ def test_season_resampler_groupby_identical(self):
gb = da.groupby(time=resampler).sum()
assert_identical(rs, gb)

def test_season_resampler_preserves_time_unit(
self, time_unit: PDDatetimeUnitOptions
) -> None:
time = date_range("2000", periods=12, freq="MS", unit=time_unit)
da = DataArray(np.ones(time.size), dims="time", coords={"time": time})
resampler = SeasonResampler(["DJF", "MAM", "JJA", "SON"])
result = da.resample(time=resampler).sum()
result_unit, _ = np.datetime_data(result.time.dtype)
assert result_unit == time_unit


@pytest.mark.parametrize(
"chunk",
Expand Down
Loading