Skip to content

Commit

Permalink
fix: allow axis + other
Browse files Browse the repository at this point in the history
Signed-off-by: Henry Schreiner <[email protected]>
  • Loading branch information
henryiii committed Feb 5, 2025
1 parent 39ca0e3 commit 2eb78ab
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 11 deletions.
22 changes: 11 additions & 11 deletions src/boost_histogram/tag.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,19 +130,19 @@ def __init__(
edges: Sequence[int | float] | None = None,
axis: PlottableAxis | None = None,
) -> None:
if (
sum(i is not None for i in [factor_or_axis, factor, groups, edges, axis])
!= 1
):
if isinstance(factor_or_axis, int):
factor = factor_or_axis
elif factor_or_axis is not None:
axis = factor_or_axis

total_args = sum(i is not None for i in [factor, groups, edges])
if total_args !=1 and axis is None:
raise ValueError("Exactly one argument should be provided")

self.groups = groups
self.edges = edges
self.axis = axis
self.factor = factor
if isinstance(factor_or_axis, int):
self.factor = factor_or_axis
elif factor_or_axis is not None:
self.axis = factor_or_axis

def __repr__(self) -> str:
repr_str = f"{self.__class__.__name__}"
Expand Down Expand Up @@ -177,10 +177,10 @@ def group_mapping(self, axis: PlottableAxis) -> Sequence[int]:
return [self.factor] * len(axis)
if self.edges is not None or self.axis is not None:
newedges = None
if self.axis is not None and hasattr(self.axis, "edges"):
newedges = self.axis.edges
elif self.edges is not None:
if self.edges is not None:
newedges = self.edges
elif self.axis is not None and hasattr(self.axis, "edges"):
newedges = self.axis.edges

if newedges is not None and hasattr(axis, "edges"):
assert newedges[0] == axis.edges[0], "Edges must start at first bin"
Expand Down
18 changes: 18 additions & 0 deletions tests/test_histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -688,6 +688,24 @@ def test_rebin_1d_flow():
assert_array_equal(hs.view(flow=True), [2, 2])


def test_rebin_change_axis_int():
h = bh.Histogram(bh.axis.Regular(5, 0, 5))
h.fill([-1, 1.1, 2.2, 3.3, 4.4, 5.5])
hs = h[bh.rebin(edges=[0, 3, 5.0], axis=bh.axis.Integer(10,12))]
assert_array_equal(hs.view(), [2, 2])
assert_array_equal(hs.view(flow=True), [1, 2, 2, 1])
assert_array_equal(hs.axes.edges[0], [10, 11, 12])


def test_rebin_change_axis_cat():
h = bh.Histogram(bh.axis.Regular(5, 0, 5))
h.fill([-1, 1.1, 2.2, 3.3, 4.4, 5.5])
hs = h[bh.rebin(groups=[2,2,1], axis=bh.axis.StrCategory(["a", "b"]))]
assert_array_equal(hs.view(), [2, 2])
assert_array_equal(hs.view(flow=True), [1, 2, 2, 1])
assert_array_equal(hs.axes.edges[0], [10, 11, 12])


def test_shrink_rebin_1d():
h = bh.Histogram(bh.axis.Regular(20, 0, 4))
h.fill(1.1)
Expand Down

0 comments on commit 2eb78ab

Please sign in to comment.