@@ -247,10 +247,12 @@ fn is_cheap_and_infallible(expr: &Arc<dyn PhysicalExpr>) -> bool {
247247}
248248
249249/// Creates a [FilterPredicate] from a boolean array.
250- fn create_filter ( predicate : & BooleanArray ) -> FilterPredicate {
250+ fn create_filter ( predicate : & BooleanArray , optimize : bool ) -> FilterPredicate {
251251 let mut filter_builder = FilterBuilder :: new ( predicate) ;
252- // Always optimize the filter since we use them multiple times.
253- filter_builder = filter_builder. optimize ( ) ;
252+ if optimize {
253+ // Always optimize the filter since we use them multiple times.
254+ filter_builder = filter_builder. optimize ( ) ;
255+ }
254256 filter_builder. build ( )
255257}
256258
@@ -846,7 +848,7 @@ impl CaseBody {
846848 result_builder. add_branch_result ( & remainder_rows, nulls_value) ?;
847849 } else {
848850 // Filter out the null rows and evaluate the else expression for those
849- let nulls_filter = create_filter ( & not ( & base_not_nulls) ?) ;
851+ let nulls_filter = create_filter ( & not ( & base_not_nulls) ?, true ) ;
850852 let nulls_batch =
851853 filter_record_batch ( & remainder_batch, & nulls_filter) ?;
852854 let nulls_rows = filter_array ( & remainder_rows, & nulls_filter) ?;
@@ -861,7 +863,7 @@ impl CaseBody {
861863 }
862864
863865 // Remove the null rows from the remainder batch
864- let not_null_filter = create_filter ( & base_not_nulls) ;
866+ let not_null_filter = create_filter ( & base_not_nulls, true ) ;
865867 remainder_batch =
866868 Cow :: Owned ( filter_record_batch ( & remainder_batch, & not_null_filter) ?) ;
867869 remainder_rows = filter_array ( & remainder_rows, & not_null_filter) ?;
@@ -907,7 +909,7 @@ impl CaseBody {
907909 // for the current branch
908910 // Still no need to call `prep_null_mask_filter` since `create_filter` will already do
909911 // this unconditionally.
910- let then_filter = create_filter ( & when_value) ;
912+ let then_filter = create_filter ( & when_value, true ) ;
911913 let then_batch = filter_record_batch ( & remainder_batch, & then_filter) ?;
912914 let then_rows = filter_array ( & remainder_rows, & then_filter) ?;
913915
@@ -930,7 +932,7 @@ impl CaseBody {
930932 not ( & prep_null_mask_filter ( & when_value) )
931933 }
932934 } ?;
933- let next_filter = create_filter ( & next_selection) ;
935+ let next_filter = create_filter ( & next_selection, true ) ;
934936 remainder_batch =
935937 Cow :: Owned ( filter_record_batch ( & remainder_batch, & next_filter) ?) ;
936938 remainder_rows = filter_array ( & remainder_rows, & next_filter) ?;
@@ -996,7 +998,7 @@ impl CaseBody {
996998 // for the current branch
997999 // Still no need to call `prep_null_mask_filter` since `create_filter` will already do
9981000 // this unconditionally.
999- let then_filter = create_filter ( when_value) ;
1001+ let then_filter = create_filter ( when_value, true ) ;
10001002 let then_batch = filter_record_batch ( & remainder_batch, & then_filter) ?;
10011003 let then_rows = filter_array ( & remainder_rows, & then_filter) ?;
10021004
@@ -1019,7 +1021,7 @@ impl CaseBody {
10191021 not ( & prep_null_mask_filter ( when_value) )
10201022 }
10211023 } ?;
1022- let next_filter = create_filter ( & next_selection) ;
1024+ let next_filter = create_filter ( & next_selection, true ) ;
10231025 remainder_batch =
10241026 Cow :: Owned ( filter_record_batch ( & remainder_batch, & next_filter) ?) ;
10251027 remainder_rows = filter_array ( & remainder_rows, & next_filter) ?;
@@ -1044,20 +1046,29 @@ impl CaseBody {
10441046 when_value : & BooleanArray ,
10451047 return_type : & DataType ,
10461048 ) -> Result < ColumnarValue > {
1047- let when_filter = create_filter ( & when_value) ;
1049+ let when_value = match when_value. null_count ( ) {
1050+ 0 => Cow :: Borrowed ( when_value) ,
1051+ _ => {
1052+ // `prep_null_mask_filter` is required to ensure null is treated as false
1053+ Cow :: Owned ( prep_null_mask_filter ( when_value) )
1054+ }
1055+ } ;
1056+
1057+ let optimize_filter = batch. num_columns ( ) > 1 ;
1058+ let when_filter = create_filter ( & when_value, optimize_filter) ;
10481059 let then_batch = filter_record_batch ( batch, & when_filter) ?;
10491060 let then_value = self . when_then_expr [ 0 ] . 1 . evaluate ( & then_batch) ?;
10501061
10511062 let else_selection = not ( & when_value) ?;
1052- let else_filter = create_filter ( & else_selection) ;
1063+ let else_filter = create_filter ( & else_selection, optimize_filter ) ;
10531064 let else_batch = filter_record_batch ( batch, & else_filter) ?;
10541065 let e = self . else_expr . as_ref ( ) . unwrap ( ) ;
10551066 // keep `else_expr`'s data type and return type consistent
10561067 let else_expr = try_cast ( Arc :: clone ( e) , & batch. schema ( ) , return_type. clone ( ) )
10571068 . unwrap_or_else ( |_| Arc :: clone ( e) ) ;
10581069 let else_value = else_expr. evaluate ( & else_batch) ?;
10591070
1060- Ok ( ColumnarValue :: Array ( zip_unaligned (
1071+ Ok ( ColumnarValue :: Array ( merge (
10611072 & when_value,
10621073 then_value,
10631074 else_value,
@@ -1215,14 +1226,6 @@ impl CaseExpr {
12151226 return self . body . else_expr . as_ref ( ) . unwrap ( ) . evaluate ( batch) ;
12161227 }
12171228
1218- let when_value = match when_value. null_count ( ) {
1219- 0 => Cow :: Borrowed ( when_value) ,
1220- _ => {
1221- // `prep_null_mask_filter` is required to ensure null is treated as false
1222- Cow :: Owned ( prep_null_mask_filter ( when_value) )
1223- }
1224- } ;
1225-
12261229 if projected. projection . len ( ) < batch. num_columns ( ) {
12271230 let projected_batch = batch. project ( & projected. projection ) ?;
12281231 projected
@@ -2236,7 +2239,7 @@ mod tests {
22362239 PartialResultIndex :: try_new( 2 ) . unwrap( ) ,
22372240 ] ;
22382241
2239- let merged = merge ( & [ a1, a2, a3] , & indices) . unwrap ( ) ;
2242+ let merged = merge_n ( & [ a1, a2, a3] , & indices) . unwrap ( ) ;
22402243 let merged = merged. as_string :: < i32 > ( ) ;
22412244
22422245 assert_eq ! ( merged. len( ) , indices. len( ) ) ;
0 commit comments