Skip to content

Commit ce86da7

Browse files
committed
Only optimize filter in case of multiple columns
1 parent 8f898ef commit ce86da7

File tree

1 file changed

+24
-21
lines changed
  • datafusion/physical-expr/src/expressions

1 file changed

+24
-21
lines changed

datafusion/physical-expr/src/expressions/case.rs

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -247,10 +247,12 @@ fn is_cheap_and_infallible(expr: &Arc<dyn PhysicalExpr>) -> bool {
247247
}
248248

249249
/// Creates a [FilterPredicate] from a boolean array.
250-
fn create_filter(predicate: &BooleanArray) -> FilterPredicate {
250+
fn create_filter(predicate: &BooleanArray, optimize: bool) -> FilterPredicate {
251251
let mut filter_builder = FilterBuilder::new(predicate);
252-
// Always optimize the filter since we use them multiple times.
253-
filter_builder = filter_builder.optimize();
252+
if optimize {
253+
// Always optimize the filter since we use them multiple times.
254+
filter_builder = filter_builder.optimize();
255+
}
254256
filter_builder.build()
255257
}
256258

@@ -846,7 +848,7 @@ impl CaseBody {
846848
result_builder.add_branch_result(&remainder_rows, nulls_value)?;
847849
} else {
848850
// Filter out the null rows and evaluate the else expression for those
849-
let nulls_filter = create_filter(&not(&base_not_nulls)?);
851+
let nulls_filter = create_filter(&not(&base_not_nulls)?, true);
850852
let nulls_batch =
851853
filter_record_batch(&remainder_batch, &nulls_filter)?;
852854
let nulls_rows = filter_array(&remainder_rows, &nulls_filter)?;
@@ -861,7 +863,7 @@ impl CaseBody {
861863
}
862864

863865
// Remove the null rows from the remainder batch
864-
let not_null_filter = create_filter(&base_not_nulls);
866+
let not_null_filter = create_filter(&base_not_nulls, true);
865867
remainder_batch =
866868
Cow::Owned(filter_record_batch(&remainder_batch, &not_null_filter)?);
867869
remainder_rows = filter_array(&remainder_rows, &not_null_filter)?;
@@ -907,7 +909,7 @@ impl CaseBody {
907909
// for the current branch
908910
// Still no need to call `prep_null_mask_filter` since `create_filter` will already do
909911
// this unconditionally.
910-
let then_filter = create_filter(&when_value);
912+
let then_filter = create_filter(&when_value, true);
911913
let then_batch = filter_record_batch(&remainder_batch, &then_filter)?;
912914
let then_rows = filter_array(&remainder_rows, &then_filter)?;
913915

@@ -930,7 +932,7 @@ impl CaseBody {
930932
not(&prep_null_mask_filter(&when_value))
931933
}
932934
}?;
933-
let next_filter = create_filter(&next_selection);
935+
let next_filter = create_filter(&next_selection, true);
934936
remainder_batch =
935937
Cow::Owned(filter_record_batch(&remainder_batch, &next_filter)?);
936938
remainder_rows = filter_array(&remainder_rows, &next_filter)?;
@@ -996,7 +998,7 @@ impl CaseBody {
996998
// for the current branch
997999
// Still no need to call `prep_null_mask_filter` since `create_filter` will already do
9981000
// this unconditionally.
999-
let then_filter = create_filter(when_value);
1001+
let then_filter = create_filter(when_value, true);
10001002
let then_batch = filter_record_batch(&remainder_batch, &then_filter)?;
10011003
let then_rows = filter_array(&remainder_rows, &then_filter)?;
10021004

@@ -1019,7 +1021,7 @@ impl CaseBody {
10191021
not(&prep_null_mask_filter(when_value))
10201022
}
10211023
}?;
1022-
let next_filter = create_filter(&next_selection);
1024+
let next_filter = create_filter(&next_selection, true);
10231025
remainder_batch =
10241026
Cow::Owned(filter_record_batch(&remainder_batch, &next_filter)?);
10251027
remainder_rows = filter_array(&remainder_rows, &next_filter)?;
@@ -1044,20 +1046,29 @@ impl CaseBody {
10441046
when_value: &BooleanArray,
10451047
return_type: &DataType,
10461048
) -> Result<ColumnarValue> {
1047-
let when_filter = create_filter(&when_value);
1049+
let when_value = match when_value.null_count() {
1050+
0 => Cow::Borrowed(when_value),
1051+
_ => {
1052+
// `prep_null_mask_filter` is required to ensure null is treated as false
1053+
Cow::Owned(prep_null_mask_filter(when_value))
1054+
}
1055+
};
1056+
1057+
let optimize_filter = batch.num_columns() > 1;
1058+
let when_filter = create_filter(&when_value, optimize_filter);
10481059
let then_batch = filter_record_batch(batch, &when_filter)?;
10491060
let then_value = self.when_then_expr[0].1.evaluate(&then_batch)?;
10501061

10511062
let else_selection = not(&when_value)?;
1052-
let else_filter = create_filter(&else_selection);
1063+
let else_filter = create_filter(&else_selection, optimize_filter);
10531064
let else_batch = filter_record_batch(batch, &else_filter)?;
10541065
let e = self.else_expr.as_ref().unwrap();
10551066
// keep `else_expr`'s data type and return type consistent
10561067
let else_expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())
10571068
.unwrap_or_else(|_| Arc::clone(e));
10581069
let else_value = else_expr.evaluate(&else_batch)?;
10591070

1060-
Ok(ColumnarValue::Array(zip_unaligned(
1071+
Ok(ColumnarValue::Array(merge(
10611072
&when_value,
10621073
then_value,
10631074
else_value,
@@ -1215,14 +1226,6 @@ impl CaseExpr {
12151226
return self.body.else_expr.as_ref().unwrap().evaluate(batch);
12161227
}
12171228

1218-
let when_value = match when_value.null_count() {
1219-
0 => Cow::Borrowed(when_value),
1220-
_ => {
1221-
// `prep_null_mask_filter` is required to ensure null is treated as false
1222-
Cow::Owned(prep_null_mask_filter(when_value))
1223-
}
1224-
};
1225-
12261229
if projected.projection.len() < batch.num_columns() {
12271230
let projected_batch = batch.project(&projected.projection)?;
12281231
projected
@@ -2236,7 +2239,7 @@ mod tests {
22362239
PartialResultIndex::try_new(2).unwrap(),
22372240
];
22382241

2239-
let merged = merge(&[a1, a2, a3], &indices).unwrap();
2242+
let merged = merge_n(&[a1, a2, a3], &indices).unwrap();
22402243
let merged = merged.as_string::<i32>();
22412244

22422245
assert_eq!(merged.len(), indices.len());

0 commit comments

Comments
 (0)