Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
223 changes: 157 additions & 66 deletions datafusion/physical-expr/src/expressions/case.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use arrow::array::*;
use arrow::compute::kernels::zip::zip;
use arrow::compute::{
is_not_null, not, nullif, prep_null_mask_filter, FilterBuilder, FilterPredicate,
SlicesIterator,
};
use arrow::datatypes::{DataType, Schema, UInt32Type};
use arrow::error::ArrowError;
Expand Down Expand Up @@ -246,10 +247,12 @@ fn is_cheap_and_infallible(expr: &Arc<dyn PhysicalExpr>) -> bool {
}

/// Creates a [FilterPredicate] from a boolean array.
fn create_filter(predicate: &BooleanArray) -> FilterPredicate {
fn create_filter(predicate: &BooleanArray, optimize: bool) -> FilterPredicate {
let mut filter_builder = FilterBuilder::new(predicate);
// Always optimize the filter since we use them multiple times.
filter_builder = filter_builder.optimize();
if optimize {
// Always optimize the filter since we use them multiple times.
filter_builder = filter_builder.optimize();
}
filter_builder.build()
}

Expand Down Expand Up @@ -290,6 +293,84 @@ fn filter_array(
filter.filter(array)
}

fn merge(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like an implementation of zip to me (rather than merge). It seems like it would be better to use consistent terminology if this is indeed the case

This looks very similar to the fancy new code that @rluvaton added to arrow for zip with scalars (will be released in arrow 57.1.0):

If you agree it is the same, perhaps we can either avoid adding this method to DataFusion or else we can add a comment that says we can revert to using just zip once apache/arrow-rs#8653 is available

Copy link
Contributor Author

@pepijnve pepijnve Nov 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's almost the same as zip, but different enough that it's necessary. Without this implementation you can't avoid the scatter step.

I've added a test case to show the difference. The short version is that merge([true, false, true], [A, C], [B]) will get you [A, B, C] while zip would return an error stating all arrays should have the same length.

I agree that these two merge kernels would be better off in arrow-rs which is why I made PR apache/arrow-rs#8753.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rluvaton's work on zip only covers the case of two scalar inputs BTW. That's why I chose to delegate to plain zip in that case. array/array, scalar/array and array/scalar still needs the specific logic here.

The subtle difference between this and zip is in

let falsy_length = start - filled;
let falsy_end = falsy_offset + falsy_length;
mutable.extend(1, falsy_offset, falsy_end);
falsy_offset = falsy_end;

vs

mutable.extend(1, filled, start);

where zip is using the slice indices from the mask directly, merge only uses the length of the slices and tracks the amount taken from truthy and falsy separately.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@alamb I realised this morning I had chosen a rather poor example in the arrow-rs PR. I've updated it to illustrate the truthy/falsy length difference.

mask: &BooleanArray,
truthy: ColumnarValue,
falsy: ColumnarValue,
) -> std::result::Result<ArrayRef, ArrowError> {
let (truthy, truthy_is_scalar) = match truthy {
ColumnarValue::Array(a) => (a, false),
ColumnarValue::Scalar(s) => (s.to_array()?, true),
};
let (falsy, falsy_is_scalar) = match falsy {
ColumnarValue::Array(a) => (a, false),
ColumnarValue::Scalar(s) => (s.to_array()?, true),
};

if truthy_is_scalar && falsy_is_scalar {
return zip(mask, &Scalar::new(truthy), &Scalar::new(falsy));
}

let falsy = falsy.to_data();
let truthy = truthy.to_data();

let mut mutable = MutableArrayData::new(vec![&truthy, &falsy], false, truthy.len());

// the SlicesIterator slices only the true values. So the gaps left by this iterator we need to
// fill with falsy values

// keep track of how much is filled
let mut filled = 0;
let mut falsy_offset = 0;
let mut truthy_offset = 0;

SlicesIterator::new(mask).for_each(|(start, end)| {
// the gap needs to be filled with falsy values
if start > filled {
if falsy_is_scalar {
for _ in filled..start {
// Copy the first item from the 'falsy' array into the output buffer.
mutable.extend(1, 0, 1);
}
} else {
let falsy_length = start - filled;
let falsy_end = falsy_offset + falsy_length;
mutable.extend(1, falsy_offset, falsy_end);
falsy_offset = falsy_end;
}
}
// fill with truthy values
if truthy_is_scalar {
for _ in start..end {
// Copy the first item from the 'truthy' array into the output buffer.
mutable.extend(0, 0, 1);
}
} else {
let truthy_length = end - start;
let truthy_end = truthy_offset + truthy_length;
mutable.extend(0, truthy_offset, truthy_end);
truthy_offset = truthy_end;
}
filled = end;
});
// the remaining part is falsy
if filled < mask.len() {
if falsy_is_scalar {
for _ in filled..mask.len() {
// Copy the first item from the 'falsy' array into the output buffer.
mutable.extend(1, 0, 1);
}
} else {
let falsy_length = mask.len() - filled;
let falsy_end = falsy_offset + falsy_length;
mutable.extend(1, falsy_offset, falsy_end);
}
}

let data = mutable.freeze();
Ok(make_array(data))
}

/// Merges elements by index from a list of [`ArrayData`], creating a new [`ColumnarValue`] from
/// those values.
///
Expand Down Expand Up @@ -342,7 +423,7 @@ fn filter_array(
/// └───────────┘ └─────────┘ └─────────┘
/// values indices result
/// ```
fn merge(values: &[ArrayData], indices: &[PartialResultIndex]) -> Result<ArrayRef> {
fn merge_n(values: &[ArrayData], indices: &[PartialResultIndex]) -> Result<ArrayRef> {
#[cfg(debug_assertions)]
for ix in indices {
if let Some(index) = ix.index() {
Expand Down Expand Up @@ -647,7 +728,7 @@ impl ResultBuilder {
}
Partial { arrays, indices } => {
// Merge partial results into a single array.
Ok(ColumnarValue::Array(merge(&arrays, &indices)?))
Ok(ColumnarValue::Array(merge_n(&arrays, &indices)?))
}
Complete(v) => {
// If we have a complete result, we can just return it.
Expand Down Expand Up @@ -723,6 +804,26 @@ impl CaseExpr {
}

impl CaseBody {
fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
// since all then results have the same data type, we can choose any one as the
// return data type except for the null.
let mut data_type = DataType::Null;
for i in 0..self.when_then_expr.len() {
data_type = self.when_then_expr[i].1.data_type(input_schema)?;
if !data_type.equals_datatype(&DataType::Null) {
break;
}
}
// if all then results are null, we use data type of else expr instead if possible.
if data_type.equals_datatype(&DataType::Null) {
if let Some(e) = &self.else_expr {
data_type = e.data_type(input_schema)?;
}
}

Ok(data_type)
}

/// See [CaseExpr::case_when_with_expr].
fn case_when_with_expr(
&self,
Expand Down Expand Up @@ -767,7 +868,7 @@ impl CaseBody {
result_builder.add_branch_result(&remainder_rows, nulls_value)?;
} else {
// Filter out the null rows and evaluate the else expression for those
let nulls_filter = create_filter(&not(&base_not_nulls)?);
let nulls_filter = create_filter(&not(&base_not_nulls)?, true);
let nulls_batch =
filter_record_batch(&remainder_batch, &nulls_filter)?;
let nulls_rows = filter_array(&remainder_rows, &nulls_filter)?;
Expand All @@ -782,7 +883,7 @@ impl CaseBody {
}

// Remove the null rows from the remainder batch
let not_null_filter = create_filter(&base_not_nulls);
let not_null_filter = create_filter(&base_not_nulls, true);
remainder_batch =
Cow::Owned(filter_record_batch(&remainder_batch, &not_null_filter)?);
remainder_rows = filter_array(&remainder_rows, &not_null_filter)?;
Expand All @@ -802,8 +903,7 @@ impl CaseBody {
compare_with_eq(&a, &base_values, base_value_is_nested)
}
ColumnarValue::Scalar(s) => {
let scalar = Scalar::new(s.to_array()?);
compare_with_eq(&scalar, &base_values, base_value_is_nested)
compare_with_eq(&s.to_scalar()?, &base_values, base_value_is_nested)
}
}?;

Expand All @@ -829,7 +929,7 @@ impl CaseBody {
// for the current branch
// Still no need to call `prep_null_mask_filter` since `create_filter` will already do
// this unconditionally.
let then_filter = create_filter(&when_value);
let then_filter = create_filter(&when_value, true);
let then_batch = filter_record_batch(&remainder_batch, &then_filter)?;
let then_rows = filter_array(&remainder_rows, &then_filter)?;

Expand All @@ -852,7 +952,7 @@ impl CaseBody {
not(&prep_null_mask_filter(&when_value))
}
}?;
let next_filter = create_filter(&next_selection);
let next_filter = create_filter(&next_selection, true);
remainder_batch =
Cow::Owned(filter_record_batch(&remainder_batch, &next_filter)?);
remainder_rows = filter_array(&remainder_rows, &next_filter)?;
Expand Down Expand Up @@ -918,7 +1018,7 @@ impl CaseBody {
// for the current branch
// Still no need to call `prep_null_mask_filter` since `create_filter` will already do
// this unconditionally.
let then_filter = create_filter(when_value);
let then_filter = create_filter(when_value, true);
let then_batch = filter_record_batch(&remainder_batch, &then_filter)?;
let then_rows = filter_array(&remainder_rows, &then_filter)?;

Expand All @@ -941,7 +1041,7 @@ impl CaseBody {
not(&prep_null_mask_filter(when_value))
}
}?;
let next_filter = create_filter(&next_selection);
let next_filter = create_filter(&next_selection, true);
remainder_batch =
Cow::Owned(filter_record_batch(&remainder_batch, &next_filter)?);
remainder_rows = filter_array(&remainder_rows, &next_filter)?;
Expand All @@ -964,24 +1064,38 @@ impl CaseBody {
&self,
batch: &RecordBatch,
when_value: &BooleanArray,
return_type: &DataType,
) -> Result<ColumnarValue> {
let then_value = self.when_then_expr[0]
.1
.evaluate_selection(batch, when_value)?
.into_array(batch.num_rows())?;
let when_value = match when_value.null_count() {
0 => Cow::Borrowed(when_value),
_ => {
// `prep_null_mask_filter` is required to ensure null is treated as false
Cow::Owned(prep_null_mask_filter(when_value))
}
};

let optimize_filter = batch.num_columns() > 1;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might also be worth checking if there are any nested types (e.g. structarrays) and optimize the filter in that case too -- this is done elsewhere (maybe in the filter kernel itself 🤔 )

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. The logic that handles that isn't pub in arrow-rs unfortunately. I can duplicate it here if you like.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed for now. Would it be useful to make this a method of DataType? I can prepare an arrow-rs PR for that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A DataType method didn't make much sense to me after all since the predicate is very much tied to the actual filter implementation logic. I went for apache/arrow-rs#8782 instead.


let when_filter = create_filter(&when_value, optimize_filter);
let then_batch = filter_record_batch(batch, &when_filter)?;
let then_value = self.when_then_expr[0].1.evaluate(&then_batch)?;

let else_selection = not(&when_value)?;
let else_filter = create_filter(&else_selection, optimize_filter);
let else_batch = filter_record_batch(batch, &else_filter)?;

// evaluate else expression on the values not covered by when_value
let remainder = not(when_value)?;
let e = self.else_expr.as_ref().unwrap();
// keep `else_expr`'s data type and return type consistent
let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())
let e = self.else_expr.as_ref().unwrap();
let return_type = self.data_type(&batch.schema())?;
let else_expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())
.unwrap_or_else(|_| Arc::clone(e));
let else_ = expr
.evaluate_selection(batch, &remainder)?
.into_array(batch.num_rows())?;

Ok(ColumnarValue::Array(zip(&remainder, &else_, &then_value)?))
let else_value = else_expr.evaluate(&else_batch)?;

Ok(ColumnarValue::Array(merge(
&when_value,
then_value,
else_value,
)?))
}
}

Expand Down Expand Up @@ -1113,41 +1227,34 @@ impl CaseExpr {
batch: &RecordBatch,
projected: &ProjectedCaseBody,
) -> Result<ColumnarValue> {
let return_type = self.data_type(&batch.schema())?;

// evaluate when condition on batch
let when_value = self.body.when_then_expr[0].0.evaluate(batch)?;
let when_value = when_value.into_array(batch.num_rows())?;
// `num_rows == 1` is intentional to avoid expanding scalars.
// If the `when_value` is effectively a scalar, the 'all true' and 'all false' checks
// below will avoid incorrectly using the scalar as a merge/zip mask.
let when_value = when_value.into_array(1)?;
let when_value = as_boolean_array(&when_value).map_err(|e| {
DataFusionError::Context(
"WHEN expression did not return a BooleanArray".to_string(),
Box::new(e),
)
})?;

// For the true and false/null selection vectors, bypass `evaluate_selection` and merging
// results. This avoids materializing the array for the other branch which we will discard
// entirely anyway.
let true_count = when_value.true_count();
if true_count == batch.num_rows() {
return self.body.when_then_expr[0].1.evaluate(batch);
if true_count == when_value.len() {
// All input rows are true, just call the 'then' expression
self.body.when_then_expr[0].1.evaluate(batch)
} else if true_count == 0 {
return self.body.else_expr.as_ref().unwrap().evaluate(batch);
}

// Treat 'NULL' as false value
let when_value = match when_value.null_count() {
0 => Cow::Borrowed(when_value),
_ => Cow::Owned(prep_null_mask_filter(when_value)),
};

if projected.projection.len() < batch.num_columns() {
// All input rows are false/null, just call the 'else' expression
self.body.else_expr.as_ref().unwrap().evaluate(batch)
} else if projected.projection.len() < batch.num_columns() {
// The case expressions do not use all the columns of the input batch.
// Project first to reduce time spent filtering.
let projected_batch = batch.project(&projected.projection)?;
projected
.body
.expr_or_expr(&projected_batch, &when_value, &return_type)
projected.body.expr_or_expr(&projected_batch, when_value)
} else {
self.body.expr_or_expr(batch, &when_value, &return_type)
// All columns are used in the case expressions, so there is no need to project.
self.body.expr_or_expr(batch, when_value)
}
}
}
Expand All @@ -1159,23 +1266,7 @@ impl PhysicalExpr for CaseExpr {
}

fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
// since all then results have the same data type, we can choose any one as the
// return data type except for the null.
let mut data_type = DataType::Null;
for i in 0..self.body.when_then_expr.len() {
data_type = self.body.when_then_expr[i].1.data_type(input_schema)?;
if !data_type.equals_datatype(&DataType::Null) {
break;
}
}
// if all then results are null, we use data type of else expr instead if possible.
if data_type.equals_datatype(&DataType::Null) {
if let Some(e) = &self.body.else_expr {
data_type = e.data_type(input_schema)?;
}
}

Ok(data_type)
self.body.data_type(input_schema)
}

fn nullable(&self, input_schema: &Schema) -> Result<bool> {
Expand Down Expand Up @@ -2154,7 +2245,7 @@ mod tests {
PartialResultIndex::try_new(2).unwrap(),
];

let merged = merge(&[a1, a2, a3], &indices).unwrap();
let merged = merge_n(&[a1, a2, a3], &indices).unwrap();
let merged = merged.as_string::<i32>();

assert_eq!(merged.len(), indices.len());
Expand Down