Skip to content

Commit 592b924

Browse files
authored
feat(substrait): handle emit_kind when consuming Substrait plans (#13127)
* feat(substrait): handle emit_kind when consuming Substrait plans * cargo fmt * avoid projection flattening for volatile expressions * simplify application of apply_emit_kind
1 parent a34e237 commit 592b924

File tree

5 files changed

+483
-26
lines changed

5 files changed

+483
-26
lines changed

datafusion/substrait/src/logical_plan/consumer.rs

Lines changed: 174 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ use datafusion::arrow::array::{new_empty_array, AsArray};
5656
use datafusion::arrow::temporal_conversions::NANOSECONDS;
5757
use datafusion::common::scalar::ScalarStructBuilder;
5858
use datafusion::dataframe::DataFrame;
59+
use datafusion::logical_expr::builder::project;
5960
use datafusion::logical_expr::expr::InList;
6061
use datafusion::logical_expr::{
6162
col, expr, Cast, Extension, GroupingSet, Like, LogicalPlanBuilder, Partitioning,
@@ -80,6 +81,7 @@ use substrait::proto::expression::literal::{
8081
use substrait::proto::expression::subquery::SubqueryType;
8182
use substrait::proto::expression::{FieldReference, Literal, ScalarFunction};
8283
use substrait::proto::read_rel::local_files::file_or_files::PathType::UriFile;
84+
use substrait::proto::rel_common::{Emit, EmitKind};
8385
use substrait::proto::{
8486
aggregate_function::AggregationInvocation,
8587
expression::{
@@ -93,9 +95,9 @@ use substrait::proto::{
9395
join_rel, plan_rel, r#type,
9496
read_rel::ReadType,
9597
rel::RelType,
96-
set_rel,
98+
rel_common, set_rel,
9799
sort_field::{SortDirection, SortKind::*},
98-
AggregateFunction, Expression, NamedStruct, Plan, Rel, Type,
100+
AggregateFunction, Expression, NamedStruct, Plan, Rel, RelCommon, Type,
99101
};
100102
use substrait::proto::{ExtendedExpression, FunctionArgument, SortField};
101103

@@ -562,42 +564,51 @@ pub async fn from_substrait_rel(
562564
rel: &Rel,
563565
extensions: &Extensions,
564566
) -> Result<LogicalPlan> {
565-
match &rel.rel_type {
567+
let plan: Result<LogicalPlan> = match &rel.rel_type {
566568
Some(RelType::Project(p)) => {
567569
if let Some(input) = p.input.as_ref() {
568570
let mut input = LogicalPlanBuilder::from(
569571
from_substrait_rel(ctx, input, extensions).await?,
570572
);
571-
let mut names: HashSet<String> = HashSet::new();
572-
let mut exprs: Vec<Expr> = vec![];
573-
for e in &p.expressions {
574-
let x =
575-
from_substrait_rex(ctx, e, input.clone().schema(), extensions)
573+
let original_schema = input.schema().clone();
574+
575+
// Ensure that all expressions have a unique display name, so that
576+
// validate_unique_names does not fail when constructing the project.
577+
let mut name_tracker = NameTracker::new();
578+
579+
// By default, a Substrait Project emits all inputs fields followed by all expressions.
580+
// We build the explicit expressions first, and then the input expressions to avoid
581+
// adding aliases to the explicit expressions (as part of ensuring unique names).
582+
//
583+
// This is helpful for plan visualization and tests, because when DataFusion produces
584+
// Substrait Projects it adds an output mapping that excludes all input columns
585+
// leaving only explicit expressions.
586+
587+
let mut explicit_exprs: Vec<Expr> = vec![];
588+
for expr in &p.expressions {
589+
let e =
590+
from_substrait_rex(ctx, expr, input.clone().schema(), extensions)
576591
.await?;
577592
// if the expression is WindowFunction, wrap in a Window relation
578-
if let Expr::WindowFunction(_) = &x {
593+
if let Expr::WindowFunction(_) = &e {
579594
// Adding the same expression here and in the project below
580595
// works because the project's builder uses columnize_expr(..)
581596
// to transform it into a column reference
582-
input = input.window(vec![x.clone()])?
597+
input = input.window(vec![e.clone()])?
583598
}
584-
// Ensure the expression has a unique display name, so that project's
585-
// validate_unique_names doesn't fail
586-
let name = x.schema_name().to_string();
587-
let mut new_name = name.clone();
588-
let mut i = 0;
589-
while names.contains(&new_name) {
590-
new_name = format!("{}__temp__{}", name, i);
591-
i += 1;
592-
}
593-
if new_name != name {
594-
exprs.push(x.alias(new_name.clone()));
595-
} else {
596-
exprs.push(x);
597-
}
598-
names.insert(new_name);
599+
explicit_exprs.push(name_tracker.get_uniquely_named_expr(e)?);
599600
}
600-
input.project(exprs)?.build()
601+
602+
let mut final_exprs: Vec<Expr> = vec![];
603+
for index in 0..original_schema.fields().len() {
604+
let e = Expr::Column(Column::from(
605+
original_schema.qualified_field(index),
606+
));
607+
final_exprs.push(name_tracker.get_uniquely_named_expr(e)?);
608+
}
609+
final_exprs.append(&mut explicit_exprs);
610+
611+
input.project(final_exprs)?.build()
601612
} else {
602613
not_impl_err!("Projection without an input is not supported")
603614
}
@@ -1074,6 +1085,143 @@ pub async fn from_substrait_rel(
10741085
}))
10751086
}
10761087
_ => not_impl_err!("Unsupported RelType: {:?}", rel.rel_type),
1088+
};
1089+
apply_emit_kind(retrieve_rel_common(rel), plan?)
1090+
}
1091+
1092+
fn retrieve_rel_common(rel: &Rel) -> Option<&RelCommon> {
1093+
match rel.rel_type.as_ref() {
1094+
None => None,
1095+
Some(rt) => match rt {
1096+
RelType::Read(r) => r.common.as_ref(),
1097+
RelType::Filter(f) => f.common.as_ref(),
1098+
RelType::Fetch(f) => f.common.as_ref(),
1099+
RelType::Aggregate(a) => a.common.as_ref(),
1100+
RelType::Sort(s) => s.common.as_ref(),
1101+
RelType::Join(j) => j.common.as_ref(),
1102+
RelType::Project(p) => p.common.as_ref(),
1103+
RelType::Set(s) => s.common.as_ref(),
1104+
RelType::ExtensionSingle(e) => e.common.as_ref(),
1105+
RelType::ExtensionMulti(e) => e.common.as_ref(),
1106+
RelType::ExtensionLeaf(e) => e.common.as_ref(),
1107+
RelType::Cross(c) => c.common.as_ref(),
1108+
RelType::Reference(_) => None,
1109+
RelType::Write(w) => w.common.as_ref(),
1110+
RelType::Ddl(d) => d.common.as_ref(),
1111+
RelType::HashJoin(j) => j.common.as_ref(),
1112+
RelType::MergeJoin(j) => j.common.as_ref(),
1113+
RelType::NestedLoopJoin(j) => j.common.as_ref(),
1114+
RelType::Window(w) => w.common.as_ref(),
1115+
RelType::Exchange(e) => e.common.as_ref(),
1116+
RelType::Expand(e) => e.common.as_ref(),
1117+
},
1118+
}
1119+
}
1120+
1121+
fn retrieve_emit_kind(rel_common: Option<&RelCommon>) -> EmitKind {
1122+
// the default EmitKind is Direct if it is not set explicitly
1123+
let default = EmitKind::Direct(rel_common::Direct {});
1124+
rel_common
1125+
.and_then(|rc| rc.emit_kind.as_ref())
1126+
.map_or(default, |ek| ek.clone())
1127+
}
1128+
1129+
fn contains_volatile_expr(proj: &Projection) -> Result<bool> {
1130+
for expr in proj.expr.iter() {
1131+
if expr.is_volatile()? {
1132+
return Ok(true);
1133+
}
1134+
}
1135+
Ok(false)
1136+
}
1137+
1138+
fn apply_emit_kind(
1139+
rel_common: Option<&RelCommon>,
1140+
plan: LogicalPlan,
1141+
) -> Result<LogicalPlan> {
1142+
match retrieve_emit_kind(rel_common) {
1143+
EmitKind::Direct(_) => Ok(plan),
1144+
EmitKind::Emit(Emit { output_mapping }) => {
1145+
// It is valid to reference the same field multiple times in the Emit
1146+
// In this case, we need to provide unique names to avoid collisions
1147+
let mut name_tracker = NameTracker::new();
1148+
match plan {
1149+
// To avoid adding a projection on top of a projection, we apply special case
1150+
// handling to flatten Substrait Emits. This is only applicable if none of the
1151+
// expressions in the projection are volatile. This is to avoid issues like
1152+
// converting a single call of the random() function into multiple calls due to
1153+
// duplicate fields in the output_mapping.
1154+
LogicalPlan::Projection(proj) if !contains_volatile_expr(&proj)? => {
1155+
let mut exprs: Vec<Expr> = vec![];
1156+
for field in output_mapping {
1157+
let expr = proj.expr
1158+
.get(field as usize)
1159+
.ok_or_else(|| substrait_datafusion_err!(
1160+
"Emit output field {} cannot be resolved in input schema {}",
1161+
field, proj.input.schema().clone()
1162+
))?;
1163+
exprs.push(name_tracker.get_uniquely_named_expr(expr.clone())?);
1164+
}
1165+
1166+
let input = Arc::unwrap_or_clone(proj.input);
1167+
project(input, exprs)
1168+
}
1169+
// Otherwise we just handle the output_mapping as a projection
1170+
_ => {
1171+
let input_schema = plan.schema();
1172+
1173+
let mut exprs: Vec<Expr> = vec![];
1174+
for index in output_mapping.into_iter() {
1175+
let column = Expr::Column(Column::from(
1176+
input_schema.qualified_field(index as usize),
1177+
));
1178+
let expr = name_tracker.get_uniquely_named_expr(column)?;
1179+
exprs.push(expr);
1180+
}
1181+
1182+
project(plan, exprs)
1183+
}
1184+
}
1185+
}
1186+
}
1187+
}
1188+
1189+
struct NameTracker {
1190+
seen_names: HashSet<String>,
1191+
}
1192+
1193+
enum NameTrackerStatus {
1194+
NeverSeen,
1195+
SeenBefore,
1196+
}
1197+
1198+
impl NameTracker {
1199+
fn new() -> Self {
1200+
NameTracker {
1201+
seen_names: HashSet::default(),
1202+
}
1203+
}
1204+
fn get_unique_name(&mut self, name: String) -> (String, NameTrackerStatus) {
1205+
match self.seen_names.insert(name.clone()) {
1206+
true => (name, NameTrackerStatus::NeverSeen),
1207+
false => {
1208+
let mut counter = 0;
1209+
loop {
1210+
let candidate_name = format!("{}__temp__{}", name, counter);
1211+
if self.seen_names.insert(candidate_name.clone()) {
1212+
return (candidate_name, NameTrackerStatus::SeenBefore);
1213+
}
1214+
counter += 1;
1215+
}
1216+
}
1217+
}
1218+
}
1219+
1220+
fn get_uniquely_named_expr(&mut self, expr: Expr) -> Result<Expr> {
1221+
match self.get_unique_name(expr.name_for_alias()?) {
1222+
(_, NameTrackerStatus::NeverSeen) => Ok(expr),
1223+
(name, NameTrackerStatus::SeenBefore) => Ok(expr.alias(name)),
1224+
}
10771225
}
10781226
}
10791227

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
//! Tests for Emit Kind usage
19+
20+
#[cfg(test)]
21+
mod tests {
22+
use crate::utils::test::{add_plan_schemas_to_ctx, read_json};
23+
24+
use datafusion::common::Result;
25+
use datafusion::execution::SessionStateBuilder;
26+
use datafusion::prelude::{CsvReadOptions, SessionConfig, SessionContext};
27+
use datafusion_substrait::logical_plan::consumer::from_substrait_plan;
28+
use datafusion_substrait::logical_plan::producer::to_substrait_plan;
29+
30+
#[tokio::test]
31+
async fn project_respects_direct_emit_kind() -> Result<()> {
32+
let proto_plan = read_json(
33+
"tests/testdata/test_plans/emit_kind/direct_on_project.substrait.json",
34+
);
35+
let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto_plan)?;
36+
let plan = from_substrait_plan(&ctx, &proto_plan).await?;
37+
38+
let plan_str = format!("{}", plan);
39+
40+
assert_eq!(
41+
plan_str,
42+
"Projection: DATA.A AS a, DATA.B AS b, DATA.A + Int64(1) AS add1\
43+
\n TableScan: DATA"
44+
);
45+
Ok(())
46+
}
47+
48+
#[tokio::test]
49+
async fn handle_emit_as_project() -> Result<()> {
50+
let proto_plan = read_json(
51+
"tests/testdata/test_plans/emit_kind/emit_on_filter.substrait.json",
52+
);
53+
let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto_plan)?;
54+
let plan = from_substrait_plan(&ctx, &proto_plan).await?;
55+
56+
let plan_str = format!("{}", plan);
57+
58+
assert_eq!(
59+
plan_str,
60+
// Note that duplicate references in the remap are aliased
61+
"Projection: DATA.B, DATA.A AS A1, DATA.A AS DATA.A__temp__0 AS A2\
62+
\n Filter: DATA.B = Int64(2)\
63+
\n TableScan: DATA"
64+
);
65+
Ok(())
66+
}
67+
68+
async fn make_context() -> Result<SessionContext> {
69+
let state = SessionStateBuilder::new()
70+
.with_config(SessionConfig::default())
71+
.with_default_features()
72+
.build();
73+
let ctx = SessionContext::new_with_state(state);
74+
ctx.register_csv("data", "tests/testdata/data.csv", CsvReadOptions::default())
75+
.await?;
76+
Ok(ctx)
77+
}
78+
79+
#[tokio::test]
80+
async fn handle_emit_as_project_with_volatile_expr() -> Result<()> {
81+
let ctx = make_context().await?;
82+
83+
let df = ctx
84+
.sql("SELECT random() AS c1, a + 1 AS c2 FROM data")
85+
.await?;
86+
87+
let plan = df.into_unoptimized_plan();
88+
assert_eq!(
89+
format!("{}", plan),
90+
"Projection: random() AS c1, data.a + Int64(1) AS c2\
91+
\n TableScan: data"
92+
);
93+
94+
let proto = to_substrait_plan(&plan, &ctx)?;
95+
let plan2 = from_substrait_plan(&ctx, &proto).await?;
96+
// note how the Projections are not flattened
97+
assert_eq!(
98+
format!("{}", plan2),
99+
"Projection: random() AS c1, data.a + Int64(1) AS c2\
100+
\n Projection: data.a, data.b, data.c, data.d, data.e, data.f, random(), data.a + Int64(1)\
101+
\n TableScan: data"
102+
);
103+
Ok(())
104+
}
105+
106+
#[tokio::test]
107+
async fn handle_emit_as_project_without_volatile_exprs() -> Result<()> {
108+
let ctx = make_context().await?;
109+
let df = ctx.sql("SELECT a + 1, b + 2 FROM data").await?;
110+
111+
let plan = df.into_unoptimized_plan();
112+
assert_eq!(
113+
format!("{}", plan),
114+
"Projection: data.a + Int64(1), data.b + Int64(2)\
115+
\n TableScan: data"
116+
);
117+
118+
let proto = to_substrait_plan(&plan, &ctx)?;
119+
let plan2 = from_substrait_plan(&ctx, &proto).await?;
120+
121+
let plan1str = format!("{plan}");
122+
let plan2str = format!("{plan2}");
123+
assert_eq!(plan1str, plan2str);
124+
125+
Ok(())
126+
}
127+
}

datafusion/substrait/tests/cases/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
// under the License.
1717

1818
mod consumer_integration;
19+
mod emit_kind_tests;
1920
mod function_test;
2021
mod logical_plans;
2122
mod roundtrip_logical_plan;

0 commit comments

Comments
 (0)