Skip to content

Commit

Permalink
WIP: fix axis passthrough
Browse files Browse the repository at this point in the history
Signed-off-by: Henry Schreiner <[email protected]>
  • Loading branch information
henryiii committed Jan 31, 2025
1 parent 789d325 commit 08e5b1c
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 8 deletions.
24 changes: 23 additions & 1 deletion docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,31 @@

## Version 1.5

### Version 1.5.2

Fix for axis metadata not passing though non-uniform rebinning correctly. Flow
bins are now preserved when doing a non-uniform rebinning. Also adds the
ability to rebin by edges or an existing axis.

#### Features

- Support `edges=` and `axis=` in `bh.rebin` [#977][]

#### Bug fixes

- Axis metadata was broken when rebinning [#978][]
- Flow bins were lost when using variable rebinning [#977][]


[#977]: https://github.com/scikit-hep/boost-histogram/pull/977
[#978]: https://github.com/scikit-hep/boost-histogram/pull/978


### Version 1.5.1

Make non-uniform rebinning work for Weight() and friends [#972][]
#### Bug fixes

- Make non-uniform rebinning work for Weight() and friends [#972][]

[#972]: https://github.com/scikit-hep/boost-histogram/pull/972

Expand Down
23 changes: 17 additions & 6 deletions src/boost_histogram/histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -908,11 +908,18 @@ def __getitem__(self: H, index: IndexingExpr) -> H | float | Accumulator:
start, stop = self.axes[i]._process_loc(ind.start, ind.stop)

groups = []
new_axis = None
if ind != slice(None):
merge = 1
if ind.step is not None:
if getattr(ind.step, "factor", None) is not None:
merge = ind.step.factor
elif (
hasattr(ind.step, "axis_mapping")
and (tmp_both := ind.step.axis_mapping(self.axes[i]))
is not None
):
groups, new_axis = tmp_both
elif (
hasattr(ind.step, "group_mapping")
and (tmp_groups := ind.step.group_mapping(self.axes[i]))
Expand Down Expand Up @@ -958,22 +965,26 @@ def __getitem__(self: H, index: IndexingExpr) -> H | float | Accumulator:
new_axes_indices += [axes[i].edges[j + group]]
j += group

variable_axis = Variable(
new_axes_indices, __dict__=axes[i].metadata
)
axes[i] = variable_axis._ax
if new_axis is None:
new_axis = Variable(
new_axes_indices,
__dict__=axes[i].metadata,
underflow=axes[i].traits_underflow,
overflow=axes[i].traits_overflow,
)
axes[i] = new_axis._ax

logger.debug("Axes: %s", axes)

new_reduced = reduced.__class__(axes)
new_view = new_reduced.view(flow=True)
j = 0
new_j_base = 0
if self.axes[i].traits.underflow:
if axes[i].traits_underflow:
groups.insert(0, 1)
else:
new_j_base = 1
if self.axes[i].traits.overflow:
if axes[i].traits_overflow:
groups.append(1)
for new_j, group in enumerate(groups):
for _ in range(group):
Expand Down
9 changes: 9 additions & 0 deletions src/boost_histogram/tag.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,15 @@ def __repr__(self) -> str:
break
return return_str

# Note: this preserves the input type of `self.axis`, so is safe within a
# single UHI library, but not cross-library. Returns None for the axis if
# an axis is not provided, the caller should make an axis if that's the
# case.
def axis_mapping(
self, axis: PlottableAxis
) -> tuple[Sequence[int], PlottableAxis | None]:
return (self.group_mapping(axis), self.axis)

def group_mapping(self, axis: PlottableAxis) -> Sequence[int]:
if self.groups is not None:
if sum(self.groups) != len(axis):
Expand Down
11 changes: 10 additions & 1 deletion tests/test_histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,10 +653,12 @@ def test_rebin_1d(metadata):
hs = h[bh.rebin(edges=[1.0, 1.2, 1.6, 2.2, 5.0])]
assert_array_equal(hs.view(), [1, 0, 0, 3])
assert_array_equal(hs.axes.edges[0], [1.0, 1.2, 1.6, 2.2, 5.0])
assert h.axes[0].metadata is hs.axes[0].metadata

hs = h[bh.rebin(axis=bh.axis.Variable([1.0, 1.2, 1.6, 2.2, 5.0]))]
hs = h[bh.rebin(axis=bh.axis.Variable([1.0, 1.2, 1.6, 2.2, 5.0], metadata="hi"))]
assert_array_equal(hs.view(), [1, 0, 0, 3])
assert_array_equal(hs.axes.edges[0], [1.0, 1.2, 1.6, 2.2, 5.0])
assert hs.axes[0].metadata == "hi"


def test_rebin_1d_flow():
Expand All @@ -677,6 +679,13 @@ def test_rebin_1d_flow():
hs = h[bh.rebin(edges=[0, 3, 5.0])]
assert_array_equal(hs.view(flow=True), [1, 2, 2, 0])

h = bh.Histogram(bh.axis.Regular(5, 0, 5, underflow=True, overflow=True))
h.fill([-1, 1.1, 2.2, 3.3, 4.4, 5.5])
hs = h[
bh.rebin(axes=bh.axis.Variable([0, 3, 5.0], underflow=False, overflow=False))
]
assert_array_equal(hs.view(flow=True), [2, 2, 0])


def test_shrink_rebin_1d():
h = bh.Histogram(bh.axis.Regular(20, 0, 4))
Expand Down

0 comments on commit 08e5b1c

Please sign in to comment.