Skip to content

Commit

Permalink
Cherry picking bfloat16 radix sort test for 4.5 (#269)
Browse files Browse the repository at this point in the history
  • Loading branch information
stanleytsang-amd authored Oct 7, 2021
1 parent 33c1436 commit fa2d3b8
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 9 deletions.
15 changes: 6 additions & 9 deletions test/rocprim/test_device_radix_sort.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,11 @@ typedef ::testing::Types<
params<double, int, true>,
params<float, int>,
params<rocprim::half, long long>,
//TODO: Disable bfloat16 test until we get a better bfloat16 implemetation for host side
//params<rocprim::bfloat16, long long>,
params<rocprim::bfloat16, long long>,
params<int8_t, int8_t>,
params<uint8_t, uint8_t>,
params<rocprim::half, rocprim::half>,
//TODO: Disable bfloat16 test until we get a better bfloat16 implemetation for host side
//params<rocprim::bfloat16, rocprim::bfloat16>,
params<rocprim::bfloat16, rocprim::bfloat16>,
params<int, test_utils::custom_test_type<float>>,

// start_bit and end_bit
Expand All @@ -80,9 +78,8 @@ typedef ::testing::Types<
params<unsigned int, double, true, 4, 21>,
params<unsigned int, rocprim::half, true, 0, 15>,
params<unsigned short, rocprim::half, false, 3, 22>,
//TODO: Disable bfloat16 test until we get a better bfloat16 implemetation for host side
//params<unsigned int, rocprim::bfloat16, true, 0, 15>,
//params<unsigned short, rocprim::bfloat16, false, 3, 22>,
params<unsigned int, rocprim::bfloat16, true, 0, 12>,
params<unsigned short, rocprim::bfloat16, false, 3, 11>,
params<unsigned long long, char, false, 8, 20>,
params<unsigned short, test_utils::custom_test_type<double>, false, 8, 11>,

Expand Down Expand Up @@ -292,7 +289,7 @@ TYPED_TEST(RocprimDeviceRadixSort, SortPairs)
}

std::vector<value_type> values_input(size);
std::iota(values_input.begin(), values_input.end(), 0);
test_utils::iota(values_input.begin(), values_input.end(), 0);

key_type * d_keys_input;
key_type * d_keys_output;
Expand Down Expand Up @@ -599,7 +596,7 @@ TYPED_TEST(RocprimDeviceRadixSort, SortPairsDoubleBuffer)
}

std::vector<value_type> values_input(size);
std::iota(values_input.begin(), values_input.end(), 0);
test_utils::iota(values_input.begin(), values_input.end(), 0);

key_type * d_keys_input;
key_type * d_keys_output;
Expand Down
10 changes: 10 additions & 0 deletions test/rocprim/test_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1462,6 +1462,16 @@ void assert_eq(const rocprim::bfloat16& result, const rocprim::bfloat16& expecte
ASSERT_EQ(bfloat16_to_native(result), bfloat16_to_native(expected));
}

//TODO: Use custom iota until the follwing PR merge: https://github.com/ROCm-Developer-Tools/HIP/pull/2303
template<class ForwardIt, class T>
void iota(ForwardIt first, ForwardIt last, T value)
{
using value_type = typename std::iterator_traits<ForwardIt>::value_type;
while(first != last) {
*first++ = static_cast<value_type>(value);
++value;
}
}

} // end test_utils namespace

Expand Down

0 comments on commit fa2d3b8

Please sign in to comment.