Skip to content

Commit

Permalink
Addressed test_cast_date_integral_and_fp.
Browse files Browse the repository at this point in the history
  • Loading branch information
mythrocks committed Oct 2, 2024
1 parent 80e97da commit eeeb4a0
Showing 1 changed file with 28 additions and 4 deletions.
32 changes: 28 additions & 4 deletions integration_tests/src/main/python/cast_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,11 +368,28 @@ def test_cast_long_to_decimal_overflow():
DecimalType(30, 3),
DecimalType(5, -3),
DecimalType(3, 0)], ids=idfn)
def test_cast_floating_point_to_decimal(data_gen, to_type):
def test_cast_floating_point_to_decimal_ansi_off(data_gen, to_type):
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, data_gen).select(
f.col('a'), f.col('a').cast(to_type)),
conf={'spark.rapids.sql.castFloatToDecimal.enabled': 'true'})
conf=copy_and_update(
ansi_disabled_conf,
{'spark.rapids.sql.castFloatToDecimal.enabled': 'true'}))


@pytest.mark.skip("https://github.com/NVIDIA/spark-rapids/issues/11550")
@pytest.mark.parametrize('data_gen', [FloatGen(special_cases=_float_special_cases)])
@pytest.mark.parametrize('to_type', [DecimalType(7, 1)])
def test_cast_floating_point_to_decimal_ansi_on(data_gen, to_type):
assert_gpu_and_cpu_error(
lambda spark : unary_op_df(spark, data_gen).select(
f.col('a'),
f.col('a').cast(to_type)).collect(),
conf=copy_and_update(
ansi_enabled_conf,
{'spark.rapids.sql.castFloatToDecimal.enabled': 'true'}),
error_message="[NUMERIC_VALUE_OUT_OF_RANGE.WITH_SUGGESTION]")


# casting these types to string should be passed
basic_gens_for_cast_to_string = [ByteGen, ShortGen, IntegerGen, LongGen, StringGen, BooleanGen, DateGen, TimestampGen]
Expand Down Expand Up @@ -838,7 +855,14 @@ def test_cast_fallback_not_UTC(from_gen, to_type):
{"spark.sql.session.timeZone": "+08",
"spark.rapids.sql.castStringToTimestamp.enabled": "true"})

def test_cast_date_integral_and_fp():

def test_cast_date_integral_and_fp_ansi_off():
"""
This tests that a date column can be cast to different numeric/floating-point types.
This needs to be tested with ANSI disabled, because some of these conversions are
not ANSI-compliant.
"""
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, date_gen).selectExpr(
"cast(a as boolean)", "cast(a as byte)", "cast(a as short)", "cast(a as int)", "cast(a as long)", "cast(a as float)", "cast(a as double)"))
"cast(a as boolean)", "cast(a as byte)", "cast(a as short)", "cast(a as int)", "cast(a as long)", "cast(a as float)", "cast(a as double)"),
conf=ansi_disabled_conf)

0 comments on commit eeeb4a0

Please sign in to comment.