From 2eb78ab1e8813d036b0c2b92e245656c9998d831 Mon Sep 17 00:00:00 2001 From: Henry Schreiner Date: Wed, 5 Feb 2025 16:00:31 -0500 Subject: [PATCH] fix: allow axis + other Signed-off-by: Henry Schreiner --- src/boost_histogram/tag.py | 22 +++++++++++----------- tests/test_histogram.py | 18 ++++++++++++++++++ 2 files changed, 29 insertions(+), 11 deletions(-) diff --git a/src/boost_histogram/tag.py b/src/boost_histogram/tag.py index ab7bc6c2..16b9939b 100644 --- a/src/boost_histogram/tag.py +++ b/src/boost_histogram/tag.py @@ -130,19 +130,19 @@ def __init__( edges: Sequence[int | float] | None = None, axis: PlottableAxis | None = None, ) -> None: - if ( - sum(i is not None for i in [factor_or_axis, factor, groups, edges, axis]) - != 1 - ): + if isinstance(factor_or_axis, int): + factor = factor_or_axis + elif factor_or_axis is not None: + axis = factor_or_axis + + total_args = sum(i is not None for i in [factor, groups, edges]) + if total_args !=1 and axis is None: raise ValueError("Exactly one argument should be provided") + self.groups = groups self.edges = edges self.axis = axis self.factor = factor - if isinstance(factor_or_axis, int): - self.factor = factor_or_axis - elif factor_or_axis is not None: - self.axis = factor_or_axis def __repr__(self) -> str: repr_str = f"{self.__class__.__name__}" @@ -177,10 +177,10 @@ def group_mapping(self, axis: PlottableAxis) -> Sequence[int]: return [self.factor] * len(axis) if self.edges is not None or self.axis is not None: newedges = None - if self.axis is not None and hasattr(self.axis, "edges"): - newedges = self.axis.edges - elif self.edges is not None: + if self.edges is not None: newedges = self.edges + elif self.axis is not None and hasattr(self.axis, "edges"): + newedges = self.axis.edges if newedges is not None and hasattr(axis, "edges"): assert newedges[0] == axis.edges[0], "Edges must start at first bin" diff --git a/tests/test_histogram.py b/tests/test_histogram.py index 946562a5..0847d5ec 100644 --- a/tests/test_histogram.py +++ b/tests/test_histogram.py @@ -688,6 +688,24 @@ def test_rebin_1d_flow(): assert_array_equal(hs.view(flow=True), [2, 2]) +def test_rebin_change_axis_int(): + h = bh.Histogram(bh.axis.Regular(5, 0, 5)) + h.fill([-1, 1.1, 2.2, 3.3, 4.4, 5.5]) + hs = h[bh.rebin(edges=[0, 3, 5.0], axis=bh.axis.Integer(10,12))] + assert_array_equal(hs.view(), [2, 2]) + assert_array_equal(hs.view(flow=True), [1, 2, 2, 1]) + assert_array_equal(hs.axes.edges[0], [10, 11, 12]) + + +def test_rebin_change_axis_cat(): + h = bh.Histogram(bh.axis.Regular(5, 0, 5)) + h.fill([-1, 1.1, 2.2, 3.3, 4.4, 5.5]) + hs = h[bh.rebin(groups=[2,2,1], axis=bh.axis.StrCategory(["a", "b"]))] + assert_array_equal(hs.view(), [2, 2]) + assert_array_equal(hs.view(flow=True), [1, 2, 2, 1]) + assert_array_equal(hs.axes.edges[0], [10, 11, 12]) + + def test_shrink_rebin_1d(): h = bh.Histogram(bh.axis.Regular(20, 0, 4)) h.fill(1.1)