diff --git a/src/boost_histogram/histogram.py b/src/boost_histogram/histogram.py index 2c62be50..205b9731 100644 --- a/src/boost_histogram/histogram.py +++ b/src/boost_histogram/histogram.py @@ -881,7 +881,7 @@ def __getitem__(self: H, index: IndexingExpr) -> H | float | Accumulator: reduced: CppHistogram | None = None # Compute needed slices and projections - for i, ind in enumerate(indexes): + for i, ind in enumerate(indexes): # pylint: disable=too-many-nested-blocks if isinstance(ind, SupportsIndex): pick_each[i] = ind.__index__() + ( 1 if self.axes[i].traits.underflow else 0 @@ -967,14 +967,26 @@ def __getitem__(self: H, index: IndexingExpr) -> H | float | Accumulator: new_reduced = reduced.__class__(axes) new_view = new_reduced.view(flow=True) - - j = 1 + j = 0 + new_j_base = 0 + if self.axes[i].traits.underflow: + groups.insert(0, 1) + else: + new_j_base = 1 + if self.axes[i].traits.overflow: + groups.append(1) for new_j, group in enumerate(groups): for _ in range(group): pos = [slice(None)] * (i) - new_view[(*pos, new_j + 1, ...)] += _to_view( - reduced_view[(*pos, j, ...)] - ) + if new_view.dtype.names: + for field in new_view.dtype.names: + new_view[(*pos, new_j + new_j_base, ...)][ + field + ] += reduced_view[(*pos, j, ...)][field] + else: + new_view[(*pos, new_j + new_j_base, ...)] += ( + reduced_view[(*pos, j, ...)] + ) j += 1 reduced = new_reduced diff --git a/src/boost_histogram/tag.py b/src/boost_histogram/tag.py index dcedf7b9..d681db46 100644 --- a/src/boost_histogram/tag.py +++ b/src/boost_histogram/tag.py @@ -6,6 +6,8 @@ from builtins import sum from typing import TYPE_CHECKING, Sequence, TypeVar +import numpy as np + if TYPE_CHECKING: from uhi.typing.plottable import PlottableAxis @@ -112,26 +114,43 @@ def __call__(self, axis: AxisLike) -> int: # noqa: ARG002 class rebin: __slots__ = ( + "axis", + "edges", "factor", "groups", ) def __init__( self, - factor: int | None = None, + factor_or_axis: int | PlottableAxis | None = None, + /, *, + factor: int | None = None, groups: Sequence[int] | None = None, + edges: Sequence[int | float] | None = None, + axis: PlottableAxis | None = None, ) -> None: - if not sum(i is None for i in [factor, groups]) == 1: - raise ValueError("Exactly one, a factor or groups should be provided") - self.factor = factor + if ( + sum(i is not None for i in [factor_or_axis, factor, groups, edges, axis]) + != 1 + ): + raise ValueError("Exactly one argument should be provided") self.groups = groups + self.edges = edges + self.axis = axis + self.factor = factor + if isinstance(factor_or_axis, int): + self.factor = factor_or_axis + elif factor_or_axis is not None: + self.axis = factor_or_axis def __repr__(self) -> str: repr_str = f"{self.__class__.__name__}" - args: dict[str, int | Sequence[int] | None] = { + args: dict[str, int | Sequence[int | float] | PlottableAxis | None] = { "factor": self.factor, "groups": self.groups, + "edges": self.edges, + "axis": self.axis, } for k, v in args.items(): if v is not None: @@ -147,4 +166,30 @@ def group_mapping(self, axis: PlottableAxis) -> Sequence[int]: return self.groups if self.factor is not None: return [self.factor] * len(axis) + if self.edges is not None or self.axis is not None: + newedges = None + if self.axis is not None and hasattr(self.axis, "edges"): + newedges = self.axis.edges + elif self.edges is not None: + newedges = self.edges + + if newedges is not None and hasattr(axis, "edges"): + assert newedges[0] == axis.edges[0], "Edges must start at first bin" + assert newedges[-1] == axis.edges[-1], "Edges must end at last bin" + assert all( + np.isclose( + axis.edges[np.abs(axis.edges - edge).argmin()], + edge, + ) + for edge in newedges + ), "Edges must be in the axis" + matched_ixes = np.where( + np.isin( + axis.edges, + newedges, + ) + )[0] + return [ + int(ix - matched_ixes[i]) for i, ix in enumerate(matched_ixes[1:]) + ] raise ValueError("No rebinning factor or groups provided") diff --git a/tests/test_histogram.py b/tests/test_histogram.py index f26df9f7..4dedef2e 100644 --- a/tests/test_histogram.py +++ b/tests/test_histogram.py @@ -650,6 +650,33 @@ def test_rebin_1d(metadata): assert_array_equal(hs.axes.edges[0], [1.0, 1.2, 1.6, 2.2, 5.0]) assert h.axes[0].metadata is hs.axes[0].metadata + hs = h[bh.rebin(edges=[1.0, 1.2, 1.6, 2.2, 5.0])] + assert_array_equal(hs.view(), [1, 0, 0, 3]) + assert_array_equal(hs.axes.edges[0], [1.0, 1.2, 1.6, 2.2, 5.0]) + + hs = h[bh.rebin(axis=bh.axis.Variable([1.0, 1.2, 1.6, 2.2, 5.0]))] + assert_array_equal(hs.view(), [1, 0, 0, 3]) + assert_array_equal(hs.axes.edges[0], [1.0, 1.2, 1.6, 2.2, 5.0]) + + +def test_rebin_1d_flow(): + h = bh.Histogram(bh.axis.Regular(5, 0, 5, underflow=True, overflow=True)) + h.fill([-1, 1.1, 2.2, 3.3, 4.4, 5.5]) + hs = h[bh.rebin(edges=[0, 3, 5.0])] + assert_array_equal(hs.view(), [2, 2]) + assert_array_equal(hs.view(flow=True), [1, 2, 2, 1]) + assert_array_equal(hs.axes.edges[0], [0.0, 3.0, 5.0]) + + h = bh.Histogram(bh.axis.Regular(5, 0, 5, underflow=False, overflow=False)) + h.fill([-1, 1.1, 2.2, 3.3, 4.4, 5.5]) + hs = h[bh.rebin(edges=[0, 3, 5.0])] + assert_array_equal(hs.view(flow=True), [0, 2, 2, 0]) + + h = bh.Histogram(bh.axis.Regular(5, 0, 5, underflow=True, overflow=False)) + h.fill([-1, 1.1, 2.2, 3.3, 4.4, 5.5]) + hs = h[bh.rebin(edges=[0, 3, 5.0])] + assert_array_equal(hs.view(flow=True), [1, 2, 2, 0]) + def test_shrink_rebin_1d(): h = bh.Histogram(bh.axis.Regular(20, 0, 4))