Skip to content

Commit

Permalink
fix: ak.typetracer.length_one_if_typetracer with option and union typ…
Browse files Browse the repository at this point in the history
…es (#3266)

* fix: ak.typetracer.length_one_if_typetracer with option and union types

* forgot to add the test

* style: pre-commit fixes

* no, not the Emacs backup file

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
jpivarski and pre-commit-ci[bot] authored Oct 5, 2024
1 parent d7264c1 commit 88ba3e5
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 9 deletions.
64 changes: 55 additions & 9 deletions src/awkward/forms/form.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,52 @@ def max_prefer_unknown(this: ShapeItem, that: ShapeItem) -> ShapeItem:

container = {}

def prepare_empty(form):
form_key = f"node-{len(container)}"

if isinstance(form, (ak.forms.BitMaskedForm, ak.forms.ByteMaskedForm)):
container[form_key] = b""
return form.copy(content=prepare_empty(form.content), form_key=form_key)

elif isinstance(form, ak.forms.IndexedOptionForm):
container[form_key] = b""
return form.copy(content=prepare_empty(form.content), form_key=form_key)

elif isinstance(form, ak.forms.EmptyForm):
return form

elif isinstance(form, ak.forms.UnmaskedForm):
return form.copy(content=prepare_empty(form.content))

elif isinstance(form, (ak.forms.IndexedForm, ak.forms.ListForm)):
container[form_key] = b""
return form.copy(content=prepare_empty(form.content), form_key=form_key)

elif isinstance(form, ak.forms.ListOffsetForm):
container[form_key] = b""
return form.copy(content=prepare_empty(form.content), form_key=form_key)

elif isinstance(form, ak.forms.RegularForm):
return form.copy(content=prepare_empty(form.content))

elif isinstance(form, ak.forms.NumpyForm):
container[form_key] = b""
return form.copy(form_key=form_key)

elif isinstance(form, ak.forms.RecordForm):
return form.copy(contents=[prepare_empty(x) for x in form.contents])

elif isinstance(form, ak.forms.UnionForm):
# both tags and index will get this buffer
container[form_key] = b""
return form.copy(
contents=[prepare_empty(x) for x in form.contents],
form_key=form_key,
)

else:
raise AssertionError(f"not a Form: {form!r}")

def prepare(form, multiplier):
form_key = f"node-{len(container)}"

Expand All @@ -566,11 +612,13 @@ def prepare(form, multiplier):
container[form_key] = b"\x00" * multiplier
else:
container[form_key] = b"\xff" * multiplier
return form.copy(form_key=form_key) # DO NOT RECURSE
# switch from recursing down `prepare` to `prepare_empty`
return form.copy(content=prepare_empty(form.content), form_key=form_key)

elif isinstance(form, ak.forms.IndexedOptionForm):
container[form_key] = b"\xff\xff\xff\xff\xff\xff\xff\xff" # -1
return form.copy(form_key=form_key) # DO NOT RECURSE
# switch from recursing down `prepare` to `prepare_empty`
return form.copy(content=prepare_empty(form.content), form_key=form_key)

elif isinstance(form, ak.forms.EmptyForm):
# no error if protected by non-recursing node type
Expand Down Expand Up @@ -624,13 +672,11 @@ def prepare(form, multiplier):
elif isinstance(form, ak.forms.UnionForm):
# both tags and index will get this buffer, but index is 8 bytes
container[form_key] = b"\x00" * (8 * multiplier)
return form.copy(
# only recurse down contents[0] because all index == 0
contents=(
[prepare(form.contents[0], multiplier)] + form.contents[1:]
),
form_key=form_key,
)
# recurse down contents[0] with `prepare`, but others with `prepare_empty`
contents = [prepare(form.contents[0], multiplier)]
for x in form.contents[1:]:
contents.append(prepare_empty(x))
return form.copy(contents=contents, form_key=form_key)

else:
raise AssertionError(f"not a Form: {form!r}")
Expand Down
14 changes: 14 additions & 0 deletions tests/test_3264_length_one_if_typetracer_with_option_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE
# ruff: noqa: E402

from __future__ import annotations

import awkward as ak


def test():
arr = ak.Array([[1], [2, 3], [1, 2, 4, 5]])[[0, None, 2]]
l1 = ak.typetracer.length_one_if_typetracer(ak.to_backend(arr, "typetracer"))

assert l1.to_list() == [None]
assert str(l1.type) == "1 * option[var * int64]"

0 comments on commit 88ba3e5

Please sign in to comment.