Skip to content

Commit 68bb513

Browse files
committed
GH-45167: [C++] Implement Compute Equals for List Types
1 parent e434536 commit 68bb513

File tree

3 files changed

+150
-0
lines changed

3 files changed

+150
-0
lines changed

Diff for: cpp/src/arrow/compute/kernels/codegen_internal.h

+49
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,13 @@ struct GetViewType<Type, enable_if_t<is_base_binary_type<Type>::value ||
141141
static T LogicalValue(PhysicalType value) { return value; }
142142
};
143143

144+
template <typename Type>
145+
struct GetViewType<Type, enable_if_list_type<Type>> {
146+
using T = typename TypeTraits<Type>::ScalarType;
147+
148+
static T LogicalValue(T value) { return value; }
149+
};
150+
144151
template <>
145152
struct GetViewType<Decimal32Type> {
146153
using T = Decimal32;
@@ -322,6 +329,26 @@ struct ArrayIterator<Type, enable_if_base_binary<Type>> {
322329
}
323330
};
324331

332+
template <typename Type>
333+
struct ArrayIterator<Type, enable_if_list_type<Type>> {
334+
using T = typename TypeTraits<Type>::ScalarType;
335+
using ArrayT = typename TypeTraits<Type>::ArrayType;
336+
using offset_type = typename Type::offset_type;
337+
338+
const ArraySpan& arr;
339+
int64_t position;
340+
341+
explicit ArrayIterator(const ArraySpan& arr) : arr(arr), position(0) {}
342+
343+
T operator()() {
344+
const auto array_ptr = arr.ToArray();
345+
const auto array = checked_cast<const ArrayT*>(array_ptr.get());
346+
347+
T result{array->value_slice(position++)};
348+
return result;
349+
}
350+
};
351+
325352
template <>
326353
struct ArrayIterator<FixedSizeBinaryType> {
327354
const ArraySpan& arr;
@@ -390,6 +417,12 @@ struct UnboxScalar<Type, enable_if_has_string_view<Type>> {
390417
}
391418
};
392419

420+
template <typename Type>
421+
struct UnboxScalar<Type, enable_if_list_type<Type>> {
422+
using T = typename TypeTraits<Type>::ScalarType;
423+
static const T& Unbox(const Scalar& val) { return checked_cast<const T&>(val); }
424+
};
425+
393426
template <>
394427
struct UnboxScalar<Decimal32Type> {
395428
using T = Decimal32;
@@ -1383,6 +1416,22 @@ ArrayKernelExec GenerateDecimal(detail::GetTypeId get_id) {
13831416
}
13841417
}
13851418

1419+
// Generate a kernel given a templated functor for list types
1420+
//
1421+
// See "Numeric" above for description of the generator functor
1422+
template <template <typename...> class Generator, typename Type0, typename... Args>
1423+
ArrayKernelExec GenerateList(detail::GetTypeId get_id) {
1424+
switch (get_id.id) {
1425+
case Type::LIST:
1426+
return Generator<Type0, ListType, Args...>::Exec;
1427+
case Type::LARGE_LIST:
1428+
return Generator<Type0, LargeListType, Args...>::Exec;
1429+
default:
1430+
DCHECK(false);
1431+
return nullptr;
1432+
}
1433+
}
1434+
13861435
// END of kernel generator-dispatchers
13871436
// ----------------------------------------------------------------------
13881437
// BEGIN of DispatchBest helpers

Diff for: cpp/src/arrow/compute/kernels/scalar_compare.cc

+8
Original file line numberDiff line numberDiff line change
@@ -445,6 +445,14 @@ std::shared_ptr<ScalarFunction> MakeCompareFunction(std::string name, FunctionDo
445445
DCHECK_OK(func->AddKernel({ty, ty}, boolean(), std::move(exec)));
446446
}
447447

448+
if constexpr (std::is_same_v<Op, Equal> || std::is_same_v<Op, NotEqual>) {
449+
for (const auto id : {Type::LIST, Type::LARGE_LIST}) {
450+
auto exec = GenerateList<applicator::ScalarBinaryEqualTypes, BooleanType, Op>(id);
451+
DCHECK_OK(
452+
func->AddKernel({InputType(id), InputType(id)}, boolean(), std::move(exec)));
453+
}
454+
}
455+
448456
return func;
449457
}
450458

Diff for: cpp/src/arrow/compute/kernels/scalar_compare_test.cc

+93
Original file line numberDiff line numberDiff line change
@@ -680,6 +680,99 @@ TYPED_TEST(TestCompareDecimal, DifferentParameters) {
680680
}
681681
}
682682

683+
template <typename ArrowType>
684+
class TestCompareList : public ::testing::Test {};
685+
TYPED_TEST_SUITE(TestCompareList, ListArrowTypes);
686+
687+
TYPED_TEST(TestCompareList, ArrayScalar) {
688+
const auto int_value_typ = std::make_shared<Int32Type>();
689+
const auto int_ty = std::make_shared<TypeParam>(std::move(int_value_typ));
690+
const auto bin_value_typ = std::make_shared<StringType>();
691+
const auto bin_ty = std::make_shared<TypeParam>(std::move(bin_value_typ));
692+
693+
const std::vector<std::pair<std::string, std::string>> cases = {
694+
{"equal", "[1, 0, 0, null]"},
695+
{"not_equal", "[0, 1, 1, null]"},
696+
};
697+
const auto lhs_int = ArrayFromJSON(int_ty, R"([[1, 2, 3], [4, 5, 6], [42], null])");
698+
const auto lhs_bin = ArrayFromJSON(
699+
bin_ty, R"([["a", "b", "c"], ["foo", "bar", "baz"], ["hello"], null])");
700+
const auto rhs_int = ScalarFromJSON(int_ty, R"([1, 2, 3])");
701+
const auto rhs_bin = ScalarFromJSON(bin_ty, R"(["a", "b", "c"])");
702+
for (const auto& op : cases) {
703+
const auto& function = op.first;
704+
const auto& expected = op.second;
705+
706+
SCOPED_TRACE(function);
707+
CheckScalarBinary(function, lhs_int, rhs_int, ArrayFromJSON(boolean(), expected));
708+
CheckScalarBinary(function, lhs_bin, rhs_bin, ArrayFromJSON(boolean(), expected));
709+
}
710+
}
711+
712+
TYPED_TEST(TestCompareList, ScalarArray) {
713+
const auto int_value_typ = std::make_shared<Int32Type>();
714+
const auto int_ty = std::make_shared<TypeParam>(std::move(int_value_typ));
715+
const auto bin_value_typ = std::make_shared<StringType>();
716+
const auto bin_ty = std::make_shared<TypeParam>(std::move(bin_value_typ));
717+
718+
const std::vector<std::pair<std::string, std::string>> cases = {
719+
{"equal", "[1, 0, 0, null]"},
720+
{"not_equal", "[0, 1, 1, null]"},
721+
};
722+
const auto lhs_int = ScalarFromJSON(int_ty, R"([1, 2, 3])");
723+
const auto lhs_bin = ScalarFromJSON(bin_ty, R"(["a", "b", "c"])");
724+
const auto rhs_int = ArrayFromJSON(int_ty, R"([[1, 2, 3], [4, 5, 6], [42], null])");
725+
const auto rhs_bin = ArrayFromJSON(
726+
bin_ty, R"([["a", "b", "c"], ["foo", "bar"], ["baz", "hello", "world"], null])");
727+
for (const auto& op : cases) {
728+
const auto& function = op.first;
729+
const auto& expected = op.second;
730+
731+
SCOPED_TRACE(function);
732+
CheckScalarBinary(function, lhs_int, rhs_int, ArrayFromJSON(boolean(), expected));
733+
CheckScalarBinary(function, lhs_bin, rhs_bin, ArrayFromJSON(boolean(), expected));
734+
}
735+
}
736+
737+
TYPED_TEST(TestCompareList, ArrayArray) {
738+
const auto int_value_typ = std::make_shared<Int32Type>();
739+
const auto int_ty = std::make_shared<TypeParam>(std::move(int_value_typ));
740+
const auto bin_value_typ = std::make_shared<StringType>();
741+
const auto bin_ty = std::make_shared<TypeParam>(std::move(bin_value_typ));
742+
743+
const std::vector<std::pair<std::string, std::string>> cases = {
744+
{"equal", "[1, 0, 0, null]"},
745+
{"not_equal", "[0, 1, 1, null]"},
746+
};
747+
const auto lhs_int = ArrayFromJSON(int_ty, R"([[1, 2, 3], [4, 5, 6], [7], null])");
748+
const auto lhs_bin = ArrayFromJSON(
749+
bin_ty, R"([["a", "b", "c"], ["foo", "bar", "baz"], ["hello"], null])");
750+
const auto rhs_int = ArrayFromJSON(int_ty, R"([[1, 2, 3], [4, 5], [6, 7, 8], null])");
751+
const auto rhs_bin = ArrayFromJSON(
752+
bin_ty, R"([["a", "b", "c"], ["foo", "bar"], ["baz", "hello", "world"], null])");
753+
for (const auto& op : cases) {
754+
const auto& function = op.first;
755+
const auto& expected = op.second;
756+
757+
SCOPED_TRACE(function);
758+
CheckScalarBinary(function, ArrayFromJSON(int_ty, R"([])"),
759+
ArrayFromJSON(int_ty, R"([])"), ArrayFromJSON(boolean(), "[]"));
760+
CheckScalarBinary(function, ArrayFromJSON(int_ty, R"([null])"),
761+
ArrayFromJSON(int_ty, R"([null])"),
762+
ArrayFromJSON(boolean(), "[null]"));
763+
764+
CheckScalarBinary(function, lhs_int, rhs_int, ArrayFromJSON(boolean(), expected));
765+
766+
CheckScalarBinary(function, ArrayFromJSON(bin_ty, R"([])"),
767+
ArrayFromJSON(int_ty, R"([])"), ArrayFromJSON(boolean(), "[]"));
768+
CheckScalarBinary(function, ArrayFromJSON(int_ty, R"([null])"),
769+
ArrayFromJSON(bin_ty, R"([null])"),
770+
ArrayFromJSON(boolean(), "[null]"));
771+
772+
CheckScalarBinary(function, lhs_bin, rhs_bin, ArrayFromJSON(boolean(), expected));
773+
}
774+
}
775+
683776
// Helper to organize tests for fixed size binary comparisons
684777
struct CompareCase {
685778
std::shared_ptr<DataType> lhs_type;

0 commit comments

Comments
 (0)