diff --git a/backends-velox/src/test/scala/org/apache/gluten/execution/FallbackSuite.scala b/backends-velox/src/test/scala/org/apache/gluten/execution/FallbackSuite.scala index ddc9cc923e6..53af1664c85 100644 --- a/backends-velox/src/test/scala/org/apache/gluten/execution/FallbackSuite.scala +++ b/backends-velox/src/test/scala/org/apache/gluten/execution/FallbackSuite.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.execution.{ColumnarBroadcastExchangeExec, ColumnarSh import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, AQEShuffleReadExec} import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, SortMergeJoinExec} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.utils.GlutenSuiteUtils class FallbackSuite extends VeloxWholeStageTransformerSuite with AdaptiveSparkPlanHelper { @@ -62,12 +63,53 @@ class FallbackSuite extends VeloxWholeStageTransformerSuite with AdaptiveSparkPl .write .format("parquet") .saveAsTable("tmp3") + // ORC files are written with DECIMAL(38, 18) (Hive's native storage precision). + // tmp4/tmp5 declare DECIMAL(20, 0) pointing to the same ORC files, + // so the reader must handle a precision/scale mismatch. + spark + .range(100) + .selectExpr( + "cast(id as decimal(38, 18)) as c1", + "cast(id % 3 as int) as c2", + "cast(id % 9 as timestamp) as c3") + .write + .format("orc") + .saveAsTable("tmp4_wide") + spark + .range(100) + .selectExpr( + "cast(id as decimal(38, 18)) as c1", + "cast(id % 3 as int) as c2", + "cast(id % 5 as timestamp) as c3") + .write + .format("orc") + .saveAsTable("tmp5_wide") + val loc4 = spark + .sql("DESCRIBE FORMATTED tmp4_wide") + .filter("col_name = 'Location'") + .select("data_type") + .collect()(0) + .getString(0) + val loc5 = spark + .sql("DESCRIBE FORMATTED tmp5_wide") + .filter("col_name = 'Location'") + .select("data_type") + .collect()(0) + .getString(0) + spark.sql( + s"CREATE TABLE tmp4 (c1 DECIMAL(20, 0), c2 INT, c3 TIMESTAMP) USING ORC LOCATION '$loc4'") + spark.sql( + s"CREATE TABLE tmp5 (c1 DECIMAL(20, 0), c2 INT, c3 TIMESTAMP) USING ORC LOCATION '$loc5'") } override protected def afterAll(): Unit = { spark.sql("drop table tmp1") spark.sql("drop table tmp2") spark.sql("drop table tmp3") + spark.sql("drop table tmp4_wide") + spark.sql("drop table tmp5_wide") + spark.sql("drop table tmp4") + spark.sql("drop table tmp5") super.afterAll() } @@ -382,4 +424,127 @@ class FallbackSuite extends VeloxWholeStageTransformerSuite with AdaptiveSparkPl } } } + + test("For decimal-key joins, if one side falls back to Spark, force fallback the other side") { + // ORC files are written with DECIMAL(38, 18) (Hive's native storage precision). + // The metastore tables tmp4/tmp5 declare DECIMAL(20, 0) and point to the + // same ORC files, so the reader must handle a precision/scale mismatch. + // Selecting only c2 (INT) -> native FileSourceScanExecTransformer. + // Selecting c3 (TIMESTAMP) in addition -> native validation fails -> + // vanilla FileSourceScanExec. + + // -- SortMergeJoin ------------------------------------------------------------------ + + val sql1 = "SELECT /*+ MERGE(tmp4) */ tmp4.c2 AS 4c2, tmp4.c3 AS 4c3, " + + "tmp5.c2 AS 5c2, tmp5.c3 AS 5c3 FROM tmp4 JOIN tmp5 ON tmp4.c1 = tmp5.c1" + withSQLConf( + GlutenConfig.COLUMNAR_FORCE_SHUFFLED_HASH_JOIN_ENABLED.key -> "false", + GlutenConfig.COLUMNAR_SHUFFLED_HASH_JOIN_ENABLED.key -> "false") { + checkAnswer( + spark.sql(sql1), + spark.sql( + "SELECT tmp4_wide.c2 AS 4c2, tmp4_wide.c3 AS 4c3, " + + "tmp5_wide.c2 AS 5c2, tmp5_wide.c3 AS 5c3 " + + "FROM tmp4_wide JOIN tmp5_wide ON tmp4_wide.c1 = tmp5_wide.c1") + ) + } + + val sql2 = "SELECT /*+ MERGE(tmp4) */ tmp4.c2 AS 4c2, tmp4.c3 AS 4c3, " + + "tmp5.c2 AS 5c2 FROM tmp4 JOIN tmp5 ON tmp4.c1 = tmp5.c1" + withSQLConf( + GlutenConfig.COLUMNAR_FORCE_SHUFFLED_HASH_JOIN_ENABLED.key -> "false", + GlutenConfig.COLUMNAR_SHUFFLED_HASH_JOIN_ENABLED.key -> "false") { + checkAnswer( + spark.sql(sql2), + spark.sql( + "SELECT tmp4_wide.c2 AS 4c2, tmp4_wide.c3 AS 4c3, " + + "tmp5_wide.c2 AS 5c2 " + + "FROM tmp4_wide JOIN tmp5_wide ON tmp4_wide.c1 = tmp5_wide.c1") + ) + } + + val sql3 = "SELECT /*+ MERGE(tmp4) */ tmp4.c2 AS 4c2, " + + "tmp5.c2 AS 5c2, tmp5.c3 AS 5c3 FROM tmp4 JOIN tmp5 ON tmp4.c1 = tmp5.c1" + withSQLConf( + GlutenConfig.COLUMNAR_FORCE_SHUFFLED_HASH_JOIN_ENABLED.key -> "false", + GlutenConfig.COLUMNAR_SHUFFLED_HASH_JOIN_ENABLED.key -> "false") { + checkAnswer( + spark.sql(sql3), + spark.sql( + "SELECT tmp4_wide.c2 AS 4c2, " + + "tmp5_wide.c2 AS 5c2, tmp5_wide.c3 AS 5c3 " + + "FROM tmp4_wide JOIN tmp5_wide ON tmp4_wide.c1 = tmp5_wide.c1") + ) + } + + // -- ShuffledHashJoin --------------------------------------------------------------- + + val sql4 = "SELECT /*+ SHUFFLE_HASH(tmp4) */ tmp4.c2 AS 4c2, tmp4.c3 AS 4c3, " + + "tmp5.c2 AS 5c2, tmp5.c3 AS 5c3 FROM tmp4 JOIN tmp5 ON tmp4.c1 = tmp5.c1" + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + checkAnswer( + spark.sql(sql4), + spark.sql( + "SELECT tmp4_wide.c2 AS 4c2, tmp4_wide.c3 AS 4c3, " + + "tmp5_wide.c2 AS 5c2, tmp5_wide.c3 AS 5c3 " + + "FROM tmp4_wide JOIN tmp5_wide ON tmp4_wide.c1 = tmp5_wide.c1") + ) + } + + val sql5 = "SELECT /*+ SHUFFLE_HASH(tmp4) */ tmp4.c2 AS 4c2, tmp4.c3 AS 4c3, " + + "tmp5.c2 AS 5c2 FROM tmp4 JOIN tmp5 ON tmp4.c1 = tmp5.c1" + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + checkAnswer( + spark.sql(sql5), + spark.sql( + "SELECT tmp4_wide.c2 AS 4c2, tmp4_wide.c3 AS 4c3, " + + "tmp5_wide.c2 AS 5c2 " + + "FROM tmp4_wide JOIN tmp5_wide ON tmp4_wide.c1 = tmp5_wide.c1") + ) + } + + val sql6 = "SELECT /*+ SHUFFLE_HASH(tmp4) */ tmp4.c2 AS 4c2, " + + "tmp5.c2 AS 5c2, tmp5.c3 AS 5c3 FROM tmp4 JOIN tmp5 ON tmp4.c1 = tmp5.c1" + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + checkAnswer( + spark.sql(sql6), + spark.sql( + "SELECT tmp4_wide.c2 AS 4c2, " + + "tmp5_wide.c2 AS 5c2, tmp5_wide.c3 AS 5c3 " + + "FROM tmp4_wide JOIN tmp5_wide ON tmp4_wide.c1 = tmp5_wide.c1") + ) + } + + // -- BroadcastHashJoin -------------------------------------------------------------- + + val sql7 = "SELECT tmp4.c2 AS 4c2, tmp4.c3 AS 4c3, " + + "tmp5.c2 AS 5c2, tmp5.c3 AS 5c3 FROM tmp4 JOIN tmp5 ON tmp4.c1 = tmp5.c1" + checkAnswer( + spark.sql(sql7), + spark.sql( + "SELECT tmp4_wide.c2 AS 4c2, tmp4_wide.c3 AS 4c3, " + + "tmp5_wide.c2 AS 5c2, tmp5_wide.c3 AS 5c3 " + + "FROM tmp4_wide JOIN tmp5_wide ON tmp4_wide.c1 = tmp5_wide.c1") + ) + + val sql8 = "SELECT tmp4.c2 AS 4c2, tmp4.c3 AS 4c3, " + + "tmp5.c2 AS 5c2 FROM tmp4 JOIN tmp5 ON tmp4.c1 = tmp5.c1" + checkAnswer( + spark.sql(sql8), + spark.sql( + "SELECT tmp4_wide.c2 AS 4c2, tmp4_wide.c3 AS 4c3, " + + "tmp5_wide.c2 AS 5c2 " + + "FROM tmp4_wide JOIN tmp5_wide ON tmp4_wide.c1 = tmp5_wide.c1") + ) + + val sql9 = "SELECT tmp4.c2 AS 4c2, " + + "tmp5.c2 AS 5c2, tmp5.c3 AS 5c3 FROM tmp4 JOIN tmp5 ON tmp4.c1 = tmp5.c1" + checkAnswer( + spark.sql(sql9), + spark.sql( + "SELECT tmp4_wide.c2 AS 4c2, " + + "tmp5_wide.c2 AS 5c2, tmp5_wide.c3 AS 5c3 " + + "FROM tmp4_wide JOIN tmp5_wide ON tmp4_wide.c1 = tmp5_wide.c1") + ) + } }