@@ -804,6 +804,26 @@ impl CaseExpr {
804804}
805805
806806impl CaseBody {
807+ fn data_type ( & self , input_schema : & Schema ) -> Result < DataType > {
808+ // since all then results have the same data type, we can choose any one as the
809+ // return data type except for the null.
810+ let mut data_type = DataType :: Null ;
811+ for i in 0 ..self . when_then_expr . len ( ) {
812+ data_type = self . when_then_expr [ i] . 1 . data_type ( input_schema) ?;
813+ if !data_type. equals_datatype ( & DataType :: Null ) {
814+ break ;
815+ }
816+ }
817+ // if all then results are null, we use data type of else expr instead if possible.
818+ if data_type. equals_datatype ( & DataType :: Null ) {
819+ if let Some ( e) = & self . else_expr {
820+ data_type = e. data_type ( input_schema) ?;
821+ }
822+ }
823+
824+ Ok ( data_type)
825+ }
826+
807827 /// See [CaseExpr::case_when_with_expr].
808828 fn case_when_with_expr (
809829 & self ,
@@ -1044,7 +1064,6 @@ impl CaseBody {
10441064 & self ,
10451065 batch : & RecordBatch ,
10461066 when_value : & BooleanArray ,
1047- return_type : & DataType ,
10481067 ) -> Result < ColumnarValue > {
10491068 let when_value = match when_value. null_count ( ) {
10501069 0 => Cow :: Borrowed ( when_value) ,
@@ -1055,17 +1074,21 @@ impl CaseBody {
10551074 } ;
10561075
10571076 let optimize_filter = batch. num_columns ( ) > 1 ;
1077+
10581078 let when_filter = create_filter ( & when_value, optimize_filter) ;
10591079 let then_batch = filter_record_batch ( batch, & when_filter) ?;
10601080 let then_value = self . when_then_expr [ 0 ] . 1 . evaluate ( & then_batch) ?;
10611081
10621082 let else_selection = not ( & when_value) ?;
10631083 let else_filter = create_filter ( & else_selection, optimize_filter) ;
10641084 let else_batch = filter_record_batch ( batch, & else_filter) ?;
1065- let e = self . else_expr . as_ref ( ) . unwrap ( ) ;
1085+
10661086 // keep `else_expr`'s data type and return type consistent
1087+ let e = self . else_expr . as_ref ( ) . unwrap ( ) ;
1088+ let return_type = self . data_type ( & batch. schema ( ) ) ?;
10671089 let else_expr = try_cast ( Arc :: clone ( e) , & batch. schema ( ) , return_type. clone ( ) )
10681090 . unwrap_or_else ( |_| Arc :: clone ( e) ) ;
1091+
10691092 let else_value = else_expr. evaluate ( & else_batch) ?;
10701093
10711094 Ok ( ColumnarValue :: Array ( merge (
@@ -1204,35 +1227,34 @@ impl CaseExpr {
12041227 batch : & RecordBatch ,
12051228 projected : & ProjectedCaseBody ,
12061229 ) -> Result < ColumnarValue > {
1207- let return_type = self . data_type ( & batch. schema ( ) ) ?;
1208-
12091230 // evaluate when condition on batch
12101231 let when_value = self . body . when_then_expr [ 0 ] . 0 . evaluate ( batch) ?;
1211- let when_value = when_value. into_array ( batch. num_rows ( ) ) ?;
1232+ // `num_rows == 1` is intentional to avoid expanding scalars.
1233+ // If the `when_value` is effectively a scalar, the 'all true' and 'all false' checks
1234+ // below will avoid incorrectly using the scalar as a merge/zip mask.
1235+ let when_value = when_value. into_array ( 1 ) ?;
12121236 let when_value = as_boolean_array ( & when_value) . map_err ( |e| {
12131237 DataFusionError :: Context (
12141238 "WHEN expression did not return a BooleanArray" . to_string ( ) ,
12151239 Box :: new ( e) ,
12161240 )
12171241 } ) ?;
12181242
1219- // For the true and false/null selection vectors, bypass filtering and merging
1220- // results. This avoids materializing the array for the other branch which we will discard
1221- // entirely anyway.
12221243 let true_count = when_value. true_count ( ) ;
1223- if true_count == batch. num_rows ( ) {
1224- return self . body . when_then_expr [ 0 ] . 1 . evaluate ( batch) ;
1244+ if true_count == when_value. len ( ) {
1245+ // All input rows are true, just call the 'then' expression
1246+ self . body . when_then_expr [ 0 ] . 1 . evaluate ( batch)
12251247 } else if true_count == 0 {
1226- return self . body . else_expr . as_ref ( ) . unwrap ( ) . evaluate ( batch) ;
1227- }
1228-
1229- if projected. projection . len ( ) < batch. num_columns ( ) {
1248+ // All input rows are false/null, just call the 'else' expression
1249+ self . body . else_expr . as_ref ( ) . unwrap ( ) . evaluate ( batch)
1250+ } else if projected. projection . len ( ) < batch. num_columns ( ) {
1251+ // The case expressions do not use all the columns of the input batch.
1252+ // Project first to reduce time spent filtering.
12301253 let projected_batch = batch. project ( & projected. projection ) ?;
1231- projected
1232- . body
1233- . expr_or_expr ( & projected_batch, & when_value, & return_type)
1254+ projected. body . expr_or_expr ( & projected_batch, when_value)
12341255 } else {
1235- self . body . expr_or_expr ( batch, & when_value, & return_type)
1256+ // All columns are used in the case expressions, so there is no need to project.
1257+ self . body . expr_or_expr ( batch, when_value)
12361258 }
12371259 }
12381260}
@@ -1244,23 +1266,7 @@ impl PhysicalExpr for CaseExpr {
12441266 }
12451267
12461268 fn data_type ( & self , input_schema : & Schema ) -> Result < DataType > {
1247- // since all then results have the same data type, we can choose any one as the
1248- // return data type except for the null.
1249- let mut data_type = DataType :: Null ;
1250- for i in 0 ..self . body . when_then_expr . len ( ) {
1251- data_type = self . body . when_then_expr [ i] . 1 . data_type ( input_schema) ?;
1252- if !data_type. equals_datatype ( & DataType :: Null ) {
1253- break ;
1254- }
1255- }
1256- // if all then results are null, we use data type of else expr instead if possible.
1257- if data_type. equals_datatype ( & DataType :: Null ) {
1258- if let Some ( e) = & self . body . else_expr {
1259- data_type = e. data_type ( input_schema) ?;
1260- }
1261- }
1262-
1263- Ok ( data_type)
1269+ self . body . data_type ( input_schema)
12641270 }
12651271
12661272 fn nullable ( & self , input_schema : & Schema ) -> Result < bool > {
0 commit comments