Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix collection_ops_tests for Spark 4.0 [databricks] #11414

Merged
merged 10 commits into from
Oct 12, 2024
11 changes: 8 additions & 3 deletions integration_tests/src/main/python/collection_ops_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2021-2023, NVIDIA CORPORATION.
# Copyright (c) 2021-2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -17,6 +17,8 @@
from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_error
from data_gen import *
from pyspark.sql.types import *

from src.main.python.spark_session import is_before_spark_400
mythrocks marked this conversation as resolved.
Show resolved Hide resolved
from string_test import mk_str_gen
import pyspark.sql.functions as f
import pyspark.sql.utils
Expand Down Expand Up @@ -326,8 +328,11 @@ def test_sequence_illegal_boundaries(start_gen, stop_gen, step_gen):
@pytest.mark.parametrize('stop_gen', sequence_too_long_length_gens, ids=idfn)
@allow_non_gpu(*non_utc_allow)
def test_sequence_too_long_sequence(stop_gen):
msg = "Too long sequence" if is_before_spark_334() or (not is_before_spark_340() and is_before_spark_342()) \
or is_spark_350() else "Unsuccessful try to create array with"
msg = "Too long sequence" if is_before_spark_334() \
or (not is_before_spark_340() and is_before_spark_342()) \
or is_spark_350() \
else "Can't create array" if not is_before_spark_400() \
else "Unsuccessful try to create array with"
assert_gpu_and_cpu_error(
# To avoid OOM, reduce the row number to 1, it is enough to verify this case.
lambda spark:unary_op_df(spark, stop_gen, 1).selectExpr(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.rapids
import java.util.Optional

import ai.rapids.cudf
import ai.rapids.cudf.{BinaryOp, ColumnVector, ColumnView, DType, Scalar, SegmentedReductionAggregation, Table}
import ai.rapids.cudf.{BinaryOp, ColumnVector, ColumnView, DType, ReductionAggregation, Scalar, SegmentedReductionAggregation, Table}
import com.nvidia.spark.rapids._
import com.nvidia.spark.rapids.Arm._
import com.nvidia.spark.rapids.ArrayIndexUtils.firstIndexAndNumElementUnchecked
Expand Down Expand Up @@ -1535,7 +1535,8 @@ object GpuSequenceUtil {
def computeSequenceSize(
start: ColumnVector,
stop: ColumnVector,
step: ColumnVector): ColumnVector = {
step: ColumnVector,
functionName: String): ColumnVector = {
checkSequenceInputs(start, stop, step)
val actualSize = GetSequenceSize(start, stop, step)
val sizeAsLong = withResource(actualSize) { _ =>
Expand All @@ -1557,7 +1558,11 @@ object GpuSequenceUtil {
// check max size
withResource(Scalar.fromInt(MAX_ROUNDED_ARRAY_LENGTH)) { maxLen =>
withResource(sizeAsLong.lessOrEqualTo(maxLen)) { allValid =>
require(isAllValidTrue(allValid), GetSequenceSize.TOO_LONG_SEQUENCE)
withResource(sizeAsLong.reduce(ReductionAggregation.max())) { maxSizeScalar =>
require(isAllValidTrue(allValid),
GetSequenceSize.TOO_LONG_SEQUENCE(maxSizeScalar.getLong.asInstanceOf[Int],
functionName))
}
}
}
// cast to int and return
Expand Down Expand Up @@ -1597,7 +1602,7 @@ case class GpuSequence(start: Expression, stop: Expression, stepOpt: Option[Expr
val steps = stepGpuColOpt.map(_.getBase.incRefCount())
.getOrElse(defaultStepsFunc(startCol, stopCol))
closeOnExcept(steps) { _ =>
(computeSequenceSize(startCol, stopCol, steps), steps)
(computeSequenceSize(startCol, stopCol, steps, prettyName), steps)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,12 @@ import com.nvidia.spark.rapids.Arm._
import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH
mythrocks marked this conversation as resolved.
Show resolved Hide resolved

object GetSequenceSize {
val TOO_LONG_SEQUENCE = s"Too long sequence found. Should be <= $MAX_ROUNDED_ARRAY_LENGTH"
def TOO_LONG_SEQUENCE(sequenceLength: Int, functionName: String) = {
mythrocks marked this conversation as resolved.
Show resolved Hide resolved
// For these Spark versions, the sequence length and function name
// do not appear in the exception message.
s"Too long sequence found. Should be <= $MAX_ROUNDED_ARRAY_LENGTH"
}

/**
* Compute the size of each sequence according to 'start', 'stop' and 'step'.
* A row (Row[start, stop, step]) contains at least one null element will produce
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,12 @@ import ai.rapids.cudf._
import com.nvidia.spark.rapids.Arm._

import org.apache.spark.sql.rapids.{AddOverflowChecks, SubtractOverflowChecks}
import org.apache.spark.sql.rapids.shims.SequenceSizeError
import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH

object GetSequenceSize {
val TOO_LONG_SEQUENCE = "Unsuccessful try to create array with elements exceeding the array " +
s"size limit $MAX_ROUNDED_ARRAY_LENGTH"
def TOO_LONG_SEQUENCE(sequenceLength: Int, functionName: String): String =
SequenceSizeError.getTooLongSequenceErrorString(sequenceLength, functionName)
/**
* Compute the size of each sequence according to 'start', 'stop' and 'step'.
* A row (Row[start, stop, step]) contains at least one null element will produce
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/*** spark-rapids-shim-json-lines
{"spark": "334"}
{"spark": "342"}
{"spark": "343"}
{"spark": "351"}
{"spark": "352"}
spark-rapids-shim-json-lines ***/
package org.apache.spark.sql.rapids.shims

import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH

object SequenceSizeError {
def getTooLongSequenceErrorString(sequenceSize: Int, functionName: String): String = {
// The errant function's name does not feature in the exception message
// prior to Spark 4.0. Neither does the attempted allocation size.
"Unsuccessful try to create array with elements exceeding the array " +
s"size limit $MAX_ROUNDED_ARRAY_LENGTH"
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/*** spark-rapids-shim-json-lines
{"spark": "400"}
spark-rapids-shim-json-lines ***/
package org.apache.spark.sql.rapids.shims

import org.apache.spark.sql.errors.QueryExecutionErrors

object SequenceSizeError {
def getTooLongSequenceErrorString(sequenceSize: Int, functionName: String): String = {
QueryExecutionErrors.createArrayWithElementsExceedLimitError(functionName, sequenceSize)
mythrocks marked this conversation as resolved.
Show resolved Hide resolved
.getMessage
}
}