-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Avoid scatter operation in ExpressionOrExpression case evaluation method
#18444
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,6 +23,7 @@ use arrow::array::*; | |
| use arrow::compute::kernels::zip::zip; | ||
| use arrow::compute::{ | ||
| is_not_null, not, nullif, prep_null_mask_filter, FilterBuilder, FilterPredicate, | ||
| SlicesIterator, | ||
| }; | ||
| use arrow::datatypes::{DataType, Schema, UInt32Type}; | ||
| use arrow::error::ArrowError; | ||
|
|
@@ -246,10 +247,12 @@ fn is_cheap_and_infallible(expr: &Arc<dyn PhysicalExpr>) -> bool { | |
| } | ||
|
|
||
| /// Creates a [FilterPredicate] from a boolean array. | ||
| fn create_filter(predicate: &BooleanArray) -> FilterPredicate { | ||
| fn create_filter(predicate: &BooleanArray, optimize: bool) -> FilterPredicate { | ||
| let mut filter_builder = FilterBuilder::new(predicate); | ||
| // Always optimize the filter since we use them multiple times. | ||
| filter_builder = filter_builder.optimize(); | ||
| if optimize { | ||
| // Always optimize the filter since we use them multiple times. | ||
| filter_builder = filter_builder.optimize(); | ||
| } | ||
| filter_builder.build() | ||
| } | ||
|
|
||
|
|
@@ -290,6 +293,84 @@ fn filter_array( | |
| filter.filter(array) | ||
| } | ||
|
|
||
| fn merge( | ||
| mask: &BooleanArray, | ||
| truthy: ColumnarValue, | ||
| falsy: ColumnarValue, | ||
| ) -> std::result::Result<ArrayRef, ArrowError> { | ||
| let (truthy, truthy_is_scalar) = match truthy { | ||
| ColumnarValue::Array(a) => (a, false), | ||
| ColumnarValue::Scalar(s) => (s.to_array()?, true), | ||
| }; | ||
| let (falsy, falsy_is_scalar) = match falsy { | ||
| ColumnarValue::Array(a) => (a, false), | ||
| ColumnarValue::Scalar(s) => (s.to_array()?, true), | ||
| }; | ||
|
|
||
| if truthy_is_scalar && falsy_is_scalar { | ||
| return zip(mask, &Scalar::new(truthy), &Scalar::new(falsy)); | ||
| } | ||
|
|
||
| let falsy = falsy.to_data(); | ||
| let truthy = truthy.to_data(); | ||
|
|
||
| let mut mutable = MutableArrayData::new(vec![&truthy, &falsy], false, truthy.len()); | ||
|
|
||
| // the SlicesIterator slices only the true values. So the gaps left by this iterator we need to | ||
| // fill with falsy values | ||
|
|
||
| // keep track of how much is filled | ||
| let mut filled = 0; | ||
| let mut falsy_offset = 0; | ||
| let mut truthy_offset = 0; | ||
|
|
||
| SlicesIterator::new(mask).for_each(|(start, end)| { | ||
| // the gap needs to be filled with falsy values | ||
| if start > filled { | ||
| if falsy_is_scalar { | ||
| for _ in filled..start { | ||
| // Copy the first item from the 'falsy' array into the output buffer. | ||
| mutable.extend(1, 0, 1); | ||
| } | ||
| } else { | ||
| let falsy_length = start - filled; | ||
| let falsy_end = falsy_offset + falsy_length; | ||
| mutable.extend(1, falsy_offset, falsy_end); | ||
| falsy_offset = falsy_end; | ||
| } | ||
| } | ||
| // fill with truthy values | ||
| if truthy_is_scalar { | ||
| for _ in start..end { | ||
| // Copy the first item from the 'truthy' array into the output buffer. | ||
| mutable.extend(0, 0, 1); | ||
| } | ||
| } else { | ||
| let truthy_length = end - start; | ||
| let truthy_end = truthy_offset + truthy_length; | ||
| mutable.extend(0, truthy_offset, truthy_end); | ||
| truthy_offset = truthy_end; | ||
| } | ||
| filled = end; | ||
| }); | ||
| // the remaining part is falsy | ||
| if filled < mask.len() { | ||
| if falsy_is_scalar { | ||
| for _ in filled..mask.len() { | ||
| // Copy the first item from the 'falsy' array into the output buffer. | ||
| mutable.extend(1, 0, 1); | ||
| } | ||
| } else { | ||
| let falsy_length = mask.len() - filled; | ||
| let falsy_end = falsy_offset + falsy_length; | ||
| mutable.extend(1, falsy_offset, falsy_end); | ||
| } | ||
| } | ||
|
|
||
| let data = mutable.freeze(); | ||
| Ok(make_array(data)) | ||
| } | ||
|
|
||
| /// Merges elements by index from a list of [`ArrayData`], creating a new [`ColumnarValue`] from | ||
| /// those values. | ||
| /// | ||
|
|
@@ -342,7 +423,7 @@ fn filter_array( | |
| /// └───────────┘ └─────────┘ └─────────┘ | ||
| /// values indices result | ||
| /// ``` | ||
| fn merge(values: &[ArrayData], indices: &[PartialResultIndex]) -> Result<ArrayRef> { | ||
| fn merge_n(values: &[ArrayData], indices: &[PartialResultIndex]) -> Result<ArrayRef> { | ||
| #[cfg(debug_assertions)] | ||
| for ix in indices { | ||
| if let Some(index) = ix.index() { | ||
|
|
@@ -647,7 +728,7 @@ impl ResultBuilder { | |
| } | ||
| Partial { arrays, indices } => { | ||
| // Merge partial results into a single array. | ||
| Ok(ColumnarValue::Array(merge(&arrays, &indices)?)) | ||
| Ok(ColumnarValue::Array(merge_n(&arrays, &indices)?)) | ||
| } | ||
| Complete(v) => { | ||
| // If we have a complete result, we can just return it. | ||
|
|
@@ -723,6 +804,26 @@ impl CaseExpr { | |
| } | ||
|
|
||
| impl CaseBody { | ||
| fn data_type(&self, input_schema: &Schema) -> Result<DataType> { | ||
| // since all then results have the same data type, we can choose any one as the | ||
| // return data type except for the null. | ||
| let mut data_type = DataType::Null; | ||
| for i in 0..self.when_then_expr.len() { | ||
| data_type = self.when_then_expr[i].1.data_type(input_schema)?; | ||
| if !data_type.equals_datatype(&DataType::Null) { | ||
| break; | ||
| } | ||
| } | ||
| // if all then results are null, we use data type of else expr instead if possible. | ||
| if data_type.equals_datatype(&DataType::Null) { | ||
| if let Some(e) = &self.else_expr { | ||
| data_type = e.data_type(input_schema)?; | ||
| } | ||
| } | ||
|
|
||
| Ok(data_type) | ||
| } | ||
|
|
||
| /// See [CaseExpr::case_when_with_expr]. | ||
| fn case_when_with_expr( | ||
| &self, | ||
|
|
@@ -767,7 +868,7 @@ impl CaseBody { | |
| result_builder.add_branch_result(&remainder_rows, nulls_value)?; | ||
| } else { | ||
| // Filter out the null rows and evaluate the else expression for those | ||
| let nulls_filter = create_filter(¬(&base_not_nulls)?); | ||
| let nulls_filter = create_filter(¬(&base_not_nulls)?, true); | ||
| let nulls_batch = | ||
| filter_record_batch(&remainder_batch, &nulls_filter)?; | ||
| let nulls_rows = filter_array(&remainder_rows, &nulls_filter)?; | ||
|
|
@@ -782,7 +883,7 @@ impl CaseBody { | |
| } | ||
|
|
||
| // Remove the null rows from the remainder batch | ||
| let not_null_filter = create_filter(&base_not_nulls); | ||
| let not_null_filter = create_filter(&base_not_nulls, true); | ||
| remainder_batch = | ||
| Cow::Owned(filter_record_batch(&remainder_batch, ¬_null_filter)?); | ||
| remainder_rows = filter_array(&remainder_rows, ¬_null_filter)?; | ||
|
|
@@ -802,8 +903,7 @@ impl CaseBody { | |
| compare_with_eq(&a, &base_values, base_value_is_nested) | ||
| } | ||
| ColumnarValue::Scalar(s) => { | ||
| let scalar = Scalar::new(s.to_array()?); | ||
| compare_with_eq(&scalar, &base_values, base_value_is_nested) | ||
| compare_with_eq(&s.to_scalar()?, &base_values, base_value_is_nested) | ||
| } | ||
| }?; | ||
|
|
||
|
|
@@ -829,7 +929,7 @@ impl CaseBody { | |
| // for the current branch | ||
| // Still no need to call `prep_null_mask_filter` since `create_filter` will already do | ||
| // this unconditionally. | ||
| let then_filter = create_filter(&when_value); | ||
| let then_filter = create_filter(&when_value, true); | ||
| let then_batch = filter_record_batch(&remainder_batch, &then_filter)?; | ||
| let then_rows = filter_array(&remainder_rows, &then_filter)?; | ||
|
|
||
|
|
@@ -852,7 +952,7 @@ impl CaseBody { | |
| not(&prep_null_mask_filter(&when_value)) | ||
| } | ||
| }?; | ||
| let next_filter = create_filter(&next_selection); | ||
| let next_filter = create_filter(&next_selection, true); | ||
| remainder_batch = | ||
| Cow::Owned(filter_record_batch(&remainder_batch, &next_filter)?); | ||
| remainder_rows = filter_array(&remainder_rows, &next_filter)?; | ||
|
|
@@ -918,7 +1018,7 @@ impl CaseBody { | |
| // for the current branch | ||
| // Still no need to call `prep_null_mask_filter` since `create_filter` will already do | ||
| // this unconditionally. | ||
| let then_filter = create_filter(when_value); | ||
| let then_filter = create_filter(when_value, true); | ||
| let then_batch = filter_record_batch(&remainder_batch, &then_filter)?; | ||
| let then_rows = filter_array(&remainder_rows, &then_filter)?; | ||
|
|
||
|
|
@@ -941,7 +1041,7 @@ impl CaseBody { | |
| not(&prep_null_mask_filter(when_value)) | ||
| } | ||
| }?; | ||
| let next_filter = create_filter(&next_selection); | ||
| let next_filter = create_filter(&next_selection, true); | ||
| remainder_batch = | ||
| Cow::Owned(filter_record_batch(&remainder_batch, &next_filter)?); | ||
| remainder_rows = filter_array(&remainder_rows, &next_filter)?; | ||
|
|
@@ -964,24 +1064,38 @@ impl CaseBody { | |
| &self, | ||
| batch: &RecordBatch, | ||
| when_value: &BooleanArray, | ||
| return_type: &DataType, | ||
| ) -> Result<ColumnarValue> { | ||
| let then_value = self.when_then_expr[0] | ||
| .1 | ||
| .evaluate_selection(batch, when_value)? | ||
| .into_array(batch.num_rows())?; | ||
| let when_value = match when_value.null_count() { | ||
| 0 => Cow::Borrowed(when_value), | ||
| _ => { | ||
| // `prep_null_mask_filter` is required to ensure null is treated as false | ||
| Cow::Owned(prep_null_mask_filter(when_value)) | ||
| } | ||
| }; | ||
|
|
||
| let optimize_filter = batch.num_columns() > 1; | ||
|
||
|
|
||
| let when_filter = create_filter(&when_value, optimize_filter); | ||
| let then_batch = filter_record_batch(batch, &when_filter)?; | ||
| let then_value = self.when_then_expr[0].1.evaluate(&then_batch)?; | ||
|
|
||
| let else_selection = not(&when_value)?; | ||
| let else_filter = create_filter(&else_selection, optimize_filter); | ||
| let else_batch = filter_record_batch(batch, &else_filter)?; | ||
|
|
||
| // evaluate else expression on the values not covered by when_value | ||
| let remainder = not(when_value)?; | ||
| let e = self.else_expr.as_ref().unwrap(); | ||
| // keep `else_expr`'s data type and return type consistent | ||
| let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone()) | ||
| let e = self.else_expr.as_ref().unwrap(); | ||
| let return_type = self.data_type(&batch.schema())?; | ||
| let else_expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone()) | ||
| .unwrap_or_else(|_| Arc::clone(e)); | ||
| let else_ = expr | ||
| .evaluate_selection(batch, &remainder)? | ||
| .into_array(batch.num_rows())?; | ||
|
|
||
| Ok(ColumnarValue::Array(zip(&remainder, &else_, &then_value)?)) | ||
| let else_value = else_expr.evaluate(&else_batch)?; | ||
|
|
||
| Ok(ColumnarValue::Array(merge( | ||
| &when_value, | ||
| then_value, | ||
| else_value, | ||
| )?)) | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -1113,41 +1227,34 @@ impl CaseExpr { | |
| batch: &RecordBatch, | ||
| projected: &ProjectedCaseBody, | ||
| ) -> Result<ColumnarValue> { | ||
| let return_type = self.data_type(&batch.schema())?; | ||
|
|
||
| // evaluate when condition on batch | ||
| let when_value = self.body.when_then_expr[0].0.evaluate(batch)?; | ||
| let when_value = when_value.into_array(batch.num_rows())?; | ||
| // `num_rows == 1` is intentional to avoid expanding scalars. | ||
| // If the `when_value` is effectively a scalar, the 'all true' and 'all false' checks | ||
| // below will avoid incorrectly using the scalar as a merge/zip mask. | ||
| let when_value = when_value.into_array(1)?; | ||
| let when_value = as_boolean_array(&when_value).map_err(|e| { | ||
| DataFusionError::Context( | ||
| "WHEN expression did not return a BooleanArray".to_string(), | ||
| Box::new(e), | ||
| ) | ||
| })?; | ||
|
|
||
| // For the true and false/null selection vectors, bypass `evaluate_selection` and merging | ||
| // results. This avoids materializing the array for the other branch which we will discard | ||
| // entirely anyway. | ||
| let true_count = when_value.true_count(); | ||
| if true_count == batch.num_rows() { | ||
| return self.body.when_then_expr[0].1.evaluate(batch); | ||
| if true_count == when_value.len() { | ||
| // All input rows are true, just call the 'then' expression | ||
| self.body.when_then_expr[0].1.evaluate(batch) | ||
| } else if true_count == 0 { | ||
| return self.body.else_expr.as_ref().unwrap().evaluate(batch); | ||
| } | ||
|
|
||
| // Treat 'NULL' as false value | ||
| let when_value = match when_value.null_count() { | ||
| 0 => Cow::Borrowed(when_value), | ||
| _ => Cow::Owned(prep_null_mask_filter(when_value)), | ||
| }; | ||
|
|
||
| if projected.projection.len() < batch.num_columns() { | ||
| // All input rows are false/null, just call the 'else' expression | ||
| self.body.else_expr.as_ref().unwrap().evaluate(batch) | ||
| } else if projected.projection.len() < batch.num_columns() { | ||
| // The case expressions do not use all the columns of the input batch. | ||
| // Project first to reduce time spent filtering. | ||
| let projected_batch = batch.project(&projected.projection)?; | ||
| projected | ||
| .body | ||
| .expr_or_expr(&projected_batch, &when_value, &return_type) | ||
| projected.body.expr_or_expr(&projected_batch, when_value) | ||
| } else { | ||
| self.body.expr_or_expr(batch, &when_value, &return_type) | ||
| // All columns are used in the case expressions, so there is no need to project. | ||
| self.body.expr_or_expr(batch, when_value) | ||
| } | ||
| } | ||
| } | ||
|
|
@@ -1159,23 +1266,7 @@ impl PhysicalExpr for CaseExpr { | |
| } | ||
|
|
||
| fn data_type(&self, input_schema: &Schema) -> Result<DataType> { | ||
| // since all then results have the same data type, we can choose any one as the | ||
| // return data type except for the null. | ||
| let mut data_type = DataType::Null; | ||
| for i in 0..self.body.when_then_expr.len() { | ||
| data_type = self.body.when_then_expr[i].1.data_type(input_schema)?; | ||
| if !data_type.equals_datatype(&DataType::Null) { | ||
| break; | ||
| } | ||
| } | ||
| // if all then results are null, we use data type of else expr instead if possible. | ||
| if data_type.equals_datatype(&DataType::Null) { | ||
| if let Some(e) = &self.body.else_expr { | ||
| data_type = e.data_type(input_schema)?; | ||
| } | ||
| } | ||
|
|
||
| Ok(data_type) | ||
| self.body.data_type(input_schema) | ||
| } | ||
|
|
||
| fn nullable(&self, input_schema: &Schema) -> Result<bool> { | ||
|
|
@@ -2154,7 +2245,7 @@ mod tests { | |
| PartialResultIndex::try_new(2).unwrap(), | ||
| ]; | ||
|
|
||
| let merged = merge(&[a1, a2, a3], &indices).unwrap(); | ||
| let merged = merge_n(&[a1, a2, a3], &indices).unwrap(); | ||
| let merged = merged.as_string::<i32>(); | ||
|
|
||
| assert_eq!(merged.len(), indices.len()); | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks like an implementation of
zipto me (rather thanmerge). It seems like it would be better to use consistent terminology if this is indeed the caseThis looks very similar to the fancy new code that @rluvaton added to arrow for
zipwith scalars (will be released in arrow 57.1.0):If you agree it is the same, perhaps we can either avoid adding this method to DataFusion or else we can add a comment that says we can revert to using just
ziponce apache/arrow-rs#8653 is availableUh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's almost the same as
zip, but different enough that it's necessary. Without this implementation you can't avoid the scatter step.I've added a test case to show the difference. The short version is that
merge([true, false, true], [A, C], [B])will get you[A, B, C]whilezipwould return an error statingall arrays should have the same length.I agree that these two merge kernels would be better off in
arrow-rswhich is why I made PR apache/arrow-rs#8753.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@rluvaton's work on zip only covers the case of two scalar inputs BTW. That's why I chose to delegate to plain zip in that case. array/array, scalar/array and array/scalar still needs the specific logic here.
The subtle difference between this and zip is in
vs
where zip is using the slice indices from the mask directly, merge only uses the length of the slices and tracks the amount taken from truthy and falsy separately.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@alamb I realised this morning I had chosen a rather poor example in the
arrow-rsPR. I've updated it to illustrate the truthy/falsy length difference.