Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import java.lang.Math.toIntExact

import scala.collection.JavaConverters._

import ai.rapids.cudf.{ColumnVector => CudfColumnVector, OrderByArg, Scalar, Table}
import ai.rapids.cudf.{ColumnVector => CudfColumnVector, Table}
import com.nvidia.spark.rapids.{GpuBoundReference, GpuColumnVector, GpuExpression, GpuLiteral, RapidsHostColumnVector, SpillableColumnarBatch, SpillPriorities}
import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource}
import com.nvidia.spark.rapids.RapidsPluginImplicits.AutoCloseableProducingSeq
Expand Down Expand Up @@ -54,10 +54,37 @@ class GpuIcebergPartitioner(val spec: PartitionSpec,
private val partitionExprs: Seq[GpuExpression] = spec.fields().asScala.map(getPartitionExpr).toSeq

private val keyColNum: Int = spec.fields().size()
private val inputColNum: Int = dataSparkType.fields.length

// key column indices in the table: [key columns, input columns]
private val keyColIndices: Array[Int] = (0 until keyColNum).toArray
private val keySortOrders: Array[OrderByArg] = (0 until keyColNum)
.map(OrderByArg.asc(_, true))
.toArray
// input column indices in the table: [key columns, input columns]
private val inputColumnIndices: Array[Int] = (keyColNum until (keyColNum + inputColNum)).toArray

/**
* Make a new table: [key columns, input columns]
*/
private def makeKeysAndInputTable(spillableInput: SpillableColumnarBatch): Table = {
withResource(spillableInput.getColumnarBatch()) { inputBatch =>
// compute keys columns
val keyCols = partitionExprs.safeMap(_.columnarEval(inputBatch))

// combine keys columns and input columns into a new table
withResource(keyCols) { _ =>
withResource(GpuColumnVector.from(inputBatch)) { inputTable =>
val numCols = keyCols.size + inputTable.getNumberOfColumns
val cols = new Array[CudfColumnVector](numCols)
for (i <- keyCols.indices) {
cols(i) = keyCols(i).getBase
}
for (i <- 0 until inputTable.getNumberOfColumns) {
cols(i + keyCols.size) = inputTable.getColumn(i)
}
new Table(cols:_*)
}
}
}
}
Comment on lines +67 to +87
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: creating Table on line 83 with column references from keyCols and inputTable, but those source objects are closed by their withResource wrappers (lines 73, 74) before the table is returned. the returned table contains dangling references to freed GPU memory.

the columns need to be incRefCount() before being added to the new table, or the table construction needs to happen before the source resources are closed


/**
* Partition the `input` columnar batch using iceberg's partition spec.
Expand All @@ -70,94 +97,41 @@ class GpuIcebergPartitioner(val spec: PartitionSpec,
return Seq.empty
}

val numRows = input.numRows()

val spillableInput = closeOnExcept(input) { _ =>
SpillableColumnarBatch(input, ACTIVE_ON_DECK_PRIORITY)
}

val (partitionKeys, partitions) = withRetryNoSplit(spillableInput) { scb =>
val parts = withResource(scb.getColumnarBatch()) { inputBatch =>
partitionExprs.safeMap(_.columnarEval(inputBatch))
}
val keysTable = withResource(parts) { _ =>
val arr = new Array[CudfColumnVector](partitionExprs.size)
for (i <- partitionExprs.indices) {
arr(i) = parts(i).getBase
}
new Table(arr:_*)
}

val sortedKeyTableWithRowIdx = withResource(keysTable) { _ =>
withResource(Scalar.fromInt(0)) { zero =>
withResource(CudfColumnVector.sequence(zero, numRows)) { rowIdxCol =>
val totalColCount = keysTable.getNumberOfColumns + 1
val allCols = new Array[CudfColumnVector](totalColCount)

for (i <- 0 until keysTable.getNumberOfColumns) {
allCols(i) = keysTable.getColumn(i)
}
allCols(keysTable.getNumberOfColumns) = rowIdxCol

withResource(new Table(allCols: _*)) { allColsTable =>
allColsTable.orderBy(keySortOrders: _*)
}
}
}
}

val (sortedPartitionKeys, splitIds, rowIdxCol) = withResource(sortedKeyTableWithRowIdx) { _ =>
val uniqueKeysTable = sortedKeyTableWithRowIdx.groupBy(keyColIndices: _*)
.aggregate()

val sortedUniqueKeysTable = withResource(uniqueKeysTable) { _ =>
uniqueKeysTable.orderBy(keySortOrders: _*)
}

val (sortedPartitionKeys, splitIds) = withResource(sortedUniqueKeysTable) { _ =>
val partitionKeys = toPartitionKeys(spec.partitionType(),
partitionSparkType,
sortedUniqueKeysTable)

val splitIdsCv = sortedKeyTableWithRowIdx.upperBound(
sortedUniqueKeysTable,
keySortOrders: _*)

val splitIds = withResource(splitIdsCv) { _ =>
GpuColumnVector.toIntArray(splitIdsCv)
}

(partitionKeys, splitIds)
}
withRetryNoSplit(spillableInput) { scb =>
// make table: [key columns, input columns]
val keysAndInputTable = makeKeysAndInputTable(scb)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: the keysAndInputTable variable is assigned outside withResource but used inside the closure on line 115-116; if an exception occurs between line 110 and 114, the table won't be properly closed


val rowIdxCol = sortedKeyTableWithRowIdx.getColumn(keyColNum).incRefCount()
(sortedPartitionKeys, splitIds, rowIdxCol)
// split the input columns by the key columns,
// note: the result does not contain the key columns
val splitRet = withResource(keysAndInputTable) { _ =>
keysAndInputTable.groupBy(keyColIndices: _*)
.contiguousSplitGroupsAndGenUniqKeys(inputColumnIndices)
}

withResource(rowIdxCol) { _ =>
val inputTable = withResource(scb.getColumnarBatch()) { inputBatch =>
GpuColumnVector.from(inputBatch)
}
// generate results
withResource(splitRet) { _ =>
// generate the partition keys on the host side
val partitionKeys = toPartitionKeys(spec.partitionType(),
partitionSparkType,
splitRet.getUniqKeyTable)

val sortedDataTable = withResource(inputTable) { _ =>
inputTable.gather(rowIdxCol)
}
// release unique table to save GPU memory
splitRet.closeUniqKeyTable()

val partitions = withResource(sortedDataTable) { _ =>
sortedDataTable.contiguousSplit(splitIds: _*)
}
// get the partitions
val partitions = splitRet.getGroups

(sortedPartitionKeys, partitions)
// combine the partition keys and partitioned tables
partitionKeys.zip(partitions).map { case (partKey, partition) =>
ColumnarBatchWithPartition(SpillableColumnarBatch(partition, sparkType, SpillPriorities
.ACTIVE_BATCHING_PRIORITY), partKey)
}.toSeq
}
}

withResource(partitions) { _ =>
partitionKeys.zip(partitions).map { case (partKey, partition) =>
ColumnarBatchWithPartition(SpillableColumnarBatch(partition, sparkType, SpillPriorities
.ACTIVE_BATCHING_PRIORITY), partKey)
}.toSeq
}

}

private def getPartitionExpr(field: PartitionField)
Expand Down Expand Up @@ -208,4 +182,4 @@ object GpuIcebergPartitioner {
}).toArray
}
}
}
}
Loading