From 881b470c7b19e4f2fa9b51b3b0a3c4a90b6fd826 Mon Sep 17 00:00:00 2001 From: Saransh Chopra Date: Fri, 8 Mar 2024 16:50:23 +0100 Subject: [PATCH] make it work for nd hists --- src/boost_histogram/_internal/hist.py | 11 +++--- tests/test_histogram.py | 52 ++++++++++++++++++++++++++- 2 files changed, 57 insertions(+), 6 deletions(-) diff --git a/src/boost_histogram/_internal/hist.py b/src/boost_histogram/_internal/hist.py index c3c2d917..16c3079f 100644 --- a/src/boost_histogram/_internal/hist.py +++ b/src/boost_histogram/_internal/hist.py @@ -820,6 +820,7 @@ def __getitem__(self: H, index: IndexingExpr) -> H | float | Accumulator: slices: list[_core.algorithm.reduce_command] = [] pick_each: dict[int, int] = {} pick_set: dict[int, list[int]] = {} + reduced: CppHistogram | None = None # Compute needed slices and projections for i, ind in enumerate(indexes): @@ -884,7 +885,8 @@ def __getitem__(self: H, index: IndexingExpr) -> H | float | Accumulator: ) # rebinning with groups elif len(groups) != 0: - reduced = self._hist + if not reduced: + reduced = self._hist axes = [reduced.axis(x) for x in range(reduced.rank())] reduced_view = reduced.view(flow=True) new_axes_indices = [axes[i].edges[0]] @@ -907,7 +909,7 @@ def __getitem__(self: H, index: IndexingExpr) -> H | float | Accumulator: j = 1 for new_j, group in enumerate(groups): for _ in range(group): - pos = [slice] * (i) + pos = [slice(None)] * (i) new_view[(*pos, new_j + 1, ...)] += reduced_view[ # type: ignore[arg-type] (*pos, j, ...) # type: ignore[arg-type] ] @@ -916,10 +918,9 @@ def __getitem__(self: H, index: IndexingExpr) -> H | float | Accumulator: reduced = new_reduced # Will be updated below - if slices or pick_set or pick_each or integrations: + if (slices or pick_set or pick_each or integrations) and not reduced: reduced = self._hist - elif len(groups) == 0: - logger.debug("Reduce actions are all empty, just making a copy") + elif not reduced: reduced = copy.copy(self._hist) if pick_each: diff --git a/tests/test_histogram.py b/tests/test_histogram.py index a5c4e4d3..e572baa1 100644 --- a/tests/test_histogram.py +++ b/tests/test_histogram.py @@ -636,7 +636,6 @@ def test_rebin_1d(): hs = h[{0: slice(None, None, bh.tag.Rebinner(4))}] assert_array_equal(hs.view(), [1, 1, 1, 0, 1]) - print("Here") hs = h[{0: bh.tag.Rebinner(4)}] assert_array_equal(hs.view(), [1, 1, 1, 0, 1]) @@ -664,8 +663,59 @@ def test_rebin_nd(): assert h[{1: s[:: bh.rebin(2)]}].axes.size == (20, 15, 40) assert h[{2: s[:: bh.rebin(2)]}].axes.size == (20, 30, 20) + assert h[{0: s[:: bh.rebin(groups=[1, 2, 3])]}].axes.size == (3, 30, 40) + assert h[{1: s[:: bh.rebin(groups=[1, 2, 3])]}].axes.size == (20, 3, 40) + assert h[{2: s[:: bh.rebin(groups=[1, 2, 3])]}].axes.size == (20, 30, 3) + assert np.all( + np.isclose( + h[{0: s[:: bh.rebin(groups=[1, 2, 3])]}].axes[0].edges, [1.0, 1.1, 1.3, 1.6] + ) + ) + assert np.all( + np.isclose( + h[{1: s[:: bh.rebin(groups=[1, 2, 3])]}].axes[1].edges, + [1.0, 1.06666667, 1.2, 1.4], + ) + ) + assert np.all( + np.isclose( + h[{2: s[:: bh.rebin(groups=[1, 2, 3])]}].axes[2].edges, + [1.0, 1.05, 1.15, 1.3], + ) + ) + assert h[{0: s[:: bh.rebin(2)], 2: s[:: bh.rebin(2)]}].axes.size == (10, 30, 20) + assert h[ + {0: s[:: bh.rebin(groups=[1, 2, 3])], 2: s[:: bh.rebin(groups=[1, 2, 3])]} + ].axes.size == (3, 30, 3) + assert np.all( + np.isclose( + h[ + { + 0: s[:: bh.rebin(groups=[1, 2, 3])], + 2: s[:: bh.rebin(groups=[1, 2, 3])], + } + ] + .axes[0] + .edges, + [1.0, 1.1, 1.3, 1.6], + ) + ) + assert np.all( + np.isclose( + h[ + { + 0: s[:: bh.rebin(groups=[1, 2, 3])], + 2: s[:: bh.rebin(groups=[1, 2, 3])], + } + ] + .axes[2] + .edges, + [1.0, 1.05, 1.15, 1.3], + ) + ) + assert h[{1: s[:: bh.sum]}].axes.size == (20, 40) assert h[{1: bh.sum}].axes.size == (20, 40)