Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
Signed-off-by: Firestarman <[email protected]>
  • Loading branch information
firestarman committed Oct 11, 2024
1 parent 6f0b43d commit df32c7e
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 5 deletions.
8 changes: 6 additions & 2 deletions integration_tests/src/main/python/hash_aggregate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1118,7 +1118,8 @@ def test_hash_groupby_typed_imperative_agg_without_gpu_implementation_fallback()
@disable_ansi_mode # https://github.com/NVIDIA/spark-rapids/issues/5114
@pytest.mark.parametrize('data_gen', _init_list, ids=idfn)
@pytest.mark.parametrize('conf', get_params(_confs, params_markers_for_confs), ids=idfn)
def test_hash_multiple_mode_query(data_gen, conf):
@pytest.mark.parametrize('shuffle_split', [True, False], ids=idfn)
def test_hash_multiple_mode_query(data_gen, conf, shuffle_split):
print_params(data_gen)
assert_gpu_and_cpu_are_equal_collect(
lambda spark: gen_df(spark, data_gen, length=100)
Expand All @@ -1132,7 +1133,10 @@ def test_hash_multiple_mode_query(data_gen, conf):
f.max('a'),
f.sumDistinct('b'),
f.countDistinct('c')
), conf=conf)
),
conf=copy_and_update(
conf,
{'spark.rapids.shuffle.splitRetryRead.enabled': shuffle_split}))


@approximate_float
Expand Down
8 changes: 6 additions & 2 deletions integration_tests/src/main/python/join_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,21 +242,25 @@ def test_hash_join_ridealong_non_sized(data_gen, join_type, sub_part_enabled):
@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', basic_nested_gens + [decimal_gen_128bit], ids=idfn)
@pytest.mark.parametrize('join_type', all_symmetric_sized_join_types, ids=idfn)
@pytest.mark.parametrize('shuffle_split', [True, False], ids=idfn)
@allow_non_gpu(*non_utc_allow)
def test_hash_join_ridealong_symmetric(data_gen, join_type):
def test_hash_join_ridealong_symmetric(data_gen, join_type, shuffle_split):
confs = {
"spark.rapids.sql.join.useShuffledSymmetricHashJoin": "true",
"spark.rapids.shuffle.splitRetryRead.enabled": shuffle_split,
}
hash_join_ridealong(data_gen, join_type, confs)

@validate_execs_in_gpu_plan('GpuShuffledAsymmetricHashJoinExec')
@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', basic_nested_gens + [decimal_gen_128bit], ids=idfn)
@pytest.mark.parametrize('join_type', all_asymmetric_sized_join_types, ids=idfn)
@pytest.mark.parametrize('shuffle_split', [True, False], ids=idfn)
@allow_non_gpu(*non_utc_allow)
def test_hash_join_ridealong_asymmetric(data_gen, join_type):
def test_hash_join_ridealong_asymmetric(data_gen, join_type, shuffle_split):
confs = {
"spark.rapids.sql.join.useShuffledAsymmetricHashJoin": "true",
"spark.rapids.shuffle.splitRetryRead.enabled": shuffle_split,
}
hash_join_ridealong(data_gen, join_type, confs)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@ object GpuShuffleCoalesceUtils {
reader.prefetchHeadOnHost()
}
}
println("===> use GpuShuffleCoalesce Reader")
reader.asIterator
} else {
val hostIter = new HostShuffleCoalesceIterator(iter, targetSize, metricsMap)
Expand Down

0 comments on commit df32c7e

Please sign in to comment.