Skip to content

Commit 8c81345

Browse files
authored
fix: fill flattened pass-through strings (#629)
Close #609. Signed-off-by: Henry Schreiner <[email protected]>
1 parent 41559d6 commit 8c81345

File tree

3 files changed

+26
-4
lines changed

3 files changed

+26
-4
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ dask = [
9292
test = [
9393
{ include-group = "test-core" },
9494
{ include-group = "fit" },
95+
"awkward>=2.0.7",
9596
]
9697
test-core = [
9798
"pytest >=8",

src/hist/interop.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,15 +67,22 @@ def destructure(obj: Any) -> dict[str, Any] | None:
6767
raise TypeError(f"No histogram module found for {obj!r}")
6868

6969

70-
def broadcast_and_flatten(args: Sequence[Any]) -> tuple[np.typing.NDArray[Any], ...]:
70+
def broadcast_and_flatten(
71+
args: Sequence[Any],
72+
) -> tuple[str | np.typing.NDArray[Any], ...]:
7173
"""
7274
Convert the given histogram-module arrays into a set of consistent 1D NumPy arrays
7375
for histogram filling. For NumPy this entails broadcasting and flattening.
76+
77+
This skips passing strings to the backend, they are left inplace.
7478
"""
75-
for module in find_histogram_modules(*args):
76-
result = module.broadcast_and_flatten(args)
79+
80+
non_strings = [x for x in args if not isinstance(x, str)]
81+
for module in find_histogram_modules(*non_strings):
82+
result = module.broadcast_and_flatten(non_strings)
7783
if result is not NotImplemented:
78-
return result
84+
it = iter(result)
85+
return tuple(next(it) if not isinstance(x, str) else x for x in args)
7986

8087
raise TypeError(f"No histogram module found for {args!r}")
8188

tests/test_interop.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -383,3 +383,17 @@ def test_named_keyword(unnamed_hist):
383383
axis.Regular(2, 0, 1, name="z"),
384384
).fill_flattened(x=x, y=y, z=z)
385385
assert np.allclose(h.values(), expected)
386+
387+
388+
def test_string_fill_flattened():
389+
ark = np.array([1, 2, 3, 4, 5])
390+
h = hist.new.Reg(10, 0, 10, name="x").StrCat([], growth=True, name="cat").Weight()
391+
h.fill_flattened(x=ark, cat="A")
392+
393+
394+
def test_ak_fill_flattened():
395+
ak = pytest.importorskip("awkward")
396+
397+
ark = ak.Array([[1, 2, 3], [4, 5, 6], [9]])
398+
h = hist.new.Reg(10, 0, 10, name="x").StrCat([], growth=True, name="cat").Weight()
399+
h.fill_flattened(x=ark, cat="A")

0 commit comments

Comments
 (0)