diff --git a/datafusion/physical-expr-common/src/physical_expr.rs b/datafusion/physical-expr-common/src/physical_expr.rs index 6f7c432c7582..e5e7d6c00f08 100644 --- a/datafusion/physical-expr-common/src/physical_expr.rs +++ b/datafusion/physical-expr-common/src/physical_expr.rs @@ -23,14 +23,14 @@ use std::sync::Arc; use crate::utils::scatter; -use arrow::array::{ArrayRef, BooleanArray}; +use arrow::array::{new_empty_array, ArrayRef, BooleanArray}; use arrow::compute::filter_record_batch; use arrow::datatypes::{DataType, Field, FieldRef, Schema}; use arrow::record_batch::RecordBatch; use datafusion_common::tree_node::{ Transformed, TransformedResult, TreeNode, TreeNodeRecursion, }; -use datafusion_common::{internal_err, not_impl_err, Result, ScalarValue}; +use datafusion_common::{exec_err, internal_err, not_impl_err, Result, ScalarValue}; use datafusion_expr_common::columnar_value::ColumnarValue; use datafusion_expr_common::interval_arithmetic::Interval; use datafusion_expr_common::sort_properties::ExprProperties; @@ -90,36 +90,69 @@ pub trait PhysicalExpr: Any + Send + Sync + Display + Debug + DynEq + DynHash { self.nullable(input_schema)?, ))) } - /// Evaluate an expression against a RecordBatch after first applying a - /// validity array + /// Evaluate an expression against a RecordBatch after first applying a validity array + /// + /// # Errors + /// + /// Returns an `Err` if the expression could not be evaluated or if the length of the + /// `selection` validity array and the number of row in `batch` is not equal. fn evaluate_selection( &self, batch: &RecordBatch, selection: &BooleanArray, ) -> Result { - let tmp_batch = filter_record_batch(batch, selection)?; - - let tmp_result = self.evaluate(&tmp_batch)?; - - if batch.num_rows() == tmp_batch.num_rows() { - // All values from the `selection` filter are true. - Ok(tmp_result) - } else if let ColumnarValue::Array(a) = tmp_result { - scatter(selection, a.as_ref()).map(ColumnarValue::Array) - } else if let ColumnarValue::Scalar(ScalarValue::Boolean(value)) = &tmp_result { - // When the scalar is true or false, skip the scatter process - if let Some(v) = value { - if *v { - Ok(ColumnarValue::from(Arc::new(selection.clone()) as ArrayRef)) + let row_count = batch.num_rows(); + if row_count != selection.len() { + return exec_err!("Selection array length does not match batch row count: {} != {row_count}", selection.len()); + } + + let selection_count = selection.true_count(); + + // First, check if we can avoid filtering altogether. + if selection_count == row_count { + // All values from the `selection` filter are true and match the input batch. + // No need to perform any filtering. + return self.evaluate(batch); + } + + // Next, prepare the result array for each 'true' row in the selection vector. + let filtered_result = if selection_count == 0 { + // Do not call `evaluate` when the selection is empty. + // `evaluate_selection` is used to conditionally evaluate expressions. + // When the expression in question is fallible, evaluating it with an empty + // record batch may trigger a runtime error (e.g. division by zero). + // + // Instead, create an empty array matching the expected return type. + let datatype = self.data_type(batch.schema_ref().as_ref())?; + ColumnarValue::Array(new_empty_array(&datatype)) + } else { + // If we reach this point, there's no other option than to filter the batch. + // This is a fairly costly operation since it requires creating partial copies + // (worst case of length `row_count - 1`) of all the arrays in the record batch. + // The resulting `filtered_batch` will contain `selection_count` rows. + let filtered_batch = filter_record_batch(batch, selection)?; + self.evaluate(&filtered_batch)? + }; + + // Finally, scatter the filtered result array so that the indices match the input rows again. + match &filtered_result { + ColumnarValue::Array(a) => { + scatter(selection, a.as_ref()).map(ColumnarValue::Array) + } + ColumnarValue::Scalar(ScalarValue::Boolean(value)) => { + // When the scalar is true or false, skip the scatter process + if let Some(v) = value { + if *v { + Ok(ColumnarValue::from(Arc::new(selection.clone()) as ArrayRef)) + } else { + Ok(filtered_result) + } } else { - Ok(tmp_result) + let array = BooleanArray::from(vec![None; row_count]); + scatter(selection, &array).map(ColumnarValue::Array) } - } else { - let array = BooleanArray::from(vec![None; batch.num_rows()]); - scatter(selection, &array).map(ColumnarValue::Array) } - } else { - Ok(tmp_result) + ColumnarValue::Scalar(_) => Ok(filtered_result), } } @@ -601,3 +634,190 @@ pub fn is_volatile(expr: &Arc) -> bool { .expect("infallible closure should not fail"); is_volatile } + +#[cfg(test)] +mod test { + use crate::physical_expr::PhysicalExpr; + use arrow::array::{Array, BooleanArray, Int64Array, RecordBatch}; + use arrow::datatypes::{DataType, Schema}; + use datafusion_expr_common::columnar_value::ColumnarValue; + use std::fmt::{Display, Formatter}; + use std::sync::Arc; + + #[derive(Debug, PartialEq, Eq, Hash)] + struct TestExpr {} + + impl PhysicalExpr for TestExpr { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn data_type(&self, _schema: &Schema) -> datafusion_common::Result { + Ok(DataType::Int64) + } + + fn nullable(&self, _schema: &Schema) -> datafusion_common::Result { + Ok(false) + } + + fn evaluate( + &self, + batch: &RecordBatch, + ) -> datafusion_common::Result { + let data = vec![1; batch.num_rows()]; + Ok(ColumnarValue::Array(Arc::new(Int64Array::from(data)))) + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn with_new_children( + self: Arc, + _children: Vec>, + ) -> datafusion_common::Result> { + Ok(Arc::new(Self {})) + } + + fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.write_str("TestExpr") + } + } + + impl Display for TestExpr { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + self.fmt_sql(f) + } + } + + macro_rules! assert_arrays_eq { + ($EXPECTED: expr, $ACTUAL: expr, $MESSAGE: expr) => { + let expected = $EXPECTED.to_array(1).unwrap(); + let actual = $ACTUAL; + + let actual_array = actual.to_array(expected.len()).unwrap(); + let actual_ref = actual_array.as_ref(); + let expected_ref = expected.as_ref(); + assert!( + actual_ref == expected_ref, + "{}: expected: {:?}, actual: {:?}", + $MESSAGE, + $EXPECTED, + actual_ref + ); + }; + } + + fn test_evaluate_selection( + batch: &RecordBatch, + selection: &BooleanArray, + expected: &ColumnarValue, + ) { + let expr = TestExpr {}; + + // First check that the `evaluate_selection` is the expected one + let selection_result = expr.evaluate_selection(batch, selection).unwrap(); + assert_eq!( + expected.to_array(1).unwrap().len(), + selection_result.to_array(1).unwrap().len(), + "evaluate_selection should output row count should match input record batch" + ); + assert_arrays_eq!( + expected, + &selection_result, + "evaluate_selection returned unexpected value" + ); + + // If we're selecting all rows, the result should be the same as calling `evaluate` + // with the full record batch. + if (0..batch.num_rows()) + .all(|row_idx| row_idx < selection.len() && selection.value(row_idx)) + { + let empty_result = expr.evaluate(batch).unwrap(); + + assert_arrays_eq!( + empty_result, + &selection_result, + "evaluate_selection does not match unfiltered evaluate result" + ); + } + } + + fn test_evaluate_selection_error(batch: &RecordBatch, selection: &BooleanArray) { + let expr = TestExpr {}; + + // First check that the `evaluate_selection` is the expected one + let selection_result = expr.evaluate_selection(batch, selection); + assert!(selection_result.is_err(), "evaluate_selection should fail"); + } + + #[test] + pub fn test_evaluate_selection_with_empty_record_batch() { + test_evaluate_selection( + &RecordBatch::new_empty(Arc::new(Schema::empty())), + &BooleanArray::from(vec![false; 0]), + &ColumnarValue::Array(Arc::new(Int64Array::new_null(0))), + ); + } + + #[test] + pub fn test_evaluate_selection_with_empty_record_batch_with_larger_false_selection() { + test_evaluate_selection_error( + &RecordBatch::new_empty(Arc::new(Schema::empty())), + &BooleanArray::from(vec![false; 10]), + ); + } + + #[test] + pub fn test_evaluate_selection_with_empty_record_batch_with_larger_true_selection() { + test_evaluate_selection_error( + &RecordBatch::new_empty(Arc::new(Schema::empty())), + &BooleanArray::from(vec![true; 10]), + ); + } + + #[test] + pub fn test_evaluate_selection_with_non_empty_record_batch() { + test_evaluate_selection( + unsafe { &RecordBatch::new_unchecked(Arc::new(Schema::empty()), vec![], 10) }, + &BooleanArray::from(vec![true; 10]), + &ColumnarValue::Array(Arc::new(Int64Array::from(vec![1; 10]))), + ); + } + + #[test] + pub fn test_evaluate_selection_with_non_empty_record_batch_with_larger_false_selection( + ) { + test_evaluate_selection_error( + unsafe { &RecordBatch::new_unchecked(Arc::new(Schema::empty()), vec![], 10) }, + &BooleanArray::from(vec![false; 20]), + ); + } + + #[test] + pub fn test_evaluate_selection_with_non_empty_record_batch_with_larger_true_selection( + ) { + test_evaluate_selection_error( + unsafe { &RecordBatch::new_unchecked(Arc::new(Schema::empty()), vec![], 10) }, + &BooleanArray::from(vec![true; 20]), + ); + } + + #[test] + pub fn test_evaluate_selection_with_non_empty_record_batch_with_smaller_false_selection( + ) { + test_evaluate_selection_error( + unsafe { &RecordBatch::new_unchecked(Arc::new(Schema::empty()), vec![], 10) }, + &BooleanArray::from(vec![false; 5]), + ); + } + + #[test] + pub fn test_evaluate_selection_with_non_empty_record_batch_with_smaller_true_selection( + ) { + test_evaluate_selection_error( + unsafe { &RecordBatch::new_unchecked(Arc::new(Schema::empty()), vec![], 10) }, + &BooleanArray::from(vec![true; 5]), + ); + } +} diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 5409cfe8e7e4..d14146a20d8b 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -155,10 +155,7 @@ impl CaseExpr { && else_expr.as_ref().unwrap().as_any().is::() { EvalMethod::ScalarOrScalar - } else if when_then_expr.len() == 1 - && is_cheap_and_infallible(&(when_then_expr[0].1)) - && else_expr.as_ref().is_some_and(is_cheap_and_infallible) - { + } else if when_then_expr.len() == 1 && else_expr.is_some() { EvalMethod::ExpressionOrExpression } else { EvalMethod::NoExpression @@ -425,6 +422,16 @@ impl CaseExpr { ) })?; + // For the true and false/null selection vectors, bypass `evaluate_selection` and merging + // results. This avoids materializing the array for the other branch which we will discard + // entirely anyway. + let true_count = when_value.true_count(); + if true_count == batch.num_rows() { + return self.when_then_expr[0].1.evaluate(batch); + } else if true_count == 0 { + return self.else_expr.as_ref().unwrap().evaluate(batch); + } + // Treat 'NULL' as false value let when_value = match when_value.null_count() { 0 => Cow::Borrowed(when_value), diff --git a/datafusion/sqllogictest/test_files/case.slt b/datafusion/sqllogictest/test_files/case.slt index 69f80f459394..9bc1f83ed119 100644 --- a/datafusion/sqllogictest/test_files/case.slt +++ b/datafusion/sqllogictest/test_files/case.slt @@ -467,6 +467,7 @@ FROM t; ---- [{foo: blarg}] +# mix of then and else query II SELECT v, CASE WHEN v != 0 THEN 10/v ELSE 42 END FROM (VALUES (0), (1), (2)) t(v) ---- @@ -474,12 +475,38 @@ SELECT v, CASE WHEN v != 0 THEN 10/v ELSE 42 END FROM (VALUES (0), (1), (2)) t(v 1 10 2 5 +# when expressions is always false, then branch should never be evaluated query II SELECT v, CASE WHEN v < 0 THEN 10/0 ELSE 1 END FROM (VALUES (1), (2)) t(v) ---- 1 1 2 1 +# when expressions is always true, else branch should never be evaluated +query II +SELECT v, CASE WHEN v > 0 THEN 1 ELSE 10/0 END FROM (VALUES (1), (2)) t(v) +---- +1 1 +2 1 + + +# lazy evaluation of multiple when branches, else branch should never be evaluated +query II +SELECT v, CASE WHEN v == 1 THEN -1 WHEN v == 2 THEN -2 WHEN v == 3 THEN -3 ELSE 10/0 END FROM (VALUES (1), (2), (3)) t(v) +---- +1 -1 +2 -2 +3 -3 + +# covers the InfallibleExprOrNull evaluation strategy +query II +SELECT v, CASE WHEN v THEN 1 END FROM (VALUES (1), (2), (3), (NULL)) t(v) +---- +1 1 +2 1 +3 1 +NULL NULL + statement ok drop table t