Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: allow rebinning by passing edges or a new axis #977

Merged
merged 10 commits into from
Jan 31, 2025
24 changes: 18 additions & 6 deletions src/boost_histogram/histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -881,7 +881,7 @@ def __getitem__(self: H, index: IndexingExpr) -> H | float | Accumulator:
reduced: CppHistogram | None = None

# Compute needed slices and projections
for i, ind in enumerate(indexes):
for i, ind in enumerate(indexes): # pylint: disable=too-many-nested-blocks
if isinstance(ind, SupportsIndex):
pick_each[i] = ind.__index__() + (
1 if self.axes[i].traits.underflow else 0
Expand Down Expand Up @@ -967,14 +967,26 @@ def __getitem__(self: H, index: IndexingExpr) -> H | float | Accumulator:

new_reduced = reduced.__class__(axes)
new_view = new_reduced.view(flow=True)

j = 1
j = 0
new_j_base = 0
if self.axes[i].traits.underflow:
groups.insert(0, 1)
else:
new_j_base = 1
if self.axes[i].traits.overflow:
groups.append(1)
for new_j, group in enumerate(groups):
for _ in range(group):
pos = [slice(None)] * (i)
new_view[(*pos, new_j + 1, ...)] += _to_view(
reduced_view[(*pos, j, ...)]
)
if new_view.dtype.names:
for field in new_view.dtype.names:
new_view[(*pos, new_j + new_j_base, ...)][
field
] += reduced_view[(*pos, j, ...)][field]
else:
new_view[(*pos, new_j + new_j_base, ...)] += (
reduced_view[(*pos, j, ...)]
)
j += 1

reduced = new_reduced
Expand Down
55 changes: 50 additions & 5 deletions src/boost_histogram/tag.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from builtins import sum
from typing import TYPE_CHECKING, Sequence, TypeVar

import numpy as np

if TYPE_CHECKING:
from uhi.typing.plottable import PlottableAxis

Expand Down Expand Up @@ -112,26 +114,43 @@ def __call__(self, axis: AxisLike) -> int: # noqa: ARG002

class rebin:
__slots__ = (
"axis",
"edges",
"factor",
"groups",
)

def __init__(
self,
factor: int | None = None,
factor_or_axis: int | PlottableAxis | None = None,
henryiii marked this conversation as resolved.
Show resolved Hide resolved
/,
*,
factor: int | None = None,
groups: Sequence[int] | None = None,
edges: Sequence[int | float] | None = None,
axis: PlottableAxis | None = None,
) -> None:
if not sum(i is None for i in [factor, groups]) == 1:
raise ValueError("Exactly one, a factor or groups should be provided")
self.factor = factor
if (
sum(i is not None for i in [factor_or_axis, factor, groups, edges, axis])
!= 1
):
raise ValueError("Exactly one argument should be provided")
self.groups = groups
self.edges = edges
self.axis = axis
self.factor = factor
if isinstance(factor_or_axis, int):
self.factor = factor_or_axis
elif factor_or_axis is not None:
self.axis = factor_or_axis
henryiii marked this conversation as resolved.
Show resolved Hide resolved

def __repr__(self) -> str:
repr_str = f"{self.__class__.__name__}"
args: dict[str, int | Sequence[int] | None] = {
args: dict[str, int | Sequence[int | float] | PlottableAxis | None] = {
"factor": self.factor,
"groups": self.groups,
"edges": self.edges,
"axis": self.axis,
}
for k, v in args.items():
if v is not None:
Expand All @@ -147,4 +166,30 @@ def group_mapping(self, axis: PlottableAxis) -> Sequence[int]:
return self.groups
if self.factor is not None:
return [self.factor] * len(axis)
if self.edges is not None or self.axis is not None:
newedges = None
if self.axis is not None and hasattr(self.axis, "edges"):
newedges = self.axis.edges
elif self.edges is not None:
newedges = self.edges

if newedges is not None and hasattr(axis, "edges"):
assert newedges[0] == axis.edges[0], "Edges must start at first bin"
assert newedges[-1] == axis.edges[-1], "Edges must end at last bin"
assert all(
np.isclose(
axis.edges[np.abs(axis.edges - edge).argmin()],
edge,
)
for edge in newedges
), "Edges must be in the axis"
matched_ixes = np.where(
np.isin(
axis.edges,
newedges,
)
)[0]
return [
int(ix - matched_ixes[i]) for i, ix in enumerate(matched_ixes[1:])
]
raise ValueError("No rebinning factor or groups provided")
27 changes: 27 additions & 0 deletions tests/test_histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,33 @@ def test_rebin_1d(metadata):
assert_array_equal(hs.axes.edges[0], [1.0, 1.2, 1.6, 2.2, 5.0])
assert h.axes[0].metadata is hs.axes[0].metadata

hs = h[bh.rebin(edges=[1.0, 1.2, 1.6, 2.2, 5.0])]
assert_array_equal(hs.view(), [1, 0, 0, 3])
assert_array_equal(hs.axes.edges[0], [1.0, 1.2, 1.6, 2.2, 5.0])

hs = h[bh.rebin(axis=bh.axis.Variable([1.0, 1.2, 1.6, 2.2, 5.0]))]
assert_array_equal(hs.view(), [1, 0, 0, 3])
assert_array_equal(hs.axes.edges[0], [1.0, 1.2, 1.6, 2.2, 5.0])


def test_rebin_1d_flow():
h = bh.Histogram(bh.axis.Regular(5, 0, 5, underflow=True, overflow=True))
h.fill([-1, 1.1, 2.2, 3.3, 4.4, 5.5])
hs = h[bh.rebin(edges=[0, 3, 5.0])]
assert_array_equal(hs.view(), [2, 2])
assert_array_equal(hs.view(flow=True), [1, 2, 2, 1])
assert_array_equal(hs.axes.edges[0], [0.0, 3.0, 5.0])

h = bh.Histogram(bh.axis.Regular(5, 0, 5, underflow=False, overflow=False))
h.fill([-1, 1.1, 2.2, 3.3, 4.4, 5.5])
hs = h[bh.rebin(edges=[0, 3, 5.0])]
assert_array_equal(hs.view(flow=True), [0, 2, 2, 0])

h = bh.Histogram(bh.axis.Regular(5, 0, 5, underflow=True, overflow=False))
h.fill([-1, 1.1, 2.2, 3.3, 4.4, 5.5])
hs = h[bh.rebin(edges=[0, 3, 5.0])]
assert_array_equal(hs.view(flow=True), [1, 2, 2, 0])


def test_shrink_rebin_1d():
h = bh.Histogram(bh.axis.Regular(20, 0, 4))
Expand Down
Loading