Skip to content

Commit

Permalink
fix: support callable, add validation
Browse files Browse the repository at this point in the history
Signed-off-by: Henry Schreiner <[email protected]>
  • Loading branch information
henryiii committed Aug 23, 2024
1 parent 95ddb0f commit 049cebc
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 34 deletions.
4 changes: 1 addition & 3 deletions src/boost_histogram/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,10 @@
from . import accumulators, axis, numpy, storage
from ._internal.enum import Kind
from ._internal.hist import Histogram, IndexingExpr
from .tag import (
Rebinner as rebin,
)
from .tag import ( # pylint: disable=redefined-builtin
loc,
overflow,
rebin,
sum,
underflow,
)
Expand Down
4 changes: 2 additions & 2 deletions src/boost_histogram/_internal/hist.py
Original file line number Diff line number Diff line change
Expand Up @@ -861,11 +861,11 @@ def __getitem__(self: H, index: IndexingExpr) -> H | float | Accumulator:
if ind.step is not None:
if getattr(ind.step, "factor", None) is not None:
merge = ind.step.factor
elif getattr(ind.step, "group_mapping", None) is not None:
groups = ind.step.group_mapping(self.axes[i])
elif callable(ind.step):
if ind.step is sum:
integrations.add(i)
elif getattr(ind.step, "groups", None) is not None:
groups = ind.step.groups
else:
raise NotImplementedError

Expand Down
17 changes: 9 additions & 8 deletions src/boost_histogram/tag.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from ._internal.typing import AxisLike

__all__ = ("Slicer", "Locator", "at", "loc", "overflow", "underflow", "Rebinner", "sum")
__all__ = ("Slicer", "Locator", "at", "loc", "overflow", "underflow", "sum", "rebin")


class Slicer:
Expand Down Expand Up @@ -110,7 +110,7 @@ def __call__(self, axis: AxisLike) -> int: # noqa: ARG002
return self.value


class Rebinner:
class rebin:
__slots__ = (
"factor",
"groups",
Expand Down Expand Up @@ -139,11 +139,12 @@ def __repr__(self) -> str:
break
return return_str

def __call__(self, axis: PlottableAxis) -> int | Sequence[int]:
if self.factor is not None:
return [self.factor] * (len(axis) // self.factor)

def group_mapping(self, axis: PlottableAxis) -> Sequence[int]:
if self.groups is not None:
if sum(self.groups) != len(axis):
msg = f"The sum of the groups ({sum(self.groups)}) must be equal to the number of bins in the axis ({len(axis)})"
raise ValueError(msg)
return self.groups

raise NotImplementedError(axis)
if self.factor is not None:
return [self.factor] * len(axis)
raise ValueError("No rebinning factor or groups provided")
41 changes: 21 additions & 20 deletions tests/test_histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,15 +634,15 @@ def test_rebin_1d():
h = bh.Histogram(bh.axis.Regular(20, 1, 5))
h.fill([1.1, 2.2, 3.3, 4.4])

hs = h[{0: slice(None, None, bh.tag.Rebinner(4))}]
hs = h[{0: slice(None, None, bh.rebin(4))}]
assert_array_equal(hs.view(), [1, 1, 1, 0, 1])

hs = h[{0: bh.tag.Rebinner(4)}]
hs = h[{0: bh.rebin(4)}]
assert_array_equal(hs.view(), [1, 1, 1, 0, 1])

hs = h[{0: bh.tag.Rebinner(groups=[1, 2, 3, 4])}]
assert_array_equal(hs.view(), [1, 0, 0, 1])
assert_array_equal(hs.axes.edges[0], [1.0, 1.2, 1.6, 2.2, 3.0])
hs = h[{0: bh.rebin(groups=[1, 2, 3, 14])}]
assert_array_equal(hs.view(), [1, 0, 0, 3])
assert_array_equal(hs.axes.edges[0], [1.0, 1.2, 1.6, 2.2, 5.0])


def test_shrink_rebin_1d():
Expand All @@ -663,56 +663,57 @@ def test_rebin_nd():
assert h[{1: s[:: bh.rebin(2)]}].axes.size == (20, 15, 40)
assert h[{2: s[:: bh.rebin(2)]}].axes.size == (20, 30, 20)

assert h[{0: s[:: bh.rebin(groups=[1, 2, 3])]}].axes.size == (3, 30, 40)
assert h[{1: s[:: bh.rebin(groups=[1, 2, 3])]}].axes.size == (20, 3, 40)
assert h[{2: s[:: bh.rebin(groups=[1, 2, 3])]}].axes.size == (20, 30, 3)
assert h[{0: s[:: bh.rebin(groups=[1, 2, 17])]}].axes.size == (3, 30, 40)
assert h[{1: s[:: bh.rebin(groups=[1, 2, 27])]}].axes.size == (20, 3, 40)
assert h[{2: s[:: bh.rebin(groups=[1, 2, 37])]}].axes.size == (20, 30, 3)
assert np.all(
np.isclose(
h[{0: s[:: bh.rebin(groups=[1, 2, 3])]}].axes[0].edges, [1.0, 1.1, 1.3, 1.6]
h[{0: s[:: bh.rebin(groups=[1, 2, 17])]}].axes[0].edges,
[1.0, 1.1, 1.3, 3.0],
)
)
assert np.all(
np.isclose(
h[{1: s[:: bh.rebin(groups=[1, 2, 3])]}].axes[1].edges,
[1.0, 1.06666667, 1.2, 1.4],
h[{1: s[:: bh.rebin(groups=[1, 2, 27])]}].axes[1].edges,
[1.0, 1.06666667, 1.2, 3.0],
)
)
assert np.all(
np.isclose(
h[{2: s[:: bh.rebin(groups=[1, 2, 3])]}].axes[2].edges,
[1.0, 1.05, 1.15, 1.3],
h[{2: s[:: bh.rebin(groups=[1, 2, 37])]}].axes[2].edges,
[1.0, 1.05, 1.15, 3.0],
)
)

assert h[{0: s[:: bh.rebin(2)], 2: s[:: bh.rebin(2)]}].axes.size == (10, 30, 20)

assert h[
{0: s[:: bh.rebin(groups=[1, 2, 3])], 2: s[:: bh.rebin(groups=[1, 2, 3])]}
{0: s[:: bh.rebin(groups=[1, 2, 17])], 2: s[:: bh.rebin(groups=[1, 2, 37])]}
].axes.size == (3, 30, 3)
assert np.all(
np.isclose(
h[
{
0: s[:: bh.rebin(groups=[1, 2, 3])],
2: s[:: bh.rebin(groups=[1, 2, 3])],
0: s[:: bh.rebin(groups=[1, 2, 17])],
2: s[:: bh.rebin(groups=[1, 2, 37])],
}
]
.axes[0]
.edges,
[1.0, 1.1, 1.3, 1.6],
[1.0, 1.1, 1.3, 3],
)
)
assert np.all(
np.isclose(
h[
{
0: s[:: bh.rebin(groups=[1, 2, 3])],
2: s[:: bh.rebin(groups=[1, 2, 3])],
0: s[:: bh.rebin(groups=[1, 2, 17])],
2: s[:: bh.rebin(groups=[1, 2, 37])],
}
]
.axes[2]
.edges,
[1.0, 1.05, 1.15, 1.3],
[1.0, 1.05, 1.15, 3.0],
)
)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_histogram_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def test_repr():
assert repr(bh.overflow + 1) == "overflow + 1"
assert repr(bh.overflow - 1) == "overflow - 1"

assert repr(bh.rebin(2)) == "Rebinner(factor=2)"
assert repr(bh.rebin(2)) == "rebin(factor=2)"


# Was broken in 0.6.1
Expand Down

0 comments on commit 049cebc

Please sign in to comment.