From 0e59f87161508b5154f874e14435cb7844a87fd2 Mon Sep 17 00:00:00 2001 From: Henry Schreiner Date: Thu, 1 Feb 2024 17:41:49 -0500 Subject: [PATCH] fix: partial fix Signed-off-by: Henry Schreiner --- setup.cfg | 1 - src/boost_histogram/__init__.py | 4 ++++ src/boost_histogram/_internal/hist.py | 9 ++++----- src/boost_histogram/tag.py | 28 +++++++++++---------------- 4 files changed, 19 insertions(+), 23 deletions(-) diff --git a/setup.cfg b/setup.cfg index 70669f84..266c9b85 100644 --- a/setup.cfg +++ b/setup.cfg @@ -57,7 +57,6 @@ packages = boost_histogram._internal boost_histogram.axis install_requires = - uhi numpy>=1.26.0b1;python_version>='3.12' numpy;python_version<'3.12' typing-extensions;python_version<'3.8' diff --git a/src/boost_histogram/__init__.py b/src/boost_histogram/__init__.py index 8faa5c6c..8700f90b 100644 --- a/src/boost_histogram/__init__.py +++ b/src/boost_histogram/__init__.py @@ -3,6 +3,9 @@ from . import accumulators, axis, numpy, storage from ._internal.enum import Kind from ._internal.hist import Histogram, IndexingExpr +from .tag import ( + Rebinner as rebin, +) from .tag import ( # pylint: disable=redefined-builtin loc, overflow, @@ -34,6 +37,7 @@ "accumulators", "numpy", "loc", + "rebin", "sum", "underflow", "overflow", diff --git a/src/boost_histogram/_internal/hist.py b/src/boost_histogram/_internal/hist.py index 9db74ff0..8e6f1676 100644 --- a/src/boost_histogram/_internal/hist.py +++ b/src/boost_histogram/_internal/hist.py @@ -854,12 +854,12 @@ def __getitem__(self: H, index: IndexingExpr) -> H | float | Accumulator: if ind != slice(None): merge = 1 if ind.step is not None: - if ind.step.factor is not None: + if getattr(ind.step, "factor", None) is not None: merge = ind.step.factor elif callable(ind.step): if ind.step is sum: integrations.add(i) - elif ind.step.groups is not None: + elif getattr(ind.step, "groups", None) is not None: groups = ind.step.groups else: raise NotImplementedError @@ -904,9 +904,8 @@ def __getitem__(self: H, index: IndexingExpr) -> H | float | Accumulator: new_axes_indices += [axes[i].edges[j + 1 : j + group + 1][-1]] j = group - variable_axis = Variable(new_axes_indices) - variable_axis.metadata = axes[i].metadata - axes[i] = variable_axis + variable_axis = Variable(new_axes_indices, metadata=axes[i].metadata) + axes[i] = variable_axis._axis reduced_view = np.take(reduced_view, range(len(reduced_view)), axis=i) logger.debug("Axes: %s", axes) diff --git a/src/boost_histogram/tag.py b/src/boost_histogram/tag.py index adcd3372..5cf82892 100644 --- a/src/boost_histogram/tag.py +++ b/src/boost_histogram/tag.py @@ -4,12 +4,12 @@ import copy from builtins import sum -from typing import Sequence, TypeVar +from typing import TYPE_CHECKING, Sequence, TypeVar -from uhi.typing.plottable import PlottableAxis +if TYPE_CHECKING: + from uhi.typing.plottable import PlottableAxis from ._internal.typing import AxisLike -from .axis import Regular, Variable __all__ = ("Slicer", "Locator", "at", "loc", "overflow", "underflow", "Rebinner", "sum") @@ -114,21 +114,17 @@ class Rebinner: __slots__ = ( "factor", "groups", - "category_map", ) def __init__( self, + factor: int | None = None, *, - value: int | None = None, groups: Sequence[int] | None = None, ) -> None: - if ( - sum(i is None for i in [value, groups]) == 2 - or sum(i is not None for i in [value, groups]) > 1 - ): + if not sum(i is None for i in [factor, groups]) == 1: raise ValueError("exactly one, a value or groups should be provided") - self.factor = value + self.factor = factor self.groups = groups def __repr__(self) -> str: @@ -144,12 +140,10 @@ def __repr__(self) -> str: return return_str def __call__(self, axis: PlottableAxis) -> int | Sequence[int]: - if isinstance(axis, Regular): - if self.factor is None: - raise ValueError("must provide a value") - return self.factor - if isinstance(axis, Variable): - if self.groups is None: - raise ValueError("must provide bin groups") + if self.factor is not None: + return [self.factor] * (len(axis) // self.factor) + + if self.groups is not None: return self.groups + raise NotImplementedError(axis)