diff --git a/backends-velox/src-delta33/main/scala/org/apache/spark/sql/delta/files/GlutenDeltaFileFormatWriter.scala b/backends-velox/src-delta33/main/scala/org/apache/spark/sql/delta/files/GlutenDeltaFileFormatWriter.scala index e9ea3e475346..697a48a03cfa 100644 --- a/backends-velox/src-delta33/main/scala/org/apache/spark/sql/delta/files/GlutenDeltaFileFormatWriter.scala +++ b/backends-velox/src-delta33/main/scala/org/apache/spark/sql/delta/files/GlutenDeltaFileFormatWriter.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.delta.files import org.apache.gluten.backendsapi.BackendsApiManager import org.apache.gluten.backendsapi.velox.VeloxBatchType +import org.apache.gluten.columnarbatch.VeloxColumnarBatches import org.apache.gluten.config.GlutenConfig import org.apache.gluten.execution._ import org.apache.gluten.execution.datasource.GlutenFormatFactory @@ -45,6 +46,7 @@ import org.apache.spark.sql.execution.datasources.FileFormatWriter._ import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.{SerializableConfiguration, Utils} import org.apache.hadoop.conf.Configuration @@ -562,6 +564,14 @@ object GlutenDeltaFileFormatWriter extends LoggingShims { } }.toArray + private val reservePartitionColumns: Boolean = + description.partitionColumns.exists { + pcol => + description.dataColumns.exists { + dcol => dcol.name == pcol.name && dcol.exprId == pcol.exprId + } + } + private def beforeWrite(record: InternalRow): Unit = { val nextPartitionValues = if (isPartitioned) Some(getPartitionValues(record)) else None val nextBucketId = if (isBucketed) Some(getBucketId(record)) else None @@ -590,43 +600,87 @@ object GlutenDeltaFileFormatWriter extends LoggingShims { record match { case carrierRow: BatchCarrierRow => carrierRow match { - case placeholderRow: PlaceholderRow => + case _: PlaceholderRow => // Do nothing. case terminalRow: TerminalRow => - val numRows = terminalRow.batch().numRows() - if (numRows > 0) { - val blockStripes = GlutenFormatFactory.rowSplitter - .splitBlockByPartitionAndBucket( - terminalRow.batch(), - partitionColIndice, - isBucketed) - val iter = blockStripes.iterator() - while (iter.hasNext) { - val blockStripe = iter.next() - val headingRow = blockStripe.getHeadingRow - beforeWrite(headingRow) - val currentColumnBatch = blockStripe.getColumnarBatch - val numRowsOfCurrentColumnarBatch = currentColumnBatch.numRows() - assert(numRowsOfCurrentColumnarBatch > 0) - val currentTerminalRow = terminalRow.withNewBatch(currentColumnBatch) - currentWriter.write(currentTerminalRow) - statsTrackers.foreach { - tracker => - tracker.newRow(currentWriter.path, currentTerminalRow) - for (_ <- 0 until numRowsOfCurrentColumnarBatch - 1) { - tracker.newRow(currentWriter.path, new PlaceholderRow()) - } - } - currentColumnBatch.close() - } - blockStripes.release() - recordsInFile += numRows - } + writePartitionedBatch(terminalRow) } case _ => beforeWrite(record) writeRecord(record) } } + + private def writeCurrentBatch(terminalRow: TerminalRow, rowCount: Int): Unit = { + assert(rowCount > 0) + currentWriter.write(terminalRow) + statsTrackers.foreach(_.newRow(currentWriter.path, terminalRow)) + recordsInFile += rowCount + } + + private def writeCurrentBatchWithMaxRecords( + terminalRow: TerminalRow, + columnBatch: ColumnarBatch): Unit = { + val numRows = columnBatch.numRows() + var offset = 0 + while (offset < numRows) { + val rowsRemaining = numRows - offset + val rowsToWrite = if (description.maxRecordsPerFile > 0) { + if (recordsInFile >= description.maxRecordsPerFile) { + renewCurrentWriterIfTooManyRecords(currentPartitionValues, currentBucketId) + } + math.min(rowsRemaining.toLong, description.maxRecordsPerFile - recordsInFile).toInt + } else { + rowsRemaining + } + + assert(rowsToWrite > 0) + val batchToWrite = + if (offset == 0 && rowsToWrite == numRows) { + columnBatch + } else { + VeloxColumnarBatches.slice(columnBatch, offset, rowsToWrite) + } + try { + writeCurrentBatch(terminalRow.withNewBatch(batchToWrite), rowsToWrite) + } finally { + if (batchToWrite ne columnBatch) { + batchToWrite.close() + } + } + offset += rowsToWrite + } + } + + private def writePartitionStripe(terminalRow: TerminalRow, blockStripe: BlockStripe): Unit = { + beforeWrite(blockStripe.getHeadingRow) + val currentColumnBatch = blockStripe.getColumnarBatch + try { + assert(currentColumnBatch.numRows() > 0) + writeCurrentBatchWithMaxRecords(terminalRow, currentColumnBatch) + } finally { + currentColumnBatch.close() + } + } + + private def writePartitionedBatch(terminalRow: TerminalRow): Unit = { + val numRows = terminalRow.batch().numRows() + if (numRows > 0) { + val blockStripes = GlutenFormatFactory.rowSplitter + .splitBlockByPartitionAndBucket( + terminalRow.batch(), + partitionColIndice, + isBucketed, + reservePartitionColumns) + try { + val iter = blockStripes.iterator() + while (iter.hasNext) { + writePartitionStripe(terminalRow, iter.next()) + } + } finally { + blockStripes.release() + } + } + } } } diff --git a/backends-velox/src-delta33/main/scala/org/apache/spark/sql/delta/stats/GlutenDeltaJobStatsTracker.scala b/backends-velox/src-delta33/main/scala/org/apache/spark/sql/delta/stats/GlutenDeltaJobStatsTracker.scala index 70ca53679203..2c8405a254e8 100644 --- a/backends-velox/src-delta33/main/scala/org/apache/spark/sql/delta/stats/GlutenDeltaJobStatsTracker.scala +++ b/backends-velox/src-delta33/main/scala/org/apache/spark/sql/delta/stats/GlutenDeltaJobStatsTracker.scala @@ -152,6 +152,7 @@ object GlutenDeltaJobStatsTracker extends Logging { } private val statsAttrs = aggregates.flatMap(_.aggregateFunction.aggBufferAttributes) private val statsResultAttrs = aggregates.flatMap(_.aggregateFunction.inputAggBufferAttributes) + private val dataColIndices = dataCols.indices.toArray private val veloxAggTask: ColumnarBatchOutIterator = { val inputNode = StatisticsInputNode(Seq(dummyKeyAttr), dataCols) val aggOp = SortAggregateExec( @@ -261,6 +262,16 @@ object GlutenDeltaJobStatsTracker extends Logging { case t: TerminalRow => val valueBatch = t.batch() val numRows = valueBatch.numRows() + val statsValueBatch = if (valueBatch.numCols() == dataCols.size) { + valueBatch + } else { + assert( + valueBatch.numCols() > dataCols.size, + s"Delta stats input has ${valueBatch.numCols()} columns, " + + s"but the stats schema needs ${dataCols.size} columns." + ) + ColumnarBatches.select(BackendsApiManager.getBackendName, valueBatch, dataColIndices) + } val dummyKeyVec = ArrowWritableColumnVector .allocateColumns(numRows, new StructType().add(dummyKeyAttr.name, IntegerType)) .head @@ -269,10 +280,15 @@ object GlutenDeltaJobStatsTracker extends Logging { ColumnarBatches.offload( ArrowBufferAllocators.contextInstance(), new ColumnarBatch(Array[ColumnVector](dummyKeyVec), numRows))) - val compositeBatch = VeloxColumnarBatches.compose(dummyKeyBatch, valueBatch) - dummyKeyBatch.close() - valueBatch.close() - inputBatchQueue.put(Some(compositeBatch)) + try { + val compositeBatch = VeloxColumnarBatches.compose(dummyKeyBatch, statsValueBatch) + inputBatchQueue.put(Some(compositeBatch)) + } finally { + dummyKeyBatch.close() + if (statsValueBatch ne valueBatch) { + statsValueBatch.close() + } + } } } diff --git a/backends-velox/src-delta40/main/scala/org/apache/spark/sql/delta/files/GlutenDeltaFileFormatWriter.scala b/backends-velox/src-delta40/main/scala/org/apache/spark/sql/delta/files/GlutenDeltaFileFormatWriter.scala index f609a6130b84..3090e2e2127f 100644 --- a/backends-velox/src-delta40/main/scala/org/apache/spark/sql/delta/files/GlutenDeltaFileFormatWriter.scala +++ b/backends-velox/src-delta40/main/scala/org/apache/spark/sql/delta/files/GlutenDeltaFileFormatWriter.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.delta.files import org.apache.gluten.backendsapi.BackendsApiManager import org.apache.gluten.backendsapi.velox.VeloxBatchType +import org.apache.gluten.columnarbatch.VeloxColumnarBatches import org.apache.gluten.config.GlutenConfig import org.apache.gluten.execution._ import org.apache.gluten.execution.datasource.GlutenFormatFactory @@ -46,6 +47,7 @@ import org.apache.spark.sql.execution.datasources.FileFormatWriter._ import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.{SerializableConfiguration, Utils} import org.apache.hadoop.conf.Configuration @@ -557,6 +559,14 @@ object GlutenDeltaFileFormatWriter extends LoggingShims { } }.toArray + private val reservePartitionColumns: Boolean = + description.partitionColumns.exists { + pcol => + description.dataColumns.exists { + dcol => dcol.name == pcol.name && dcol.exprId == pcol.exprId + } + } + private def beforeWrite(record: InternalRow): Unit = { val nextPartitionValues = if (isPartitioned) Some(getPartitionValues(record)) else None val nextBucketId = if (isBucketed) Some(getBucketId(record)) else None @@ -583,42 +593,88 @@ object GlutenDeltaFileFormatWriter extends LoggingShims { record match { case carrierRow: BatchCarrierRow => carrierRow match { - case placeholderRow: PlaceholderRow => + case _: PlaceholderRow => // Do nothing. case terminalRow: TerminalRow => - val numRows = terminalRow.batch().numRows() - if (numRows > 0) { - val blockStripes = GlutenFormatFactory.rowSplitter - .splitBlockByPartitionAndBucket(terminalRow.batch(), partitionColIndice, - isBucketed) - val iter = blockStripes.iterator() - while (iter.hasNext) { - val blockStripe = iter.next() - val headingRow = blockStripe.getHeadingRow - beforeWrite(headingRow) - val currentColumnBatch = blockStripe.getColumnarBatch - val numRowsOfCurrentColumnarBatch = currentColumnBatch.numRows() - assert(numRowsOfCurrentColumnarBatch > 0) - val currentTerminalRow = terminalRow.withNewBatch(currentColumnBatch) - currentWriter.write(currentTerminalRow) - statsTrackers.foreach { - tracker => - tracker.newRow(currentWriter.path, currentTerminalRow) - for (_ <- 0 until numRowsOfCurrentColumnarBatch - 1) { - tracker.newRow(currentWriter.path, new PlaceholderRow()) - } - } - currentColumnBatch.close() - } - blockStripes.release() - recordsInFile += numRows - } + writePartitionedBatch(terminalRow) } case _ => beforeWrite(record) writeRecord(record) } } + + private def writeCurrentBatch(terminalRow: TerminalRow, rowCount: Int): Unit = { + assert(rowCount > 0) + currentWriter.write(terminalRow) + statsTrackers.foreach(_.newRow(currentWriter.path, terminalRow)) + recordsInFile += rowCount + } + + private def writeCurrentBatchWithMaxRecords( + terminalRow: TerminalRow, + columnBatch: ColumnarBatch): Unit = { + val numRows = columnBatch.numRows() + var offset = 0 + while (offset < numRows) { + val rowsRemaining = numRows - offset + val rowsToWrite = if (description.maxRecordsPerFile > 0) { + if (recordsInFile >= description.maxRecordsPerFile) { + renewCurrentWriterIfTooManyRecords(currentPartitionValues, currentBucketId) + } + math.min(rowsRemaining.toLong, description.maxRecordsPerFile - recordsInFile).toInt + } else { + rowsRemaining + } + + assert(rowsToWrite > 0) + val batchToWrite = + if (offset == 0 && rowsToWrite == numRows) { + columnBatch + } else { + VeloxColumnarBatches.slice(columnBatch, offset, rowsToWrite) + } + try { + writeCurrentBatch(terminalRow.withNewBatch(batchToWrite), rowsToWrite) + } finally { + if (batchToWrite ne columnBatch) { + batchToWrite.close() + } + } + offset += rowsToWrite + } + } + + private def writePartitionStripe(terminalRow: TerminalRow, blockStripe: BlockStripe): Unit = { + beforeWrite(blockStripe.getHeadingRow) + val currentColumnBatch = blockStripe.getColumnarBatch + try { + assert(currentColumnBatch.numRows() > 0) + writeCurrentBatchWithMaxRecords(terminalRow, currentColumnBatch) + } finally { + currentColumnBatch.close() + } + } + + private def writePartitionedBatch(terminalRow: TerminalRow): Unit = { + val numRows = terminalRow.batch().numRows() + if (numRows > 0) { + val blockStripes = GlutenFormatFactory.rowSplitter + .splitBlockByPartitionAndBucket( + terminalRow.batch(), + partitionColIndice, + isBucketed, + reservePartitionColumns) + try { + val iter = blockStripes.iterator() + while (iter.hasNext) { + writePartitionStripe(terminalRow, iter.next()) + } + } finally { + blockStripes.release() + } + } + } } } // spotless:on diff --git a/backends-velox/src-delta40/main/scala/org/apache/spark/sql/delta/stats/GlutenDeltaJobStatsTracker.scala b/backends-velox/src-delta40/main/scala/org/apache/spark/sql/delta/stats/GlutenDeltaJobStatsTracker.scala index d046263b599b..740b46c2e6b2 100644 --- a/backends-velox/src-delta40/main/scala/org/apache/spark/sql/delta/stats/GlutenDeltaJobStatsTracker.scala +++ b/backends-velox/src-delta40/main/scala/org/apache/spark/sql/delta/stats/GlutenDeltaJobStatsTracker.scala @@ -156,6 +156,7 @@ object GlutenDeltaJobStatsTracker extends Logging { } private val statsAttrs = aggregates.flatMap(_.aggregateFunction.aggBufferAttributes) private val statsResultAttrs = aggregates.flatMap(_.aggregateFunction.inputAggBufferAttributes) + private val dataColIndices = dataCols.indices.toArray private val veloxAggTask: ColumnarBatchOutIterator = { val inputNode = StatisticsInputNode(Seq(dummyKeyAttr), dataCols) val aggOp = SortAggregateExec( @@ -265,6 +266,16 @@ object GlutenDeltaJobStatsTracker extends Logging { case t: TerminalRow => val valueBatch = t.batch() val numRows = valueBatch.numRows() + val statsValueBatch = if (valueBatch.numCols() == dataCols.size) { + valueBatch + } else { + assert( + valueBatch.numCols() > dataCols.size, + s"Delta stats input has ${valueBatch.numCols()} columns, " + + s"but the stats schema needs ${dataCols.size} columns." + ) + ColumnarBatches.select(BackendsApiManager.getBackendName, valueBatch, dataColIndices) + } val dummyKeyVec = ArrowWritableColumnVector .allocateColumns(numRows, new StructType().add(dummyKeyAttr.name, IntegerType)) .head @@ -273,10 +284,15 @@ object GlutenDeltaJobStatsTracker extends Logging { ColumnarBatches.offload( ArrowBufferAllocators.contextInstance(), new ColumnarBatch(Array[ColumnVector](dummyKeyVec), numRows))) - val compositeBatch = VeloxColumnarBatches.compose(dummyKeyBatch, valueBatch) - dummyKeyBatch.close() - valueBatch.close() - inputBatchQueue.put(Some(compositeBatch)) + try { + val compositeBatch = VeloxColumnarBatches.compose(dummyKeyBatch, statsValueBatch) + inputBatchQueue.put(Some(compositeBatch)) + } finally { + dummyKeyBatch.close() + if (statsValueBatch ne valueBatch) { + statsValueBatch.close() + } + } } } diff --git a/backends-velox/src-delta40/test/scala/org/apache/spark/sql/delta/DeltaNativeWriteSuite.scala b/backends-velox/src-delta40/test/scala/org/apache/spark/sql/delta/DeltaNativeWriteSuite.scala index bca0a66d1ad6..694339967ec4 100644 --- a/backends-velox/src-delta40/test/scala/org/apache/spark/sql/delta/DeltaNativeWriteSuite.scala +++ b/backends-velox/src-delta40/test/scala/org/apache/spark/sql/delta/DeltaNativeWriteSuite.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.delta.actions.AddFile import org.apache.spark.sql.delta.commands.optimize.OptimizeMetrics import org.apache.spark.sql.delta.sources.DeltaSQLConf import org.apache.spark.sql.delta.test.DeltaSQLCommandTest +import org.apache.spark.sql.delta.util.JsonUtils import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.command.ExecutedCommandExec @@ -45,14 +46,22 @@ class DeltaNativeWriteSuite extends DeltaSQLCommandTest { .exists(_.toLowerCase(java.util.Locale.ROOT).contains("mac")) private def withNativeWriteOffloadConf[T](f: => T): T = { + withNativeWriteOffloadConf(collectStats = false)(f) + } + + private def withNativeWriteOffloadConf[T](collectStats: Boolean)(f: => T): T = { val confs = Seq( SQLConf.ANSI_ENABLED.key -> "false", SQLConf.SESSION_LOCAL_TIMEZONE.key -> "UTC", GlutenConfig.GLUTEN_ANSI_FALLBACK_ENABLED.key -> "false", - DeltaSQLConf.DELTA_COLLECT_STATS.key -> "false" + DeltaSQLConf.DELTA_COLLECT_STATS.key -> collectStats.toString ) ++ (if (isMac) { - Seq(GlutenConfig.NATIVE_VALIDATION_ENABLED.key -> "false") + Seq( + GlutenConfig.NATIVE_VALIDATION_ENABLED.key -> "false", + GlutenConfig.COLUMNAR_FILESCAN_ENABLED.key -> "false", + GlutenConfig.COLUMNAR_BATCHSCAN_ENABLED.key -> "false" + ) } else { Seq.empty }) @@ -71,8 +80,9 @@ class DeltaNativeWriteSuite extends DeltaSQLCommandTest { s"${GlutenConfig.GLUTEN_ANSI_FALLBACK_ENABLED.key} should be false in native write tests" ) assert( - !spark.sessionState.conf.getConf(DeltaSQLConf.DELTA_COLLECT_STATS), - s"${DeltaSQLConf.DELTA_COLLECT_STATS.key} should be false in native write tests") + spark.sessionState.conf.getConf(DeltaSQLConf.DELTA_COLLECT_STATS) == collectStats, + s"${DeltaSQLConf.DELTA_COLLECT_STATS.key} should be $collectStats in native write tests" + ) if (isMac) { assert( !spark.sessionState.conf @@ -80,6 +90,18 @@ class DeltaNativeWriteSuite extends DeltaSQLCommandTest { .toBoolean, s"${GlutenConfig.NATIVE_VALIDATION_ENABLED.key} should be false on macOS" ) + assert( + !spark.sessionState.conf + .getConfString(GlutenConfig.COLUMNAR_FILESCAN_ENABLED.key) + .toBoolean, + s"${GlutenConfig.COLUMNAR_FILESCAN_ENABLED.key} should be false on macOS" + ) + assert( + !spark.sessionState.conf + .getConfString(GlutenConfig.COLUMNAR_BATCHSCAN_ENABLED.key) + .toBoolean, + s"${GlutenConfig.COLUMNAR_BATCHSCAN_ENABLED.key} should be false on macOS" + ) } f } @@ -208,6 +230,18 @@ class DeltaNativeWriteSuite extends DeltaSQLCommandTest { } } + private def readStatsJson(stats: String): Map[String, Any] = { + assert(stats != null && stats.nonEmpty, "Expected Delta AddFile stats to be recorded") + JsonUtils.fromJson[Map[String, Any]](stats) + } + + private def numRecords(stats: Map[String, Any]): Long = { + stats("numRecords") match { + case number: Number => number.longValue() + case other => other.toString.toLong + } + } + test("native delta delete command should be offloaded") { withNativeWriteOffloadConf { withTempDir { @@ -346,6 +380,96 @@ class DeltaNativeWriteSuite extends DeltaSQLCommandTest { } } + test("native delta optimized partitioned write should collect stats and honor file layout") { + withNativeWriteOffloadConf(collectStats = true) { + withSQLConf(DeltaSQLConf.DELTA_OPTIMIZE_WRITE_ENABLED.key -> "true") { + withTempDir { + dir => + val path = dir.getCanonicalPath + val maxRecordsPerFile = 4L + val input = spark + .range(0, 40, 1, 4) + .selectExpr( + "id", + "concat('v', cast(id as string)) as value", + "cast(id % 4 as int) as part") + + val plans = collectExecutedPlans { + input.write + .format("delta") + .partitionBy("part") + .option("maxRecordsPerFile", maxRecordsPerFile.toString) + .mode("overwrite") + .save(path) + } + + assertContainsNativeWriteCommand( + plans, + "optimized partitioned DataFrameWriter.save(overwrite) with stats") + assert(spark.read.format("delta").load(path).collect().toSet == input.collect().toSet) + + val addFiles = DeltaLog.forTable(spark, path).update().allFiles.collect() + assert(addFiles.nonEmpty, "Expected Delta write to add files") + val fileStats = addFiles.map(add => add -> readStatsJson(add.stats)) + assert(fileStats.map { case (_, stat) => numRecords(stat) }.sum == 40) + fileStats.foreach { + case (_, stat) => + val fileNumRecords = numRecords(stat) + assert( + fileNumRecords <= maxRecordsPerFile, + s"Expected at most $maxRecordsPerFile rows per file, got $fileNumRecords") + assert(stat.contains("minValues"), s"Missing minValues in stats: $stat") + assert(stat.contains("maxValues"), s"Missing maxValues in stats: $stat") + assert(stat.contains("nullCount"), s"Missing nullCount in stats: $stat") + } + val recordsByPartition = fileStats + .groupBy { case (add, _) => add.partitionValues("part") } + .map { + case (partition, files) => + partition -> files.map { case (_, stat) => numRecords(stat) }.sum + } + assert(recordsByPartition == Map("0" -> 10L, "1" -> 10L, "2" -> 10L, "3" -> 10L)) + } + } + } + } + + test("native delta Iceberg-compatible partitioned write should collect stats") { + withNativeWriteOffloadConf(collectStats = true) { + withSQLConf(DeltaSQLConf.DELTA_OPTIMIZE_WRITE_ENABLED.key -> "true") { + withTempDir { + dir => + val path = dir.getCanonicalPath + val input = spark + .range(0, 12, 1, 3) + .selectExpr("id", "cast(id % 3 as int) as part") + + val plans = collectExecutedPlans { + input.write + .format("delta") + .option(DeltaConfigs.COLUMN_MAPPING_MODE.key, "name") + .option(DeltaConfigs.ICEBERG_COMPAT_V1_ENABLED.key, "true") + .partitionBy("part") + .mode("overwrite") + .save(path) + } + + assertContainsNativeWriteCommand( + plans, + "Iceberg-compatible partitioned DataFrameWriter.save(overwrite) with stats") + assert(spark.read.format("delta").load(path).collect().toSet == input.collect().toSet) + + val snapshot = DeltaLog.forTable(spark, path).update() + assert(IcebergCompatV1.isEnabled(snapshot.metadata)) + val addFiles = snapshot.allFiles.collect() + assert(addFiles.nonEmpty, "Expected Delta write to add files") + val stats = addFiles.map(add => readStatsJson(add.stats)) + assert(stats.map(numRecords).sum == 12) + } + } + } + } + test("native delta optimize command should be offloaded") { withNativeWriteOffloadConf { withTempDir { diff --git a/backends-velox/src/main/java/org/apache/gluten/datasource/VeloxDataSourceJniWrapper.java b/backends-velox/src/main/java/org/apache/gluten/datasource/VeloxDataSourceJniWrapper.java index 23f071aff1f8..176da92d5214 100644 --- a/backends-velox/src/main/java/org/apache/gluten/datasource/VeloxDataSourceJniWrapper.java +++ b/backends-velox/src/main/java/org/apache/gluten/datasource/VeloxDataSourceJniWrapper.java @@ -54,5 +54,8 @@ public long init(String filePath, long cSchema, Map options) { public native void writeBatch(long dsHandle, long batchHandle); public native BlockStripes splitBlockByPartitionAndBucket( - long blockAddress, int[] partitionColIndice, boolean hasBucket); + long blockAddress, + int[] partitionColIndice, + boolean hasBucket, + boolean reservePartitionColumns); } diff --git a/backends-velox/src/main/scala/org/apache/spark/sql/execution/datasources/velox/VeloxFormatWriterInjects.scala b/backends-velox/src/main/scala/org/apache/spark/sql/execution/datasources/velox/VeloxFormatWriterInjects.scala index 06e9d91c0fa6..96b01098f324 100644 --- a/backends-velox/src/main/scala/org/apache/spark/sql/execution/datasources/velox/VeloxFormatWriterInjects.scala +++ b/backends-velox/src/main/scala/org/apache/spark/sql/execution/datasources/velox/VeloxFormatWriterInjects.scala @@ -119,6 +119,10 @@ class VeloxRowSplitter extends GlutenRowSplitter { val datasourceJniWrapper = VeloxDataSourceJniWrapper.create(runtime) new VeloxBlockStripes( datasourceJniWrapper - .splitBlockByPartitionAndBucket(handler, partitionColIndice, hasBucket)) + .splitBlockByPartitionAndBucket( + handler, + partitionColIndice, + hasBucket, + reservePartitionColumns)) } } diff --git a/cpp/velox/jni/VeloxJniWrapper.cc b/cpp/velox/jni/VeloxJniWrapper.cc index e30413d6d357..d80947cc226d 100644 --- a/cpp/velox/jni/VeloxJniWrapper.cc +++ b/cpp/velox/jni/VeloxJniWrapper.cc @@ -583,7 +583,8 @@ Java_org_apache_gluten_datasource_VeloxDataSourceJniWrapper_splitBlockByPartitio jobject wrapper, jlong batchHandle, jintArray partitionColIndices, - jboolean hasBucket) { + jboolean hasBucket, + jboolean reservePartitionColumns) { JNI_METHOD_START GLUTEN_CHECK(!hasBucket, "Bucketing not supported by splitBlockByPartitionAndBucket"); @@ -600,13 +601,16 @@ Java_org_apache_gluten_datasource_VeloxDataSourceJniWrapper_splitBlockByPartitio partitionColIndicesVec.emplace_back(partitionColumnIndex); } - std::vector dataColIndicesVec; + std::vector outputColIndicesVec; for (int i = 0; i < batch->numColumns(); ++i) { if (std::find(partitionColIndicesVec.begin(), partitionColIndicesVec.end(), i) == partitionColIndicesVec.end()) { - // The column is not a partition column. Add it to the data column vector. - dataColIndicesVec.emplace_back(i); + // Write data columns first, matching Spark's dynamic partition writer output schema. + outputColIndicesVec.emplace_back(i); } } + if (reservePartitionColumns) { + outputColIndicesVec.insert(outputColIndicesVec.end(), partitionColIndicesVec.begin(), partitionColIndicesVec.end()); + } auto pool = dynamic_cast(ctx->memoryManager())->getLeafMemoryPool(); const auto veloxBatch = VeloxColumnarBatch::from(pool.get(), batch); @@ -656,9 +660,8 @@ Java_org_apache_gluten_datasource_VeloxDataSourceJniWrapper_splitBlockByPartitio : exec::wrap(partitionSize, partitionRows[partitionId], inputRowVector); const std::shared_ptr partitionBatch = std::make_shared(rowVector); - const std::shared_ptr partitionBatchWithoutPartitionColumns = - partitionBatch->select(pool.get(), dataColIndicesVec); - partitionBatchHandles[partitionId] = ctx->saveObject(partitionBatchWithoutPartitionColumns); + const std::shared_ptr outputBatch = partitionBatch->select(pool.get(), outputColIndicesVec); + partitionBatchHandles[partitionId] = ctx->saveObject(outputBatch); const auto headingRow = partitionBatch->toUnsafeRow(0); const auto headingRowBytes = headingRow.data(); const auto headingRowNumBytes = headingRow.size();