Skip to content

Commit

Permalink
[GLUTEN-8836][CH] Support partition values with escape char (#8840)
Browse files Browse the repository at this point in the history
  • Loading branch information
lwz9103 authored Mar 5, 2025
1 parent a1db382 commit be909b6
Show file tree
Hide file tree
Showing 13 changed files with 453 additions and 109 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ package org.apache.gluten.execution
import org.apache.spark.sql.connector.read.InputPartition
import org.apache.spark.sql.types.StructType

import org.apache.hadoop.fs.Path

import java.net.URI

case class MergeTreePartRange(
name: String,
dirName: String,
Expand All @@ -32,7 +36,7 @@ case class MergeTreePartRange(
}
}

case class MergeTreePartSplit(
case class MergeTreePartSplit private (
name: String,
dirName: String,
targetNode: String,
Expand All @@ -44,6 +48,22 @@ case class MergeTreePartSplit(
}
}

object MergeTreePartSplit {
def apply(
name: String,
dirName: String,
targetNode: String,
start: Long,
length: Long,
bytesOnDisk: Long
): MergeTreePartSplit = {
// Ref to org.apache.spark.sql.delta.files.TahoeFileIndex.absolutePath
val uriDecodeName = new Path(new URI(name)).toString
val uriDecodeDirName = new Path(new URI(dirName)).toString
new MergeTreePartSplit(uriDecodeName, uriDecodeDirName, targetNode, start, length, bytesOnDisk)
}
}

case class GlutenMergeTreePartition(
index: Int,
engine: String,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ trait MergeTreeFileCommitProtocol extends FileCommitProtocol {
dir: Option[String],
ext: String): String = {

val partitionStr = dir.map(p => new Path(p).toUri.toString)
val partitionStr = dir.map(p => new Path(p).toString)
val bucketIdStr = ext.split("\\.").headOption.filter(_.startsWith("_")).map(_.substring(1))
val split = taskContext.getTaskAttemptID.getTaskID.getId

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,8 @@ object AddFileTags {
rootNode.put("nullCount", "")
// Add the `stats` into delta meta log
val metricsStats = mapper.writeValueAsString(rootNode)
AddFile(name, partitionValues, bytesOnDisk, modificationTime, dataChange, metricsStats, tags)
val uriName = new Path(name).toUri.toString
AddFile(uriName, partitionValues, bytesOnDisk, modificationTime, dataChange, metricsStats, tags)
}

def addFileToAddMergeTreeParts(addFile: AddFile): AddMergeTreeParts = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,14 @@ import org.apache.gluten.test.AllDataTypesWithComplexType.genTestData

import org.apache.spark.SparkConf
import org.apache.spark.gluten.NativeWriteChecker
import org.apache.spark.sql.Row
import org.apache.spark.sql.delta.DeltaLog
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.datasources.v2.clickhouse.ClickHouseConfig
import org.apache.spark.sql.types._

import scala.reflect.runtime.universe.TypeTag
import java.io.File
import java.sql.Date

class GlutenClickHouseNativeWriteTableSuite
extends GlutenClickHouseWholeStageTransformerSuite
Expand Down Expand Up @@ -67,12 +69,6 @@ class GlutenClickHouseNativeWriteTableSuite
.setMaster("local[1]")
}

private def getWarehouseDir = {
// test non-ascii path, by the way
// scalastyle:off nonascii
basePath + "/中文/spark-warehouse"
}

private val table_name_template = "hive_%s_test"
private val table_name_vanilla_template = "hive_%s_test_written_by_vanilla"

Expand All @@ -81,58 +77,7 @@ class GlutenClickHouseNativeWriteTableSuite
super.afterAll()
}

def getColumnName(s: String): String = {
s.replaceAll("\\(", "_").replaceAll("\\)", "_")
}

import collection.immutable.ListMap

import java.io.File

def compareSource(original_table: String, table_name: String, fields: Seq[String]): Unit = {
val rowsFromOriginTable =
spark.sql(s"select ${fields.mkString(",")} from $original_table").collect()
val dfFromWriteTable =
spark.sql(
s"select " +
s"${fields
.map(getColumnName)
.mkString(",")} " +
s"from $table_name")
checkAnswer(dfFromWriteTable, rowsFromOriginTable)
}
def writeAndCheckRead(
original_table: String,
table_name: String,
fields: Seq[String],
checkNative: Boolean = true)(write: Seq[String] => Unit): Unit =
withDestinationTable(table_name) {
withNativeWriteCheck(checkNative) {
write(fields)
}
compareSource(original_table, table_name, fields)
}

def recursiveListFiles(f: File): Array[File] = {
val these = f.listFiles
these ++ these.filter(_.isDirectory).flatMap(recursiveListFiles)
}

def getSignature(format: String, filesOfNativeWriter: Array[File]): Array[(Long, Long)] = {
filesOfNativeWriter.map(
f => {
val df = if (format.equals("parquet")) {
spark.read.parquet(f.getAbsolutePath)
} else {
spark.read.orc(f.getAbsolutePath)
}
(
df.count(),
df.agg(("int_field", "sum")).collect().apply(0).apply(0).asInstanceOf[Long]
)
})
}

private val fields_ = ListMap(
("string_field", "string"),
("int_field", "int"),
Expand All @@ -146,22 +91,6 @@ class GlutenClickHouseNativeWriteTableSuite
("date_field", "date")
)

def nativeWrite2(
f: String => (String, String, String),
extraCheck: (String, String) => Unit = null,
checkNative: Boolean = true): Unit = nativeWrite {
format =>
val (table_name, table_create_sql, insert_sql) = f(format)
withDestinationTable(table_name, Option(table_create_sql)) {
checkInsertQuery(insert_sql, checkNative)
Option(extraCheck).foreach(_(table_name, format))
}
}

def withSource[A <: Product: TypeTag](data: Seq[A], viewName: String, pairs: (String, String)*)(
block: => Unit): Unit =
withSource(spark.createDataFrame(data), viewName, pairs: _*)(block)

private lazy val supplierSchema = StructType.apply(
Seq(
StructField.apply("s_suppkey", LongType, nullable = true),
Expand Down Expand Up @@ -618,18 +547,7 @@ class GlutenClickHouseNativeWriteTableSuite
.saveAsTable(table_name_vanilla)
}
}
val sigsOfNativeWriter =
getSignature(
format,
recursiveListFiles(new File(getWarehouseDir + "/" + table_name))
.filter(_.getName.endsWith(s".$format"))).sorted
val sigsOfVanillaWriter =
getSignature(
format,
recursiveListFiles(new File(getWarehouseDir + "/" + table_name_vanilla))
.filter(_.getName.endsWith(s".$format"))).sorted

assertResult(sigsOfVanillaWriter)(sigsOfNativeWriter)
compareWriteFilesSignature(format, table_name, table_name_vanilla, "sum(int_field)")
}
}
}
Expand Down Expand Up @@ -680,18 +598,7 @@ class GlutenClickHouseNativeWriteTableSuite
.bucketBy(10, "byte_field", "string_field")
.saveAsTable(table_name_vanilla)
}
val sigsOfNativeWriter =
getSignature(
format,
recursiveListFiles(new File(getWarehouseDir + "/" + table_name))
.filter(_.getName.endsWith(s".$format"))).sorted
val sigsOfVanillaWriter =
getSignature(
format,
recursiveListFiles(new File(getWarehouseDir + "/" + table_name_vanilla))
.filter(_.getName.endsWith(s".$format"))).sorted

assertResult(sigsOfVanillaWriter)(sigsOfNativeWriter)
compareWriteFilesSignature(format, table_name, table_name_vanilla, "sum(int_field)")
}
}
}
Expand Down Expand Up @@ -754,6 +661,63 @@ class GlutenClickHouseNativeWriteTableSuite
}
}

test("test partitioned with escaped characters") {

val schema = StructType(
Seq(
StructField.apply("id", IntegerType, nullable = true),
StructField.apply("escape", StringType, nullable = true),
StructField.apply("bucket/col", StringType, nullable = true),
StructField.apply("part=col1", DateType, nullable = true),
StructField.apply("part_col2", StringType, nullable = true)
))

val data: Seq[Row] = Seq(
Row(1, "=", "00000", Date.valueOf("2024-01-01"), "2024=01/01"),
Row(2, "/", "00000", Date.valueOf("2024-01-01"), "2024=01/01"),
Row(3, "#", "00000", Date.valueOf("2024-01-01"), "2024#01:01"),
Row(4, ":", "00001", Date.valueOf("2024-01-02"), "2024#01:01"),
Row(5, "\\", "00001", Date.valueOf("2024-01-02"), "2024\\01\u000101"),
Row(6, "\u0001", "000001", Date.valueOf("2024-01-02"), "2024\\01\u000101"),
Row(7, "", "000002", null, null)
)

val df = spark.createDataFrame(spark.sparkContext.parallelize(data), schema)
df.createOrReplaceTempView("origin_table")
spark.sql("select * from origin_table").show()

nativeWrite {
format =>
val table_name = table_name_template.format(format)
spark.sql(s"drop table IF EXISTS $table_name")
writeAndCheckRead("origin_table", table_name, schema.fieldNames.map(f => s"`$f`")) {
_ =>
spark
.table("origin_table")
.write
.format(format)
.partitionBy("part=col1", "part_col2")
.bucketBy(2, "bucket/col")
.saveAsTable(table_name)
}

val table_name_vanilla = table_name_vanilla_template.format(format)
spark.sql(s"drop table IF EXISTS $table_name_vanilla")
withSQLConf((GlutenConfig.NATIVE_WRITER_ENABLED.key, "false")) {
withNativeWriteCheck(checkNative = false) {
spark
.table("origin_table")
.write
.format(format)
.partitionBy("part=col1", "part_col2")
.bucketBy(2, "bucket/col")
.saveAsTable(table_name_vanilla)
}
compareWriteFilesSignature(format, table_name, table_name_vanilla, "sum(id)")
}
}
}

test("test bucketed by constant") {
nativeWrite {
format =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,13 @@ import org.apache.gluten.config.GlutenConfig
import org.apache.gluten.execution.{FileSourceScanExecTransformer, GlutenClickHouseTPCHAbstractSuite}

import org.apache.spark.SparkConf
import org.apache.spark.sql.SaveMode
import org.apache.spark.sql.{Row, SaveMode}
import org.apache.spark.sql.delta.catalog.ClickHouseTableV2
import org.apache.spark.sql.delta.files.TahoeFileIndex
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.datasources.mergetree.StorageMeta
import org.apache.spark.sql.execution.datasources.v2.clickhouse.metadata.AddMergeTreeParts
import org.apache.spark.sql.types._

import org.apache.commons.io.FileUtils
import org.apache.hadoop.conf.Configuration
Expand Down Expand Up @@ -359,6 +360,54 @@ class GlutenClickHouseMergeTreeWriteOnHDFSSuite
spark.sql("drop table lineitem_mergetree_partition_hdfs")
}

test("test partition values with escape chars") {

val schema = StructType(
Seq(
StructField.apply("id", IntegerType, nullable = true),
StructField.apply("escape", StringType, nullable = true)
))

// scalastyle:off nonascii
val data: Seq[Row] = Seq(
Row(1, "="),
Row(2, "/"),
Row(3, "#"),
Row(4, ":"),
Row(5, "\\"),
Row(6, "\u0001"),
Row(7, "中文"),
Row(8, " "),
Row(9, "a b")
)
// scalastyle:on nonascii

val df = spark.createDataFrame(spark.sparkContext.parallelize(data), schema)
df.createOrReplaceTempView("origin_table")

// spark.conf.set("spark.gluten.enabled", "false")
spark.sql(s"""
|DROP TABLE IF EXISTS partition_escape;
|""".stripMargin)

spark.sql(s"""
|CREATE TABLE IF NOT EXISTS partition_escape
|(
| c1 int,
| c2 string
|)
|USING clickhouse
|PARTITIONED BY (c2)
|TBLPROPERTIES (storage_policy='__hdfs_main',
| orderByKey='c1',
| primaryKey='c1')
|LOCATION '$HDFS_URL/test/partition_escape'
|""".stripMargin)

spark.sql("insert into partition_escape select * from origin_table")
spark.sql("select * from partition_escape").show()
}

testSparkVersionLE33("test mergetree write with bucket table") {
spark.sql(s"""
|DROP TABLE IF EXISTS lineitem_mergetree_bucket_hdfs;
Expand Down
Loading

0 comments on commit be909b6

Please sign in to comment.