Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 44 additions & 8 deletions datafusion/physical-expr-common/src/datum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@
use arrow::array::BooleanArray;
use arrow::array::{make_comparator, ArrayRef, Datum};
use arrow::buffer::NullBuffer;
use arrow::compute::SortOptions;
use arrow::compute::kernels::cmp::{
distinct, eq, gt, gt_eq, lt, lt_eq, neq, not_distinct,
};
use arrow::compute::{ilike, like, nilike, nlike, SortOptions};
use arrow::error::ArrowError;
use datafusion_common::DataFusionError;
use datafusion_common::{arrow_datafusion_err, internal_err};
Expand Down Expand Up @@ -53,22 +56,49 @@ pub fn apply(
}
}

/// Applies a binary [`Datum`] comparison kernel `f` to `lhs` and `rhs`
/// Applies a binary [`Datum`] comparison operator `op` to `lhs` and `rhs`
pub fn apply_cmp(
op: Operator,
lhs: &ColumnarValue,
rhs: &ColumnarValue,
f: impl Fn(&dyn Datum, &dyn Datum) -> Result<BooleanArray, ArrowError>,
) -> Result<ColumnarValue> {
apply(lhs, rhs, |l, r| Ok(Arc::new(f(l, r)?)))
if lhs.data_type().is_nested() {
apply_cmp_for_nested(op, lhs, rhs)
} else {
let f = match op {
Operator::Eq => eq,
Operator::NotEq => neq,
Operator::Lt => lt,
Operator::LtEq => lt_eq,
Operator::Gt => gt,
Operator::GtEq => gt_eq,
Operator::IsDistinctFrom => distinct,
Operator::IsNotDistinctFrom => not_distinct,

Operator::LikeMatch => like,
Operator::ILikeMatch => ilike,
Operator::NotLikeMatch => nlike,
Operator::NotILikeMatch => nilike,

_ => {
return internal_err!("Invalid compare operator: {}", op);
}
};

apply(lhs, rhs, |l, r| Ok(Arc::new(f(l, r)?)))
}
}

/// Applies a binary [`Datum`] comparison kernel `f` to `lhs` and `rhs` for nested type like
/// Applies a binary [`Datum`] comparison operator `op` to `lhs` and `rhs` for nested type like
/// List, FixedSizeList, LargeList, Struct, Union, Map, or a dictionary of a nested type
pub fn apply_cmp_for_nested(
op: Operator,
lhs: &ColumnarValue,
rhs: &ColumnarValue,
) -> Result<ColumnarValue> {
let left_data_type = lhs.data_type();
let right_data_type = rhs.data_type();

if matches!(
op,
Operator::Eq
Expand All @@ -79,12 +109,18 @@ pub fn apply_cmp_for_nested(
| Operator::GtEq
| Operator::IsDistinctFrom
| Operator::IsNotDistinctFrom
) {
) && left_data_type.equals_datatype(&right_data_type)
{
apply(lhs, rhs, |l, r| {
Ok(Arc::new(compare_op_for_nested(op, l, r)?))
})
} else {
internal_err!("invalid operator for nested")
internal_err!(
"invalid operator for nested data, op {} left {}, right {}",
op,
left_data_type,
right_data_type
)
}
}

Expand All @@ -97,7 +133,7 @@ pub fn compare_with_eq(
if is_nested {
compare_op_for_nested(Operator::Eq, lhs, rhs)
} else {
arrow::compute::kernels::cmp::eq(lhs, rhs).map_err(|e| arrow_datafusion_err!(e))
eq(lhs, rhs).map_err(|e| arrow_datafusion_err!(e))
}
}

Expand Down
103 changes: 65 additions & 38 deletions datafusion/physical-expr/src/expressions/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,27 @@

mod kernels;

use crate::expressions::binary::kernels::contains::{
collection_contains_all_strings_dyn, collection_contains_all_strings_dyn_scalar,
collection_contains_any_string_dyn, collection_contains_any_string_dyn_scalar,
collection_contains_dyn, collection_contains_dyn_scalar,
collection_contains_string_dyn, collection_contains_string_dyn_scalar,
};
use crate::expressions::binary::kernels::manipulate::{
collection_concat_dyn, collection_delete_key_dyn_scalar,
};
use crate::expressions::binary::kernels::select::{
cast_to_string_array, collection_select_dyn_scalar, collection_select_path_dyn_scalar,
};
use crate::intervals::cp_solver::{propagate_arithmetic, propagate_comparison};
use crate::PhysicalExpr;
use std::hash::Hash;
use std::{any::Any, sync::Arc};

use arrow::array::*;
use arrow::compute::kernels::boolean::{and_kleene, or_kleene};
use arrow::compute::kernels::cmp::*;
use arrow::compute::kernels::concat_elements::concat_elements_utf8;
use arrow::compute::{
cast, filter_record_batch, ilike, like, nilike, nlike, SlicesIterator,
};
use arrow::compute::{cast, filter_record_batch, SlicesIterator};
use arrow::datatypes::*;
use arrow::error::ArrowError;
use datafusion_common::cast::as_boolean_array;
Expand All @@ -42,7 +51,7 @@ use datafusion_expr::statistics::{
new_generic_from_binary_op, Distribution,
};
use datafusion_expr::{ColumnarValue, Operator};
use datafusion_physical_expr_common::datum::{apply, apply_cmp, apply_cmp_for_nested};
use datafusion_physical_expr_common::datum::{apply, apply_cmp};

use kernels::{
bitwise_and_dyn, bitwise_and_dyn_scalar, bitwise_or_dyn, bitwise_or_dyn_scalar,
Expand Down Expand Up @@ -251,34 +260,31 @@ impl PhysicalExpr for BinaryExpr {
let schema = batch.schema();
let input_schema = schema.as_ref();

if left_data_type.is_nested() {
if !left_data_type.equals_datatype(&right_data_type) {
return internal_err!("Cannot evaluate binary expression because of type mismatch: left {}, right {} ", left_data_type, right_data_type);
}
return apply_cmp_for_nested(self.op, &lhs, &rhs);
}

match self.op {
Operator::Plus if self.fail_on_overflow => return apply(&lhs, &rhs, add),
Operator::Plus => return apply(&lhs, &rhs, add_wrapping),
// TODO: exclude nested types
Operator::Minus if self.fail_on_overflow => return apply(&lhs, &rhs, sub),
Operator::Minus => return apply(&lhs, &rhs, sub_wrapping),
Operator::Multiply if self.fail_on_overflow => return apply(&lhs, &rhs, mul),
Operator::Multiply => return apply(&lhs, &rhs, mul_wrapping),
Operator::Divide => return apply(&lhs, &rhs, div),
Operator::Modulo => return apply(&lhs, &rhs, rem),
Operator::Eq => return apply_cmp(&lhs, &rhs, eq),
Operator::NotEq => return apply_cmp(&lhs, &rhs, neq),
Operator::Lt => return apply_cmp(&lhs, &rhs, lt),
Operator::Gt => return apply_cmp(&lhs, &rhs, gt),
Operator::LtEq => return apply_cmp(&lhs, &rhs, lt_eq),
Operator::GtEq => return apply_cmp(&lhs, &rhs, gt_eq),
Operator::IsDistinctFrom => return apply_cmp(&lhs, &rhs, distinct),
Operator::IsNotDistinctFrom => return apply_cmp(&lhs, &rhs, not_distinct),
Operator::LikeMatch => return apply_cmp(&lhs, &rhs, like),
Operator::ILikeMatch => return apply_cmp(&lhs, &rhs, ilike),
Operator::NotLikeMatch => return apply_cmp(&lhs, &rhs, nlike),
Operator::NotILikeMatch => return apply_cmp(&lhs, &rhs, nilike),

Operator::Eq
| Operator::NotEq
| Operator::Lt
| Operator::Gt
| Operator::LtEq
| Operator::GtEq
| Operator::IsDistinctFrom
| Operator::IsNotDistinctFrom
| Operator::LikeMatch
| Operator::ILikeMatch
| Operator::NotLikeMatch
| Operator::NotILikeMatch => {
return apply_cmp(self.op, &lhs, &rhs);
}
_ => {}
}

Expand All @@ -290,7 +296,7 @@ impl PhysicalExpr for BinaryExpr {
{
if !scalar.is_null() {
if let Some(result_array) =
self.evaluate_array_scalar(array, scalar.clone())?
self.evaluate_array_scalar(Arc::clone(array), scalar.clone())?
{
let final_array = result_array
.and_then(|a| to_result_type_array(&self.op, a, &result_type));
Expand Down Expand Up @@ -575,20 +581,32 @@ impl BinaryExpr {
/// right is literal - use scalar operations
fn evaluate_array_scalar(
&self,
array: &dyn Array,
array: Arc<dyn Array>,
scalar: ScalarValue,
) -> Result<Option<Result<ArrayRef>>> {
use Operator::*;
let scalar_result = match &self.op {
RegexMatch => regex_match_dyn_scalar(array, scalar, false, false),
RegexIMatch => regex_match_dyn_scalar(array, scalar, false, true),
RegexNotMatch => regex_match_dyn_scalar(array, scalar, true, false),
RegexNotIMatch => regex_match_dyn_scalar(array, scalar, true, true),
BitwiseAnd => bitwise_and_dyn_scalar(array, scalar),
BitwiseOr => bitwise_or_dyn_scalar(array, scalar),
BitwiseXor => bitwise_xor_dyn_scalar(array, scalar),
BitwiseShiftRight => bitwise_shift_right_dyn_scalar(array, scalar),
BitwiseShiftLeft => bitwise_shift_left_dyn_scalar(array, scalar),
RegexMatch => regex_match_dyn_scalar(&array, scalar, false, false),
RegexIMatch => regex_match_dyn_scalar(&array, scalar, false, true),
RegexNotMatch => regex_match_dyn_scalar(&array, scalar, true, false),
RegexNotIMatch => regex_match_dyn_scalar(&array, scalar, true, true),
BitwiseAnd => bitwise_and_dyn_scalar(&array, scalar),
BitwiseOr => bitwise_or_dyn_scalar(&array, scalar),
BitwiseXor => bitwise_xor_dyn_scalar(&array, scalar),
BitwiseShiftRight => bitwise_shift_right_dyn_scalar(&array, scalar),
BitwiseShiftLeft => bitwise_shift_left_dyn_scalar(&array, scalar),
Arrow => collection_select_dyn_scalar(&array, scalar),
LongArrow => collection_select_dyn_scalar(&array, scalar)
.map(|arr| arr.and_then(cast_to_string_array)),
HashArrow => collection_select_path_dyn_scalar(array, scalar),
HashLongArrow => collection_select_path_dyn_scalar(array, scalar)
.map(|arr| arr.and_then(cast_to_string_array)),
AtArrow => collection_contains_dyn_scalar(&array, scalar),
// TODO: ArrowAt
Question => collection_contains_string_dyn_scalar(&array, scalar),
QuestionPipe => collection_contains_any_string_dyn_scalar(&array, scalar),
QuestionAnd => collection_contains_all_strings_dyn_scalar(&array, scalar),
Minus => collection_delete_key_dyn_scalar(&array, scalar),
// if scalar operation is not supported - fallback to array implementation
_ => None,
};
Expand Down Expand Up @@ -623,6 +641,11 @@ impl BinaryExpr {
Or => {
if left_data_type == &DataType::Boolean {
Ok(boolean_op(&left, &right, or_kleene)?)
} else if matches!(
left_data_type,
DataType::List(_) | DataType::Struct(_)
) {
collection_concat_dyn(left, right)
} else {
internal_err!(
"Cannot evaluate binary expression {:?} with types {:?} and {:?}",
Expand All @@ -642,9 +665,13 @@ impl BinaryExpr {
BitwiseShiftRight => bitwise_shift_right_dyn(left, right),
BitwiseShiftLeft => bitwise_shift_left_dyn(left, right),
StringConcat => concat_elements(left, right),
AtArrow | ArrowAt | Arrow | LongArrow | HashArrow | HashLongArrow | AtAt
| HashMinus | AtQuestion | Question | QuestionAnd | QuestionPipe
| IntegerDivide => {
AtArrow => collection_contains_dyn(left, right),
ArrowAt => collection_contains_dyn(right, left),
Question => collection_contains_string_dyn(left, right),
QuestionPipe => collection_contains_any_string_dyn(left, right),
QuestionAnd => collection_contains_all_strings_dyn(left, right),
Arrow | LongArrow | HashArrow | HashLongArrow | AtAt | HashMinus
| AtQuestion | IntegerDivide => {
not_impl_err!(
"Binary operator '{:?}' is not supported in the physical expr",
self.op
Expand Down
4 changes: 4 additions & 0 deletions datafusion/physical-expr/src/expressions/binary/kernels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ use datafusion_common::{Result, ScalarValue};

use std::sync::Arc;

pub mod contains;
pub mod manipulate;
pub mod select;

/// Downcasts $LEFT and $RIGHT to $ARRAY_TYPE and then calls $KERNEL($LEFT, $RIGHT)
macro_rules! call_kernel {
($LEFT:expr, $RIGHT:expr, $KERNEL:expr, $ARRAY_TYPE:ident) => {{
Expand Down
Loading