Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.apache.spark.sql.execution.{ColumnarBroadcastExchangeExec, ColumnarSh
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, AQEShuffleReadExec}
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.utils.GlutenSuiteUtils

import scala.collection.mutable.ArrayBuffer
Expand Down Expand Up @@ -66,12 +67,53 @@ class FallbackSuite extends VeloxWholeStageTransformerSuite with AdaptiveSparkPl
.write
.format("parquet")
.saveAsTable("tmp3")
// ORC files are written with DECIMAL(38, 18) (Hive's native storage precision).
// tmp4/tmp5 declare DECIMAL(20, 0) pointing to the same ORC files,
// so the reader must handle a precision/scale mismatch.
spark
.range(100)
.selectExpr(
"cast(id as decimal(38, 18)) as c1",
"cast(id % 3 as int) as c2",
"cast(id % 9 as timestamp) as c3")
.write
.format("orc")
.saveAsTable("tmp4_wide")
spark
.range(100)
.selectExpr(
"cast(id as decimal(38, 18)) as c1",
"cast(id % 3 as int) as c2",
"cast(id % 5 as timestamp) as c3")
.write
.format("orc")
.saveAsTable("tmp5_wide")
val loc4 = spark
.sql("DESCRIBE FORMATTED tmp4_wide")
.filter("col_name = 'Location'")
.select("data_type")
.collect()(0)
.getString(0)
val loc5 = spark
.sql("DESCRIBE FORMATTED tmp5_wide")
.filter("col_name = 'Location'")
.select("data_type")
.collect()(0)
.getString(0)
spark.sql(
s"CREATE TABLE tmp4 (c1 DECIMAL(20, 0), c2 INT, c3 TIMESTAMP) USING ORC LOCATION '$loc4'")
spark.sql(
s"CREATE TABLE tmp5 (c1 DECIMAL(20, 0), c2 INT, c3 TIMESTAMP) USING ORC LOCATION '$loc5'")
}

override protected def afterAll(): Unit = {
spark.sql("drop table tmp1")
spark.sql("drop table tmp2")
spark.sql("drop table tmp3")
spark.sql("drop table tmp4_wide")
spark.sql("drop table tmp5_wide")
spark.sql("drop table tmp4")
spark.sql("drop table tmp5")

super.afterAll()
}
Expand Down Expand Up @@ -420,4 +462,127 @@ class FallbackSuite extends VeloxWholeStageTransformerSuite with AdaptiveSparkPl
spark.sparkContext.removeSparkListener(listener)
}
}

test("For decimal-key joins, if one side falls back to Spark, force fallback the other side") {
// ORC files are written with DECIMAL(38, 18) (Hive's native storage precision).
// The metastore tables tmp4/tmp5 declare DECIMAL(20, 0) and point to the
// same ORC files, so the reader must handle a precision/scale mismatch.
// Selecting only c2 (INT) -> native FileSourceScanExecTransformer.
// Selecting c3 (TIMESTAMP) in addition -> native validation fails ->
// vanilla FileSourceScanExec.

// -- SortMergeJoin ------------------------------------------------------------------

val sql1 = "SELECT /*+ MERGE(tmp4) */ tmp4.c2 AS 4c2, tmp4.c3 AS 4c3, " +
"tmp5.c2 AS 5c2, tmp5.c3 AS 5c3 FROM tmp4 JOIN tmp5 ON tmp4.c1 = tmp5.c1"
withSQLConf(
GlutenConfig.COLUMNAR_FORCE_SHUFFLED_HASH_JOIN_ENABLED.key -> "false",
GlutenConfig.COLUMNAR_SHUFFLED_HASH_JOIN_ENABLED.key -> "false") {
checkAnswer(
spark.sql(sql1),
spark.sql(
"SELECT tmp4_wide.c2 AS 4c2, tmp4_wide.c3 AS 4c3, " +
"tmp5_wide.c2 AS 5c2, tmp5_wide.c3 AS 5c3 " +
"FROM tmp4_wide JOIN tmp5_wide ON tmp4_wide.c1 = tmp5_wide.c1")
)
}

val sql2 = "SELECT /*+ MERGE(tmp4) */ tmp4.c2 AS 4c2, tmp4.c3 AS 4c3, " +
"tmp5.c2 AS 5c2 FROM tmp4 JOIN tmp5 ON tmp4.c1 = tmp5.c1"
withSQLConf(
GlutenConfig.COLUMNAR_FORCE_SHUFFLED_HASH_JOIN_ENABLED.key -> "false",
GlutenConfig.COLUMNAR_SHUFFLED_HASH_JOIN_ENABLED.key -> "false") {
checkAnswer(
spark.sql(sql2),
spark.sql(
"SELECT tmp4_wide.c2 AS 4c2, tmp4_wide.c3 AS 4c3, " +
"tmp5_wide.c2 AS 5c2 " +
"FROM tmp4_wide JOIN tmp5_wide ON tmp4_wide.c1 = tmp5_wide.c1")
)
}

val sql3 = "SELECT /*+ MERGE(tmp4) */ tmp4.c2 AS 4c2, " +
"tmp5.c2 AS 5c2, tmp5.c3 AS 5c3 FROM tmp4 JOIN tmp5 ON tmp4.c1 = tmp5.c1"
withSQLConf(
GlutenConfig.COLUMNAR_FORCE_SHUFFLED_HASH_JOIN_ENABLED.key -> "false",
GlutenConfig.COLUMNAR_SHUFFLED_HASH_JOIN_ENABLED.key -> "false") {
checkAnswer(
spark.sql(sql3),
spark.sql(
"SELECT tmp4_wide.c2 AS 4c2, " +
"tmp5_wide.c2 AS 5c2, tmp5_wide.c3 AS 5c3 " +
"FROM tmp4_wide JOIN tmp5_wide ON tmp4_wide.c1 = tmp5_wide.c1")
)
}

// -- ShuffledHashJoin ---------------------------------------------------------------

val sql4 = "SELECT /*+ SHUFFLE_HASH(tmp4) */ tmp4.c2 AS 4c2, tmp4.c3 AS 4c3, " +
"tmp5.c2 AS 5c2, tmp5.c3 AS 5c3 FROM tmp4 JOIN tmp5 ON tmp4.c1 = tmp5.c1"
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
checkAnswer(
spark.sql(sql4),
spark.sql(
"SELECT tmp4_wide.c2 AS 4c2, tmp4_wide.c3 AS 4c3, " +
"tmp5_wide.c2 AS 5c2, tmp5_wide.c3 AS 5c3 " +
"FROM tmp4_wide JOIN tmp5_wide ON tmp4_wide.c1 = tmp5_wide.c1")
)
}

val sql5 = "SELECT /*+ SHUFFLE_HASH(tmp4) */ tmp4.c2 AS 4c2, tmp4.c3 AS 4c3, " +
"tmp5.c2 AS 5c2 FROM tmp4 JOIN tmp5 ON tmp4.c1 = tmp5.c1"
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
checkAnswer(
spark.sql(sql5),
spark.sql(
"SELECT tmp4_wide.c2 AS 4c2, tmp4_wide.c3 AS 4c3, " +
"tmp5_wide.c2 AS 5c2 " +
"FROM tmp4_wide JOIN tmp5_wide ON tmp4_wide.c1 = tmp5_wide.c1")
)
}

val sql6 = "SELECT /*+ SHUFFLE_HASH(tmp4) */ tmp4.c2 AS 4c2, " +
"tmp5.c2 AS 5c2, tmp5.c3 AS 5c3 FROM tmp4 JOIN tmp5 ON tmp4.c1 = tmp5.c1"
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
checkAnswer(
spark.sql(sql6),
spark.sql(
"SELECT tmp4_wide.c2 AS 4c2, " +
"tmp5_wide.c2 AS 5c2, tmp5_wide.c3 AS 5c3 " +
"FROM tmp4_wide JOIN tmp5_wide ON tmp4_wide.c1 = tmp5_wide.c1")
)
}

// -- BroadcastHashJoin --------------------------------------------------------------

val sql7 = "SELECT tmp4.c2 AS 4c2, tmp4.c3 AS 4c3, " +
"tmp5.c2 AS 5c2, tmp5.c3 AS 5c3 FROM tmp4 JOIN tmp5 ON tmp4.c1 = tmp5.c1"
checkAnswer(
spark.sql(sql7),
spark.sql(
"SELECT tmp4_wide.c2 AS 4c2, tmp4_wide.c3 AS 4c3, " +
"tmp5_wide.c2 AS 5c2, tmp5_wide.c3 AS 5c3 " +
"FROM tmp4_wide JOIN tmp5_wide ON tmp4_wide.c1 = tmp5_wide.c1")
)

val sql8 = "SELECT tmp4.c2 AS 4c2, tmp4.c3 AS 4c3, " +
"tmp5.c2 AS 5c2 FROM tmp4 JOIN tmp5 ON tmp4.c1 = tmp5.c1"
checkAnswer(
spark.sql(sql8),
spark.sql(
"SELECT tmp4_wide.c2 AS 4c2, tmp4_wide.c3 AS 4c3, " +
"tmp5_wide.c2 AS 5c2 " +
"FROM tmp4_wide JOIN tmp5_wide ON tmp4_wide.c1 = tmp5_wide.c1")
)

val sql9 = "SELECT tmp4.c2 AS 4c2, " +
"tmp5.c2 AS 5c2, tmp5.c3 AS 5c3 FROM tmp4 JOIN tmp5 ON tmp4.c1 = tmp5.c1"
checkAnswer(
spark.sql(sql9),
spark.sql(
"SELECT tmp4_wide.c2 AS 4c2, " +
"tmp5_wide.c2 AS 5c2, tmp5_wide.c3 AS 5c3 " +
"FROM tmp4_wide JOIN tmp5_wide ON tmp4_wide.c1 = tmp5_wide.c1")
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.hive.{HiveExternalCatalog, HiveTableScanExecTransformer}
import org.apache.spark.sql.hive.client.HiveClient
import org.apache.spark.sql.internal.SQLConf

class GlutenHiveSQLQuerySuite extends GlutenHiveSQLQuerySuiteBase {

Expand Down Expand Up @@ -167,4 +168,164 @@ class GlutenHiveSQLQuerySuite extends GlutenHiveSQLQuerySuiteBase {
}
}
}

testGluten(
"GLUTEN-11980: For decimal-key joins, " +
"if one side falls back to Spark, force fallback the other side") {
withSQLConf("spark.sql.hive.convertMetastoreOrc" -> "false") {
withTable("htmp1_wide", "htmp2_wide", "htmp1", "htmp2") {
// ORC files are written with DECIMAL(38, 18) (Hive's native storage precision).
// The metastore tables htmp1/htmp2 declare DECIMAL(20, 0) and point to the
// same ORC files, so the reader must handle a precision/scale mismatch.
// Selecting only c2 (INT) -> native HiveTableScanExecTransformer.
// Selecting c3 (TIMESTAMP) in addition -> native validation fails ->
// vanilla HiveTableScanExec.
sql("CREATE TABLE htmp1_wide (c1 DECIMAL(38, 18), c2 INT, c3 TIMESTAMP) STORED AS ORC")
sql("CREATE TABLE htmp2_wide (c1 DECIMAL(38, 18), c2 INT, c3 TIMESTAMP) STORED AS ORC")
sql("INSERT INTO htmp1_wide " +
"SELECT cast(id AS DECIMAL(38, 18)), id % 3, cast(id % 9 AS TIMESTAMP) " +
"FROM range(1, 101)")
sql("INSERT INTO htmp2_wide " +
"SELECT cast(id AS DECIMAL(38, 18)), id % 3, cast(id % 5 AS TIMESTAMP) " +
"FROM range(1, 101)")
val loc1 = sql("DESCRIBE FORMATTED htmp1_wide")
.filter("col_name = 'Location'")
.select("data_type")
.collect()(0)
.getString(0)
val loc2 = sql("DESCRIBE FORMATTED htmp2_wide")
.filter("col_name = 'Location'")
.select("data_type")
.collect()(0)
.getString(0)
sql("CREATE TABLE htmp1 (c1 DECIMAL(20, 0), c2 INT, c3 TIMESTAMP) " +
s"STORED AS ORC LOCATION '$loc1'")
sql("CREATE TABLE htmp2 (c1 DECIMAL(20, 0), c2 INT, c3 TIMESTAMP) " +
s"STORED AS ORC LOCATION '$loc2'")

// -- SortMergeJoin ------------------------------------------------------------------

val sql1 =
"SELECT /*+ MERGE(htmp1) */ htmp1.c2 AS 1c2, htmp1.c3 AS 1c3, " +
"htmp2.c2 AS 2c2, htmp2.c3 AS 2c3 FROM htmp1 JOIN htmp2 ON htmp1.c1 = htmp2.c1"
withSQLConf(
GlutenConfig.COLUMNAR_FORCE_SHUFFLED_HASH_JOIN_ENABLED.key -> "false",
GlutenConfig.COLUMNAR_SHUFFLED_HASH_JOIN_ENABLED.key -> "false") {
checkAnswer(
spark.sql(sql1),
spark.sql(
"SELECT htmp1_wide.c2 AS 1c2, htmp1_wide.c3 AS 1c3, " +
"htmp2_wide.c2 AS 2c2, htmp2_wide.c3 AS 2c3 " +
"FROM htmp1_wide JOIN htmp2_wide ON htmp1_wide.c1 = htmp2_wide.c1")
)
}

val sql2 =
"SELECT /*+ MERGE(htmp1) */ htmp1.c2 AS 1c2, htmp1.c3 AS 1c3, " +
"htmp2.c2 AS 2c2 FROM htmp1 JOIN htmp2 ON htmp1.c1 = htmp2.c1"
withSQLConf(
GlutenConfig.COLUMNAR_FORCE_SHUFFLED_HASH_JOIN_ENABLED.key -> "false",
GlutenConfig.COLUMNAR_SHUFFLED_HASH_JOIN_ENABLED.key -> "false") {
checkAnswer(
spark.sql(sql2),
spark.sql(
"SELECT htmp1_wide.c2 AS 1c2, htmp1_wide.c3 AS 1c3, " +
"htmp2_wide.c2 AS 2c2 " +
"FROM htmp1_wide JOIN htmp2_wide ON htmp1_wide.c1 = htmp2_wide.c1")
)
}

val sql3 =
"SELECT /*+ MERGE(htmp1) */ htmp1.c2 AS 1c2, " +
"htmp2.c2 AS 2c2, htmp2.c3 AS 2c3 FROM htmp1 JOIN htmp2 ON htmp1.c1 = htmp2.c1"
withSQLConf(
GlutenConfig.COLUMNAR_FORCE_SHUFFLED_HASH_JOIN_ENABLED.key -> "false",
GlutenConfig.COLUMNAR_SHUFFLED_HASH_JOIN_ENABLED.key -> "false") {
checkAnswer(
spark.sql(sql3),
spark.sql(
"SELECT htmp1_wide.c2 AS 1c2, " +
"htmp2_wide.c2 AS 2c2, htmp2_wide.c3 AS 2c3 " +
"FROM htmp1_wide JOIN htmp2_wide ON htmp1_wide.c1 = htmp2_wide.c1")
)
}

// -- ShuffledHashJoin ---------------------------------------------------------------

val sql4 =
"SELECT /*+ SHUFFLE_HASH(htmp1) */ htmp1.c2 AS 1c2, htmp1.c3 AS 1c3, " +
"htmp2.c2 AS 2c2, htmp2.c3 AS 2c3 FROM htmp1 JOIN htmp2 ON htmp1.c1 = htmp2.c1"
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
checkAnswer(
spark.sql(sql4),
spark.sql(
"SELECT htmp1_wide.c2 AS 1c2, htmp1_wide.c3 AS 1c3, " +
"htmp2_wide.c2 AS 2c2, htmp2_wide.c3 AS 2c3 " +
"FROM htmp1_wide JOIN htmp2_wide ON htmp1_wide.c1 = htmp2_wide.c1")
)
}

val sql5 =
"SELECT /*+ SHUFFLE_HASH(htmp1) */ htmp1.c2 AS 1c2, htmp1.c3 AS 1c3, " +
"htmp2.c2 AS 2c2 FROM htmp1 JOIN htmp2 ON htmp1.c1 = htmp2.c1"
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
checkAnswer(
spark.sql(sql5),
spark.sql(
"SELECT htmp1_wide.c2 AS 1c2, htmp1_wide.c3 AS 1c3, " +
"htmp2_wide.c2 AS 2c2 " +
"FROM htmp1_wide JOIN htmp2_wide ON htmp1_wide.c1 = htmp2_wide.c1")
)
}

val sql6 =
"SELECT /*+ SHUFFLE_HASH(htmp1) */ htmp1.c2 AS 1c2, " +
"htmp2.c2 AS 2c2, htmp2.c3 AS 2c3 FROM htmp1 JOIN htmp2 ON htmp1.c1 = htmp2.c1"
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
checkAnswer(
spark.sql(sql6),
spark.sql(
"SELECT htmp1_wide.c2 AS 1c2, " +
"htmp2_wide.c2 AS 2c2, htmp2_wide.c3 AS 2c3 " +
"FROM htmp1_wide JOIN htmp2_wide ON htmp1_wide.c1 = htmp2_wide.c1")
)
}

// -- BroadcastHashJoin --------------------------------------------------------------

val sql7 =
"SELECT htmp1.c2 AS 1c2, htmp1.c3 AS 1c3, " +
"htmp2.c2 AS 2c2, htmp2.c3 AS 2c3 FROM htmp1 JOIN htmp2 ON htmp1.c1 = htmp2.c1"
checkAnswer(
spark.sql(sql7),
spark.sql(
"SELECT htmp1_wide.c2 AS 1c2, htmp1_wide.c3 AS 1c3, " +
"htmp2_wide.c2 AS 2c2, htmp2_wide.c3 AS 2c3 " +
"FROM htmp1_wide JOIN htmp2_wide ON htmp1_wide.c1 = htmp2_wide.c1")
)

val sql8 =
"SELECT htmp1.c2 AS 1c2, htmp1.c3 AS 1c3, " +
"htmp2.c2 AS 2c2 FROM htmp1 JOIN htmp2 ON htmp1.c1 = htmp2.c1"
checkAnswer(
spark.sql(sql8),
spark.sql(
"SELECT htmp1_wide.c2 AS 1c2, htmp1_wide.c3 AS 1c3, " +
"htmp2_wide.c2 AS 2c2 " +
"FROM htmp1_wide JOIN htmp2_wide ON htmp1_wide.c1 = htmp2_wide.c1")
)

val sql9 =
"SELECT htmp1.c2 AS 1c2, " +
"htmp2.c2 AS 2c2, htmp2.c3 AS 2c3 FROM htmp1 JOIN htmp2 ON htmp1.c1 = htmp2.c1"
checkAnswer(
spark.sql(sql9),
spark.sql(
"SELECT htmp1_wide.c2 AS 1c2, " +
"htmp2_wide.c2 AS 2c2, htmp2_wide.c3 AS 2c3 " +
"FROM htmp1_wide JOIN htmp2_wide ON htmp1_wide.c1 = htmp2_wide.c1")
)
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package org.apache.spark.sql.hive.execution

import org.apache.gluten.execution.TransformSupport

import org.apache.spark.SparkConf
import org.apache.spark.{DebugFilesystem, SparkConf}
import org.apache.spark.internal.config
import org.apache.spark.internal.config.UI.UI_ENABLED
import org.apache.spark.sql.{DataFrame, GlutenSQLTestsTrait, SparkSession}
Expand Down Expand Up @@ -61,6 +61,16 @@ abstract class GlutenHiveSQLQuerySuiteBase extends GlutenSQLTestsTrait {
}
}

override def afterEach(): Unit = {
// Clear any file handles left open by Hive ORC's SplitGenerator background threads.
// OrcInputFormat$SplitGenerator.populateAndCacheStripeDetails() opens ORC readers
// via OrcFile.createReader() in background FutureTasks that are never explicitly closed
// (Hive bug HIVE-17183), leaking handles into DebugFilesystem.openStreams and causing
// SharedSparkSessionBase.afterEach() to abort the suite via assertNoOpenStreams().
DebugFilesystem.clearOpenStreams()
super.afterEach()
}

protected def defaultSparkConf: SparkConf = {
val conf = new SparkConf()
.set("spark.master", "local[1]")
Expand Down
Loading
Loading