@@ -20,7 +20,7 @@ import java.lang.Math.toIntExact
2020
2121import scala .collection .JavaConverters ._
2222
23- import ai .rapids .cudf .{ColumnVector => CudfColumnVector , OrderByArg , Scalar , Table }
23+ import ai .rapids .cudf .{ColumnVector => CudfColumnVector , Table }
2424import com .nvidia .spark .rapids .{GpuBoundReference , GpuColumnVector , GpuExpression , GpuLiteral , RapidsHostColumnVector , SpillableColumnarBatch , SpillPriorities }
2525import com .nvidia .spark .rapids .Arm .{closeOnExcept , withResource }
2626import com .nvidia .spark .rapids .RapidsPluginImplicits .AutoCloseableProducingSeq
@@ -54,10 +54,37 @@ class GpuIcebergPartitioner(val spec: PartitionSpec,
5454 private val partitionExprs : Seq [GpuExpression ] = spec.fields().asScala.map(getPartitionExpr).toSeq
5555
5656 private val keyColNum : Int = spec.fields().size()
57+ private val inputColNum : Int = dataSparkType.fields.length
58+
59+ // key column indices in the table: [key columns, input columns]
5760 private val keyColIndices : Array [Int ] = (0 until keyColNum).toArray
58- private val keySortOrders : Array [OrderByArg ] = (0 until keyColNum)
59- .map(OrderByArg .asc(_, true ))
60- .toArray
61+ // input column indices in the table: [key columns, input columns]
62+ private val inputColumnIndices : Array [Int ] = (keyColNum until (keyColNum + inputColNum)).toArray
63+
64+ /**
65+ * Make a new table: [key columns, input columns]
66+ */
67+ private def makeKeysAndInputTable (spillableInput : SpillableColumnarBatch ): Table = {
68+ withResource(spillableInput.getColumnarBatch()) { inputBatch =>
69+ // compute keys columns
70+ val keyCols = partitionExprs.safeMap(_.columnarEval(inputBatch))
71+
72+ // combine keys columns and input columns into a new table
73+ withResource(keyCols) { _ =>
74+ withResource(GpuColumnVector .from(inputBatch)) { inputTable =>
75+ val numCols = keyCols.size + inputTable.getNumberOfColumns
76+ val cols = new Array [CudfColumnVector ](numCols)
77+ for (i <- keyCols.indices) {
78+ cols(i) = keyCols(i).getBase
79+ }
80+ for (i <- 0 until inputTable.getNumberOfColumns) {
81+ cols(i + keyCols.size) = inputTable.getColumn(i)
82+ }
83+ new Table (cols:_* )
84+ }
85+ }
86+ }
87+ }
6188
6289 /**
6390 * Partition the `input` columnar batch using iceberg's partition spec.
@@ -70,94 +97,41 @@ class GpuIcebergPartitioner(val spec: PartitionSpec,
7097 return Seq .empty
7198 }
7299
73- val numRows = input.numRows()
74-
75100 val spillableInput = closeOnExcept(input) { _ =>
76101 SpillableColumnarBatch (input, ACTIVE_ON_DECK_PRIORITY )
77102 }
78103
79- val (partitionKeys, partitions) = withRetryNoSplit(spillableInput) { scb =>
80- val parts = withResource(scb.getColumnarBatch()) { inputBatch =>
81- partitionExprs.safeMap(_.columnarEval(inputBatch))
82- }
83- val keysTable = withResource(parts) { _ =>
84- val arr = new Array [CudfColumnVector ](partitionExprs.size)
85- for (i <- partitionExprs.indices) {
86- arr(i) = parts(i).getBase
87- }
88- new Table (arr:_* )
89- }
90-
91- val sortedKeyTableWithRowIdx = withResource(keysTable) { _ =>
92- withResource(Scalar .fromInt(0 )) { zero =>
93- withResource(CudfColumnVector .sequence(zero, numRows)) { rowIdxCol =>
94- val totalColCount = keysTable.getNumberOfColumns + 1
95- val allCols = new Array [CudfColumnVector ](totalColCount)
96-
97- for (i <- 0 until keysTable.getNumberOfColumns) {
98- allCols(i) = keysTable.getColumn(i)
99- }
100- allCols(keysTable.getNumberOfColumns) = rowIdxCol
101-
102- withResource(new Table (allCols : _* )) { allColsTable =>
103- allColsTable.orderBy(keySortOrders : _* )
104- }
105- }
106- }
107- }
108-
109- val (sortedPartitionKeys, splitIds, rowIdxCol) = withResource(sortedKeyTableWithRowIdx) { _ =>
110- val uniqueKeysTable = sortedKeyTableWithRowIdx.groupBy(keyColIndices : _* )
111- .aggregate()
112-
113- val sortedUniqueKeysTable = withResource(uniqueKeysTable) { _ =>
114- uniqueKeysTable.orderBy(keySortOrders : _* )
115- }
116-
117- val (sortedPartitionKeys, splitIds) = withResource(sortedUniqueKeysTable) { _ =>
118- val partitionKeys = toPartitionKeys(spec.partitionType(),
119- partitionSparkType,
120- sortedUniqueKeysTable)
121-
122- val splitIdsCv = sortedKeyTableWithRowIdx.upperBound(
123- sortedUniqueKeysTable,
124- keySortOrders : _* )
125-
126- val splitIds = withResource(splitIdsCv) { _ =>
127- GpuColumnVector .toIntArray(splitIdsCv)
128- }
129-
130- (partitionKeys, splitIds)
131- }
104+ withRetryNoSplit(spillableInput) { scb =>
105+ // make table: [key columns, input columns]
106+ val keysAndInputTable = makeKeysAndInputTable(scb)
132107
133- val rowIdxCol = sortedKeyTableWithRowIdx.getColumn(keyColNum).incRefCount()
134- (sortedPartitionKeys, splitIds, rowIdxCol)
108+ // split the input columns by the key columns,
109+ // note: the result does not contain the key columns
110+ val splitRet = withResource(keysAndInputTable) { _ =>
111+ keysAndInputTable.groupBy(keyColIndices : _* )
112+ .contiguousSplitGroupsAndGenUniqKeys(inputColumnIndices)
135113 }
136114
137- withResource(rowIdxCol) { _ =>
138- val inputTable = withResource(scb.getColumnarBatch()) { inputBatch =>
139- GpuColumnVector .from(inputBatch)
140- }
115+ // generate results
116+ withResource(splitRet) { _ =>
117+ // generate the partition keys on the host side
118+ val partitionKeys = toPartitionKeys(spec.partitionType(),
119+ partitionSparkType,
120+ splitRet.getUniqKeyTable)
141121
142- val sortedDataTable = withResource(inputTable) { _ =>
143- inputTable.gather(rowIdxCol)
144- }
122+ // release unique table to save GPU memory
123+ splitRet.closeUniqKeyTable()
145124
146- val partitions = withResource(sortedDataTable) { _ =>
147- sortedDataTable.contiguousSplit(splitIds : _* )
148- }
125+ // get the partitions
126+ val partitions = splitRet.getGroups
149127
150- (sortedPartitionKeys, partitions)
128+ // combine the partition keys and partitioned tables
129+ partitionKeys.zip(partitions).map { case (partKey, partition) =>
130+ ColumnarBatchWithPartition (SpillableColumnarBatch (partition, sparkType, SpillPriorities
131+ .ACTIVE_BATCHING_PRIORITY ), partKey)
132+ }.toSeq
151133 }
152134 }
153-
154- withResource(partitions) { _ =>
155- partitionKeys.zip(partitions).map { case (partKey, partition) =>
156- ColumnarBatchWithPartition (SpillableColumnarBatch (partition, sparkType, SpillPriorities
157- .ACTIVE_BATCHING_PRIORITY ), partKey)
158- }.toSeq
159- }
160-
161135 }
162136
163137 private def getPartitionExpr (field : PartitionField )
@@ -208,4 +182,4 @@ object GpuIcebergPartitioner {
208182 }).toArray
209183 }
210184 }
211- }
185+ }
0 commit comments