Skip to content

Commit

Permalink
Add tests for contains (scalar and column_view)
Browse files Browse the repository at this point in the history
  • Loading branch information
Matt711 committed Jun 13, 2024
1 parent afb4061 commit 0a98c7a
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 26 deletions.
4 changes: 2 additions & 2 deletions python/cudf/cudf/_lib/lists.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -152,11 +152,11 @@ def extract_element_column(Column col, Column index):


@acquire_spill_lock()
def contains_scalar(Column col, object py_search_key):
def contains_scalar(Column col, py_search_key):
return Column.from_pylibcudf(
pylibcudf.lists.contains(
col.to_pylibcudf(mode="read"),
py_search_key,
py_search_key.device_value,
)
)

Expand Down
6 changes: 2 additions & 4 deletions python/cudf/cudf/_lib/pylibcudf/lists.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from libcpp cimport bool

from cudf._lib.pylibcudf.libcudf.types cimport size_type
from cudf._lib.scalar cimport DeviceScalar

from .column cimport Column
from .scalar cimport Scalar
Expand All @@ -11,6 +12,7 @@ from .table cimport Table
ctypedef fused ColumnOrScalar:
Column
Scalar
DeviceScalar

cpdef Table explode_outer(Table, size_type explode_column_idx)

Expand All @@ -23,7 +25,3 @@ cpdef Column contains(Column, ColumnOrScalar)
# cpdef Column contains_nulls(Column)

# ctypedef Column index_of(Column, ColumnOrScalar)

# from cudf._lib.pylibcudf.libcudf.binaryop import \
# binary_operator as BinaryOperator # no-cython-lint
# from cudf._lib.pylibcudf.libcudf.lists.contains cimport duplicate_find_option
35 changes: 17 additions & 18 deletions python/cudf/cudf/_lib/pylibcudf/lists.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,6 @@ from cudf._lib.pylibcudf.libcudf.lists.lists_column_view cimport (
from cudf._lib.pylibcudf.libcudf.scalar.scalar cimport scalar
from cudf._lib.pylibcudf.libcudf.table.table cimport table
from cudf._lib.pylibcudf.libcudf.types cimport size_type

from cudf._lib.pylibcudf.libcudf.lists.contains import \
duplicate_find_option as DuplicateFindOption # no-cython-lint

from cudf._lib.scalar cimport DeviceScalar

from .column cimport Column
Expand Down Expand Up @@ -107,15 +103,10 @@ cpdef Column concatenate_list_elements(Column input, bool dropna):

return Column.from_libcudf(move(c_result))


cpdef Column contains(Column input, ColumnOrScalar search_key):
"""Create a column of bool values based upon the search key.
``search_key`` may be a
:py:class:`~cudf._lib.pylibcudf.column.Column` or a
:py:class:`~cudf._lib.pylibcudf.scalar.Scalar`.
For details, see :cpp:func:`contains`.
Parameters
----------
input : Column
Expand All @@ -132,18 +123,26 @@ cpdef Column contains(Column input, ColumnOrScalar search_key):
cdef shared_ptr[lists_column_view] list_view = (
make_shared[lists_column_view](input.view())
)
cdef const scalar* search_key_value = NULL

if ColumnOrScalar is Column:
with nogil:
c_result = move(cpp_contains.contains(
list_view.get()[0],
search_key.view(),
))
return Column.from_libcudf(move(c_result))
cdef DeviceScalar key = search_key.device_value
cdef const scalar* key_value = key.get_raw_ptr()
with nogil:
c_result = move(cpp_contains.contains(
list_view.get()[0],
key_value[0],
))
elif ColumnOrScalar is DeviceScalar:
search_key_value = search_key.get_raw_ptr()
with nogil:
c_result = move(cpp_contains.contains(
list_view.get()[0],
search_key_value[0],
))
else:
search_key_value = search_key.get()
with nogil:
c_result = move(cpp_contains.contains(
list_view.get()[0],
search_key_value[0],
))
return Column.from_libcudf(move(c_result))
29 changes: 27 additions & 2 deletions python/cudf/cudf/pylibcudf_tests/test_lists.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,5 +46,30 @@ def test_concatenate_list_elements(test_data, dropna, expected):
assert_column_eq(expect, res)


def test_contains():
pass
def test_contains_scalar():
list_column = [[1, 2], [1, 3, 4], [5, 6]]
arr = pa.array(list_column)
scalar = pa.scalar(1)

plc_column = plc.interop.from_arrow(arr)
plc_scalar = plc.interop.from_arrow(scalar)
res = plc.lists.contains(plc_column, plc_scalar)

expect = pa.array([True, True, False])

assert_column_eq(expect, res)


def test_contains_list_column():
list_column1 = [[1, 2], [1, 3, 4], [5, 6]]
list_column2 = [1, 3, 6]
arr1 = pa.array(list_column1)
arr2 = pa.array(list_column2)

plc_column1 = plc.interop.from_arrow(arr1)
plc_column2 = plc.interop.from_arrow(arr2)
res = plc.lists.contains(plc_column1, plc_column2)

expect = pa.array([True, True, True])

assert_column_eq(expect, res)

0 comments on commit 0a98c7a

Please sign in to comment.