Skip to content

Commit 821e50d

Browse files
committed
Avoid expanding scalars
1 parent ce86da7 commit 821e50d

File tree

1 file changed

+41
-35
lines changed
  • datafusion/physical-expr/src/expressions

1 file changed

+41
-35
lines changed

datafusion/physical-expr/src/expressions/case.rs

Lines changed: 41 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -804,6 +804,26 @@ impl CaseExpr {
804804
}
805805

806806
impl CaseBody {
807+
fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
808+
// since all then results have the same data type, we can choose any one as the
809+
// return data type except for the null.
810+
let mut data_type = DataType::Null;
811+
for i in 0..self.when_then_expr.len() {
812+
data_type = self.when_then_expr[i].1.data_type(input_schema)?;
813+
if !data_type.equals_datatype(&DataType::Null) {
814+
break;
815+
}
816+
}
817+
// if all then results are null, we use data type of else expr instead if possible.
818+
if data_type.equals_datatype(&DataType::Null) {
819+
if let Some(e) = &self.else_expr {
820+
data_type = e.data_type(input_schema)?;
821+
}
822+
}
823+
824+
Ok(data_type)
825+
}
826+
807827
/// See [CaseExpr::case_when_with_expr].
808828
fn case_when_with_expr(
809829
&self,
@@ -1044,7 +1064,6 @@ impl CaseBody {
10441064
&self,
10451065
batch: &RecordBatch,
10461066
when_value: &BooleanArray,
1047-
return_type: &DataType,
10481067
) -> Result<ColumnarValue> {
10491068
let when_value = match when_value.null_count() {
10501069
0 => Cow::Borrowed(when_value),
@@ -1055,17 +1074,21 @@ impl CaseBody {
10551074
};
10561075

10571076
let optimize_filter = batch.num_columns() > 1;
1077+
10581078
let when_filter = create_filter(&when_value, optimize_filter);
10591079
let then_batch = filter_record_batch(batch, &when_filter)?;
10601080
let then_value = self.when_then_expr[0].1.evaluate(&then_batch)?;
10611081

10621082
let else_selection = not(&when_value)?;
10631083
let else_filter = create_filter(&else_selection, optimize_filter);
10641084
let else_batch = filter_record_batch(batch, &else_filter)?;
1065-
let e = self.else_expr.as_ref().unwrap();
1085+
10661086
// keep `else_expr`'s data type and return type consistent
1087+
let e = self.else_expr.as_ref().unwrap();
1088+
let return_type = self.data_type(&batch.schema())?;
10671089
let else_expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())
10681090
.unwrap_or_else(|_| Arc::clone(e));
1091+
10691092
let else_value = else_expr.evaluate(&else_batch)?;
10701093

10711094
Ok(ColumnarValue::Array(merge(
@@ -1204,35 +1227,34 @@ impl CaseExpr {
12041227
batch: &RecordBatch,
12051228
projected: &ProjectedCaseBody,
12061229
) -> Result<ColumnarValue> {
1207-
let return_type = self.data_type(&batch.schema())?;
1208-
12091230
// evaluate when condition on batch
12101231
let when_value = self.body.when_then_expr[0].0.evaluate(batch)?;
1211-
let when_value = when_value.into_array(batch.num_rows())?;
1232+
// `num_rows == 1` is intentional to avoid expanding scalars.
1233+
// If the `when_value` is effectively a scalar, the 'all true' and 'all false' checks
1234+
// below will avoid incorrectly using the scalar as a merge/zip mask.
1235+
let when_value = when_value.into_array(1)?;
12121236
let when_value = as_boolean_array(&when_value).map_err(|e| {
12131237
DataFusionError::Context(
12141238
"WHEN expression did not return a BooleanArray".to_string(),
12151239
Box::new(e),
12161240
)
12171241
})?;
12181242

1219-
// For the true and false/null selection vectors, bypass filtering and merging
1220-
// results. This avoids materializing the array for the other branch which we will discard
1221-
// entirely anyway.
12221243
let true_count = when_value.true_count();
1223-
if true_count == batch.num_rows() {
1224-
return self.body.when_then_expr[0].1.evaluate(batch);
1244+
if true_count == when_value.len() {
1245+
// All input rows are true, just call the 'then' expression
1246+
self.body.when_then_expr[0].1.evaluate(batch)
12251247
} else if true_count == 0 {
1226-
return self.body.else_expr.as_ref().unwrap().evaluate(batch);
1227-
}
1228-
1229-
if projected.projection.len() < batch.num_columns() {
1248+
// All input rows are false/null, just call the 'else' expression
1249+
self.body.else_expr.as_ref().unwrap().evaluate(batch)
1250+
} else if projected.projection.len() < batch.num_columns() {
1251+
// The case expressions do not use all the columns of the input batch.
1252+
// Project first to reduce time spent filtering.
12301253
let projected_batch = batch.project(&projected.projection)?;
1231-
projected
1232-
.body
1233-
.expr_or_expr(&projected_batch, &when_value, &return_type)
1254+
projected.body.expr_or_expr(&projected_batch, when_value)
12341255
} else {
1235-
self.body.expr_or_expr(batch, &when_value, &return_type)
1256+
// All columns are used in the case expressions, so there is no need to project.
1257+
self.body.expr_or_expr(batch, when_value)
12361258
}
12371259
}
12381260
}
@@ -1244,23 +1266,7 @@ impl PhysicalExpr for CaseExpr {
12441266
}
12451267

12461268
fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
1247-
// since all then results have the same data type, we can choose any one as the
1248-
// return data type except for the null.
1249-
let mut data_type = DataType::Null;
1250-
for i in 0..self.body.when_then_expr.len() {
1251-
data_type = self.body.when_then_expr[i].1.data_type(input_schema)?;
1252-
if !data_type.equals_datatype(&DataType::Null) {
1253-
break;
1254-
}
1255-
}
1256-
// if all then results are null, we use data type of else expr instead if possible.
1257-
if data_type.equals_datatype(&DataType::Null) {
1258-
if let Some(e) = &self.body.else_expr {
1259-
data_type = e.data_type(input_schema)?;
1260-
}
1261-
}
1262-
1263-
Ok(data_type)
1269+
self.body.data_type(input_schema)
12641270
}
12651271

12661272
fn nullable(&self, input_schema: &Schema) -> Result<bool> {

0 commit comments

Comments
 (0)