Skip to content

Commit

Permalink
feat: use pyarrow for string functions (#2616)
Browse files Browse the repository at this point in the history
* First function is working: is_alnum.

* is_alpha

* is_decimal

* is_lower

* is_digit

* is_numeric

* is_printable

* is_space

* is_upper

* is_title

* is_ascii; done with string predicates

* capitalize

* lower

* upper

* upper

* title

* T -> T operations on bytestrings should return bytestrings.

* repeat (the first that needs a broadcastable argument)

* reverse (because it's easy)

* replace_slice

* replace_substring

* Also test 'max_replacements' in replace_substring.

* replace_substring_regex: done with string transforms

* center

* lpad and rpad

* trim

* trim_whitespace

* ltrim

* rtrim

* rtrim_whitespace

* ltrim_whitespace

* slice

* feat: add `split_whitespace`

* test: add test for `split_whitespace`

* test: correct test

* feat: add `split_pattern`

* refactor: rename `_get_action`

* feat: add `ak_split_pattern_regex`

* test: update tests for new features

* Fixed UnmaskedArray._drop_none.

* fix: adjust for numexpr 2.8.5, which hid getContext's frame_depth argument (#2617)

* extract_regex.

* join (almost entirely from https://gist.github.com/agoose77/28e5bb0250678e454356a85861a16368)

* use dispatch correctly

* fix: drop unused arg

* join_element_wise

* Revert "use dispatch correctly"

This reverts commit 559073b.

* fix: broadcast `num_repeats`

* feat: add `count_substring[_pattern]`

* docs: fixup docstring

* feat: add `ends_with`

* feat: add `starts_with`

* docs: fix link

* feat: add `find_substring`

* docs: fix typo

* feat: add `find_substring_regex`

* docs: fix link

* feat: add `match_like`

* test: improve test

* feat: add `match_substring`, `match_substring_regex`

* feat: add `is_in` and `index_in`

* fix: operate at leaf depth

* refactor: add internal `pyarrow.compute` helper

* refactor: use pyarrow import helper

* refactor: add `module` and `name` arguments to `high_level_function`

* fix: pass `module` to str `high_level_function`

* docs: homogenize docstrings

* docs: add see also

* docs: include `ak.str` in toctree

* chore: update pre-commit hooks (#2619)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* refactor: cleanup error handling

* Rename ak_*.py modules -> akstr_*.py.

---------

Co-authored-by: Angus Hollands <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Aug 8, 2023
1 parent 85ca6d5 commit 1cfea2f
Show file tree
Hide file tree
Showing 57 changed files with 4,601 additions and 6 deletions.
1 change: 1 addition & 0 deletions docs/prepare_docstrings.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@ def dofunction(link, linelink, shortname, name, astfcn):
.replace(".behaviors.string", "")
)
shortname = re.sub(r"\.operations\.ak_\w+", "", shortname)
shortname = re.sub(r"\.operations\.str\.akstr_\w+", ".str", shortname)
shortname = re.sub(r"\.(contents|types|forms)\.\w+", r".\1", shortname)

if (
Expand Down
73 changes: 73 additions & 0 deletions docs/reference/toctree.txt
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,79 @@
generated/ak.argcartesian
generated/ak.argcombinations

.. toctree::
:caption: String predicates

generated/ak.str.is_alnum
generated/ak.str.is_alpha
generated/ak.str.is_ascii
generated/ak.str.is_decimal
generated/ak.str.is_digit
generated/ak.str.is_lower
generated/ak.str.is_numeric
generated/ak.str.is_printable
generated/ak.str.is_space
generated/ak.str.is_title
generated/ak.str.is_upper

.. toctree::
:caption: String transforms

generated/ak.str.capitalize
generated/ak.str.length
generated/ak.str.lower
generated/ak.str.repeat
generated/ak.str.replace_slice
generated/ak.str.replace_substring
generated/ak.str.replace_substring_regex
generated/ak.str.reverse
generated/ak.str.swapcase
generated/ak.str.title
generated/ak.str.upper

.. toctree::
:caption: String padding and trimming

generated/ak.str.center
generated/ak.str.lpad
generated/ak.str.rpad
generated/ak.str.ltrim
generated/ak.str.ltrim_whitespace
generated/ak.str.rtrim
generated/ak.str.rtrim_whitespace
generated/ak.str.trim
generated/ak.str.trim_whitespace

.. toctree::
:caption: String splitting and joining

generated/ak.str.split_pattern
generated/ak.str.split_pattern_regex
generated/ak.str.split_whitespace
generated/ak.str.join
generated/ak.str.join_element_wise

.. toctree::
:caption: String slicing and decomposition

generated/ak.str.slice
generated/ak.str.extract_regex

.. toctree::
:caption: String containment tests

generated/ak.str.count_substring
generated/ak.str.count_substring_regex
generated/ak.str.ends_with
generated/ak.str.find_substring
generated/ak.str.find_substring_regex
generated/ak.str.index_in
generated/ak.str.is_in
generated/ak.str.match_like
generated/ak.str.match_substring
generated/ak.str.match_substring_regex
generated/ak.str.starts_with

.. toctree::
:caption: Value and type conversions

Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,8 @@ mccabe.max-complexity = 100
"src/awkward/_connect/*" = ["TID251"]
"src/awkward/__init__.py" = ["E402", "F401", "F403", "I001"]
"src/awkward/_ext.py" = ["F401"]
"src/awkward/operations/__init__.py" = ["F403"]
"src/awkward/operations/__init__.py" = ["F401", "F403"]
"src/awkward/operations/str/__init__.py" = ["F401", "F403", "I001"]
"src/awkward/_nplikes/*" = ["TID251"]
"src/awkward/_operators.py" = ["TID251"]
"tests*/*" = ["T20", "TID251"]
Expand Down
17 changes: 14 additions & 3 deletions src/awkward/_connect/pyarrow.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE
from __future__ import annotations

import json
from collections.abc import Iterable, Sized
from types import ModuleType

from packaging.version import parse as parse_version

Expand Down Expand Up @@ -36,13 +38,13 @@
error_message = "pyarrow 7.0.0 or later required for {0}"


def import_pyarrow(name):
def import_pyarrow(name: str) -> ModuleType:
if pyarrow is None:
raise ImportError(error_message.format(name))
return pyarrow


def import_pyarrow_parquet(name):
def import_pyarrow_parquet(name: str) -> ModuleType:
if pyarrow is None:
raise ImportError(error_message.format(name))

Expand All @@ -51,7 +53,16 @@ def import_pyarrow_parquet(name):
return out


def import_fsspec(name):
def import_pyarrow_compute(name: str) -> ModuleType:
if pyarrow is None:
raise ImportError(error_message.format(name))

import pyarrow.compute as out

return out


def import_fsspec(name: str) -> ModuleType:
try:
import fsspec

Expand Down
2 changes: 1 addition & 1 deletion src/awkward/contents/unmaskedarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,7 @@ def _remove_structure(self, backend, options):
return [self]

def _drop_none(self) -> Content:
return self.to_ByteMaskedArray(True)._drop_none()
return self.content

def _recursively_apply(
self, action, behavior, depth, depth_context, lateral_context, options
Expand Down
2 changes: 1 addition & 1 deletion src/awkward/operations/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE
# ruff: noqa: F401

import awkward.operations.str
from awkward.operations.ak_all import *
from awkward.operations.ak_almost_equal import *
from awkward.operations.ak_any import *
Expand Down
205 changes: 205 additions & 0 deletions src/awkward/operations/str/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE

# https://arrow.apache.org/docs/python/api/compute.html#string-predicates

# string predicates
from awkward.operations.str.akstr_is_alnum import *
from awkward.operations.str.akstr_is_alpha import *
from awkward.operations.str.akstr_is_decimal import *
from awkward.operations.str.akstr_is_digit import *
from awkward.operations.str.akstr_is_lower import *
from awkward.operations.str.akstr_is_numeric import *
from awkward.operations.str.akstr_is_printable import *
from awkward.operations.str.akstr_is_space import *
from awkward.operations.str.akstr_is_upper import *
from awkward.operations.str.akstr_is_title import *
from awkward.operations.str.akstr_is_ascii import *

# string transforms
from awkward.operations.str.akstr_capitalize import *
from awkward.operations.str.akstr_length import *
from awkward.operations.str.akstr_lower import *
from awkward.operations.str.akstr_swapcase import *
from awkward.operations.str.akstr_title import *
from awkward.operations.str.akstr_upper import *
from awkward.operations.str.akstr_repeat import *
from awkward.operations.str.akstr_replace_slice import *
from awkward.operations.str.akstr_reverse import *
from awkward.operations.str.akstr_replace_substring import *
from awkward.operations.str.akstr_replace_substring_regex import *

# string padding
from awkward.operations.str.akstr_center import *
from awkward.operations.str.akstr_lpad import *
from awkward.operations.str.akstr_rpad import *

# string trimming
from awkward.operations.str.akstr_ltrim import *
from awkward.operations.str.akstr_ltrim_whitespace import *
from awkward.operations.str.akstr_rtrim import *
from awkward.operations.str.akstr_rtrim_whitespace import *
from awkward.operations.str.akstr_trim import *
from awkward.operations.str.akstr_trim_whitespace import *

# string splitting
from awkward.operations.str.akstr_split_whitespace import *
from awkward.operations.str.akstr_split_pattern import *
from awkward.operations.str.akstr_split_pattern_regex import *

# string component extraction

from awkward.operations.str.akstr_extract_regex import *

# string joining

from awkward.operations.str.akstr_join import *
from awkward.operations.str.akstr_join_element_wise import *

# string slicing

from awkward.operations.str.akstr_slice import *

# containment tests

from awkward.operations.str.akstr_count_substring import *
from awkward.operations.str.akstr_count_substring_regex import *
from awkward.operations.str.akstr_ends_with import *
from awkward.operations.str.akstr_find_substring import *
from awkward.operations.str.akstr_find_substring_regex import *
from awkward.operations.str.akstr_index_in import *
from awkward.operations.str.akstr_is_in import *
from awkward.operations.str.akstr_match_like import *
from awkward.operations.str.akstr_match_substring import *
from awkward.operations.str.akstr_match_substring_regex import *
from awkward.operations.str.akstr_starts_with import *


def _get_ufunc_action(
utf8_function,
ascii_function,
*args,
bytestring_to_string=False,
**kwargs,
):
from awkward.operations.ak_from_arrow import from_arrow
from awkward.operations.ak_to_arrow import to_arrow

def action(layout, **absorb):
if layout.is_list and layout.parameter("__array__") == "string":
return from_arrow(
utf8_function(to_arrow(layout, extensionarray=False), *args, **kwargs),
highlevel=False,
)

elif layout.is_list and layout.parameter("__array__") == "bytestring":
if bytestring_to_string:
out = from_arrow(
ascii_function(
to_arrow(
layout.copy(
content=layout.content.copy(
parameters={"__array__": "char"}
),
parameters={"__array__": "string"},
),
extensionarray=False,
),
*args,
**kwargs,
),
highlevel=False,
)
if out.is_list and out.parameter("__array__") == "string":
out = out.copy(
content=out.content.copy(parameters={"__array__": "byte"}),
parameters={"__array__": "bytestring"},
)
return out

else:
return from_arrow(
ascii_function(
to_arrow(layout, extensionarray=False), *args, **kwargs
),
highlevel=False,
)

return action


def _erase_list_option(layout):
from awkward.contents.unmaskedarray import UnmaskedArray

assert layout.is_list
if layout.content.is_option:
assert isinstance(layout.content, UnmaskedArray)
return layout.copy(content=layout.content.content)
else:
return layout


def _get_split_action(
utf8_function, ascii_function, *args, bytestring_to_string=False, **kwargs
):
from awkward.operations.ak_from_arrow import from_arrow
from awkward.operations.ak_to_arrow import to_arrow

def action(layout, **absorb):
if layout.is_list and layout.parameter("__array__") == "string":
return _erase_list_option(
from_arrow(
utf8_function(
to_arrow(layout, extensionarray=False),
*args,
**kwargs,
),
highlevel=False,
)
)

elif layout.is_list and layout.parameter("__array__") == "bytestring":
if bytestring_to_string:
out = _erase_list_option(
from_arrow(
ascii_function(
to_arrow(
layout.copy(
content=layout.content.copy(
parameters={"__array__": "char"}
),
parameters={"__array__": "string"},
),
extensionarray=False,
),
*args,
**kwargs,
),
highlevel=False,
)
)
assert out.is_list

assert (
out.content.is_list
and out.content.parameter("__array__") == "string"
)
return out.copy(
content=out.content.copy(
content=out.content.content.copy(
parameters={"__array__": "byte"}
),
parameters={"__array__": "bytestring"},
),
)

else:
return _erase_list_option(
from_arrow(
ascii_function(
to_arrow(layout, extensionarray=False), *args, **kwargs
),
highlevel=False,
)
)

return action
Loading

0 comments on commit 1cfea2f

Please sign in to comment.