diff --git a/include/boost/compute/algorithm/detail/radix_sort.hpp b/include/boost/compute/algorithm/detail/radix_sort.hpp index 8e6d5f9c0..53b1205c7 100644 --- a/include/boost/compute/algorithm/detail/radix_sort.hpp +++ b/include/boost/compute/algorithm/detail/radix_sort.hpp @@ -17,6 +17,9 @@ #include #include +#include +#include + #include #include #include @@ -305,9 +308,12 @@ inline void radix_sort_impl(const buffer_iterator first, options << " -DASC"; } + // get type definition if it is a custom struct + std::string custom_type_def = boost::compute::type_definition() + "\n"; + // load radix sort program program radix_sort_program = cache->get_or_build( - cache_key, options.str(), radix_sort_source, context + cache_key, options.str(), custom_type_def + radix_sort_source, context ); kernel count_kernel(radix_sort_program, "count"); diff --git a/include/boost/compute/type_traits/type_definition.hpp b/include/boost/compute/type_traits/type_definition.hpp index de9095fbd..3dcc4607f 100644 --- a/include/boost/compute/type_traits/type_definition.hpp +++ b/include/boost/compute/type_traits/type_definition.hpp @@ -18,7 +18,10 @@ namespace compute { namespace detail { template -struct type_definition_trait; +struct type_definition_trait +{ + static std::string value() { return std::string(); } +}; } // end detail namespace diff --git a/test/test_sort_by_key.cpp b/test/test_sort_by_key.cpp index ab99fd2fe..978d81d36 100644 --- a/test/test_sort_by_key.cpp +++ b/test/test_sort_by_key.cpp @@ -15,6 +15,16 @@ #include #include #include +#include + +struct custom_struct +{ + boost::compute::int_ x; + boost::compute::int_ y; + boost::compute::float2_ zw; +}; + +BOOST_COMPUTE_ADAPT_STRUCT(custom_struct, custom_struct, (x, y, zw)) #include "check_macros.hpp" #include "context_setup.hpp" @@ -69,15 +79,15 @@ BOOST_AUTO_TEST_CASE(sort_int_2) BOOST_AUTO_TEST_CASE(sort_char_by_int) { int keys_data[] = { 6, 2, 1, 3, 4, 7, 5, 0 }; - char values_data[] = { 'g', 'c', 'b', 'd', 'e', 'h', 'f', 'a' }; + compute::char_ values_data[] = { 'g', 'c', 'b', 'd', 'e', 'h', 'f', 'a' }; compute::vector keys(keys_data, keys_data + 8, queue); - compute::vector values(values_data, values_data + 8, queue); + compute::vector values(values_data, values_data + 8, queue); compute::sort_by_key(keys.begin(), keys.end(), values.begin(), queue); CHECK_RANGE_EQUAL(int, 8, keys, (0, 1, 2, 3, 4, 5, 6, 7)); - CHECK_RANGE_EQUAL(char, 8, values, ('a', 'b', 'c', 'd', 'e', 'f', 'g', 'h')); + CHECK_RANGE_EQUAL(compute::char_, 8, values, ('a', 'b', 'c', 'd', 'e', 'f', 'g', 'h')); } BOOST_AUTO_TEST_CASE(sort_int_and_float) @@ -132,4 +142,66 @@ BOOST_AUTO_TEST_CASE(sort_int_and_float_custom_comparison_func) BOOST_CHECK(compute::is_sorted(values.begin(), values.end(), queue) == true); } +BOOST_AUTO_TEST_CASE(sort_int_and_float2) +{ + using boost::compute::int_; + using boost::compute::float2_; + + int n = 1024; + std::vector host_keys(n); + std::vector host_values(n); + for(int i = 0; i < n; i++){ + host_keys[i] = n - i; + host_values[i] = float2_((n - i) / 2.f); + } + + BOOST_COMPUTE_FUNCTION(bool, sort_float2, (float2_ a, float2_ b), + { + return a.x < b.x; + }); + + compute::vector keys(host_keys.begin(), host_keys.end(), queue); + compute::vector values(host_values.begin(), host_values.end(), queue); + + BOOST_CHECK(compute::is_sorted(keys.begin(), keys.end(), queue) == false); + BOOST_CHECK(compute::is_sorted(values.begin(), values.end(), sort_float2, queue) == false); + + compute::sort_by_key(keys.begin(), keys.end(), values.begin(), queue); + + BOOST_CHECK(compute::is_sorted(keys.begin(), keys.end(), queue) == true); + BOOST_CHECK(compute::is_sorted(values.begin(), values.end(), sort_float2, queue) == true); +} + +BOOST_AUTO_TEST_CASE(sort_custom_struct_by_int) +{ + using boost::compute::int_; + using boost::compute::float2_; + + int_ n = 1024; + std::vector host_keys(n); + std::vector host_values(n); + for(int_ i = 0; i < n; i++){ + host_keys[i] = n - i; + host_values[i].x = n - i; + host_values[i].y = n - i; + host_values[i].zw = float2_((n - i) / 0.5f); + } + + BOOST_COMPUTE_FUNCTION(bool, sort_custom_struct, (custom_struct a, custom_struct b), + { + return a.x < b.x; + }); + + compute::vector keys(host_keys.begin(), host_keys.end(), queue); + compute::vector values(host_values.begin(), host_values.end(), queue); + + BOOST_CHECK(compute::is_sorted(keys.begin(), keys.end(), queue) == false); + BOOST_CHECK(compute::is_sorted(values.begin(), values.end(), sort_custom_struct, queue) == false); + + compute::sort_by_key(keys.begin(), keys.end(), values.begin(), queue); + + BOOST_CHECK(compute::is_sorted(keys.begin(), keys.end(), queue) == true); + BOOST_CHECK(compute::is_sorted(values.begin(), values.end(), sort_custom_struct, queue) == true); +} + BOOST_AUTO_TEST_SUITE_END()