Skip to content

Commit 3feaafd

Browse files
authored
[SYCLomatic] Migrate thrust::merge_by_key (#756)
Signed-off-by: Lu, John <[email protected]>
1 parent 7b1c02e commit 3feaafd

File tree

5 files changed

+231
-1
lines changed

5 files changed

+231
-1
lines changed

clang/lib/DPCT/APINamesThrust.inc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,13 @@ thrustFactory("thrust::gather_if",
6262
{{7,PolicyState::HasPolicy,5,"dpct::gather_if", HelperFeatureEnum::DplExtrasAlgorithm_gather_if},
6363
{6,PolicyState::NoPolicy ,5,"dpct::gather_if", HelperFeatureEnum::DplExtrasAlgorithm_gather_if}}),
6464

65+
// thrust::merge_by_key
66+
thrustFactory("thrust::merge_by_key",
67+
{{10,PolicyState::HasPolicy,8,"dpct::merge", HelperFeatureEnum::DplExtrasAlgorithm_merge},
68+
{ 9,PolicyState::HasPolicy,8,"dpct::merge", HelperFeatureEnum::DplExtrasAlgorithm_merge},
69+
{ 9,PolicyState::NoPolicy ,8,"dpct::merge", HelperFeatureEnum::DplExtrasAlgorithm_merge},
70+
{ 8,PolicyState::NoPolicy ,8,"dpct::merge", HelperFeatureEnum::DplExtrasAlgorithm_merge}}),
71+
6572
// thrust::inclusive_scan
6673
thrustFactory("thrust::inclusive_scan",
6774
{{5,PolicyState::HasPolicy,3,"oneapi::dpl::inclusive_scan", HelperFeatureEnum::no_feature_helper },

clang/lib/DPCT/APINames_thrust.inc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ ENTRY(thrust::max_element, thrust::max_element, true, NO_FLAG, P4, "Successful")
101101
ENTRY(thrust::max, thrust::max, true, NO_FLAG, P4, "Successful")
102102
ENTRY(thrust::min, thrust::min, true, NO_FLAG, P4, "Successful")
103103
ENTRY(thrust::merge, thrust::merge, true, NO_FLAG, P4, "Successful")
104-
ENTRY(thrust::merge_by_key, thrust::merge_by_key, false, NO_FLAG, P4, "comment")
104+
ENTRY(thrust::merge_by_key, thrust::merge_by_key, true, NO_FLAG, P4, "Successful")
105105
ENTRY(thrust::min_element, thrust::min_element, true, NO_FLAG, P4, "Successful")
106106
ENTRY(thrust::minmax_element, thrust::minmax_element, true, NO_FLAG, P4, "comment")
107107
ENTRY(thrust::mismatch, thrust::mismatch, false, NO_FLAG, P4, "comment")
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
// UNSUPPORTED: cuda-8.0, cuda-12.0
2+
// UNSUPPORTED: v8.0, v12.0
3+
// RUN: dpct --format-range=none --usm-level=none --use-custom-helper=api -out-root %T/DplExtrasAlgorithm/api_test23_out %s --cuda-include-path="%cuda-path/include" -- -x cuda --cuda-host-only -std=c++17 -fsized-deallocation
4+
// RUN: grep "IsCalled" %T/DplExtrasAlgorithm/api_test23_out/MainSourceFiles.yaml | wc -l > %T/DplExtrasAlgorithm/api_test23_out/count.txt
5+
// RUN: FileCheck --input-file %T/DplExtrasAlgorithm/api_test23_out/count.txt --match-full-lines %s
6+
// RUN: rm -rf %T/DplExtrasAlgorithm/api_test23_out
7+
8+
// CHECK: 43
9+
// TEST_FEATURE: DplExtrasAlgorithm_merge
10+
11+
#include <thrust/merge.h>
12+
#include <thrust/device_vector.h>
13+
14+
int main() {
15+
thrust::device_vector<int> AD(4);
16+
thrust::device_vector<int> BD(4);
17+
thrust::device_vector<int> CD(4);
18+
thrust::device_vector<int> DD(4);
19+
thrust::device_vector<int> ED(8);
20+
thrust::device_vector<int> FD(8);
21+
22+
thrust::merge_by_key( AD.begin(), AD.end(), BD.begin(), BD.end(), CD.begin(), DD.begin(), ED.begin(), FD.begin());
23+
return 0;
24+
}
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
// UNSUPPORTED: cuda-8.0
2+
// UNSUPPORTED: v8.0
3+
// RUN: dpct -out-root %T/thrust_merge_by_key %s --cuda-include-path="%cuda-path/include" --usm-level=none
4+
// RUN: FileCheck --input-file %T/thrust_merge_by_key/thrust_merge_by_key.dp.cpp --match-full-lines %s
5+
6+
#include <thrust/device_vector.h>
7+
#include <thrust/host_vector.h>
8+
#include <thrust/execution_policy.h>
9+
10+
int main(void) {
11+
12+
thrust::device_vector<int> AD(4);
13+
thrust::device_vector<int> BD(4);
14+
thrust::device_vector<int> CD(4);
15+
thrust::device_vector<int> DD(4);
16+
thrust::device_vector<int> ED(8);
17+
thrust::device_vector<int> FD(8);
18+
19+
thrust::host_vector<int> AH(4);
20+
thrust::host_vector<int> BH(4);
21+
thrust::host_vector<int> CH(4);
22+
thrust::host_vector<int> DH(4);
23+
thrust::host_vector<int> EH(8);
24+
thrust::host_vector<int> FH(8);
25+
26+
27+
int *h_ptr;
28+
int *d_ptr;
29+
30+
h_ptr = (int*)std::malloc(20 * sizeof(int));
31+
cudaMalloc(&d_ptr, 20 * sizeof(int));
32+
33+
/*******************************************************************************************
34+
1. Test merge_by_key
35+
2. Test four VERSIONs (with/without exec argument) AND (with/without comparator)
36+
3. Test each VERSION with (device_vector/host_vector/malloc-ed memory/cudaMalloc-ed memory)
37+
*******************************************************************************************/
38+
39+
/*********** merge_by_key ***********************************************************************************************************************************************/
40+
41+
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
42+
// CHECK:dpct::merge(oneapi::dpl::execution::seq, AH.begin(), AH.end(), BH.begin(), BH.end(), CH.begin(), DH.begin(), EH.begin(), FH.begin());
43+
// CHECK-NEXT:dpct::merge(oneapi::dpl::execution::make_device_policy(q_ct1), AD.begin(), AD.end(), BD.begin(), BD.end(), CD.begin(), DD.begin(), ED.begin(), FD.begin());
44+
// CHECK-NEXT:if (dpct::is_device_ptr(h_ptr + 4)) {
45+
// CHECK-NEXT: dpct::merge(oneapi::dpl::execution::make_device_policy(q_ct1), dpct::device_pointer<int>(h_ptr), dpct::device_pointer<int>(h_ptr + 4), dpct::device_pointer<>(BH.begin()), dpct::device_pointer<>(BH.end()), dpct::device_pointer<>(CH.begin()), dpct::device_pointer<>(DH.begin()), dpct::device_pointer<>(EH.begin()), dpct::device_pointer<>(FH.begin()));
46+
// CHECK-NEXT:} else {
47+
// CHECK-NEXT: dpct::merge(oneapi::dpl::execution::seq, h_ptr, h_ptr + 4, BH.begin(), BH.end(), CH.begin(), DH.begin(), EH.begin(), FH.begin());
48+
// CHECK-NEXT:};
49+
// VERSION first1 last1 first2 last2 val1 val2 keys values
50+
thrust::merge_by_key( AH.begin(), AH.end(), BH.begin(), BH.end(), CH.begin(), DH.begin(), EH.begin(), FH.begin());
51+
thrust::merge_by_key( AD.begin(), AD.end(), BD.begin(), BD.end(), CD.begin(), DD.begin(), ED.begin(), FD.begin());
52+
thrust::merge_by_key( h_ptr, h_ptr+4, BH.begin(), BH.end(), CH.begin(), DH.begin(), EH.begin(), FH.begin());
53+
// Overload not supported with thrust
54+
// thrust::merge_by_key( d_ptr, d_ptr+4, BD.begin(), BD.end(), CD.begin(), DD.begin(), ED.begin(), FD.begin());
55+
56+
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
57+
// CHECK:dpct::merge(oneapi::dpl::execution::seq, AH.begin(), AH.end(), BH.begin(), BH.end(), CH.begin(), DH.begin(), EH.begin(), FH.begin(), std::greater<int>());
58+
// CHECK-NEXT:dpct::merge(oneapi::dpl::execution::make_device_policy(q_ct1), AD.begin(), AD.end(), BD.begin(), BD.end(), CD.begin(), DD.begin(), ED.begin(), FD.begin(), std::greater<int>());
59+
// CHECK-NEXT:if (dpct::is_device_ptr(h_ptr + 4)) {
60+
// CHECK-NEXT: dpct::merge(oneapi::dpl::execution::make_device_policy(q_ct1), dpct::device_pointer<int>(h_ptr), dpct::device_pointer<int>(h_ptr + 4), dpct::device_pointer<>(BH.begin()), dpct::device_pointer<>(BH.end()), dpct::device_pointer<>(CH.begin()), dpct::device_pointer<>(DH.begin()), dpct::device_pointer<>(EH.begin()), dpct::device_pointer<>(FH.begin()), std::greater<int>());
61+
// CHECK-NEXT:} else {
62+
// CHECK-NEXT: dpct::merge(oneapi::dpl::execution::seq, h_ptr, h_ptr + 4, BH.begin(), BH.end(), CH.begin(), DH.begin(), EH.begin(), FH.begin(), std::greater<int>());
63+
// CHECK-NEXT:};
64+
// VERSION first1 last1 first2 last2 val1 val2 keys values comparator
65+
thrust::merge_by_key( AH.begin(), AH.end(), BH.begin(), BH.end(), CH.begin(), DH.begin(), EH.begin(), FH.begin(), thrust::greater<int>());
66+
thrust::merge_by_key( AD.begin(), AD.end(), BD.begin(), BD.end(), CD.begin(), DD.begin(), ED.begin(), FD.begin(), thrust::greater<int>());
67+
thrust::merge_by_key( h_ptr, h_ptr+4, BH.begin(), BH.end(), CH.begin(), DH.begin(), EH.begin(), FH.begin(), thrust::greater<int>());
68+
#ifdef ADD_BUG
69+
// This fails with nvcc
70+
thrust::merge_by_key( d_ptr, d_ptr+4, BD.begin(), BD.end(), CD.begin(), DD.begin(), ED.begin(), FD.begin(), thrust::greater<int>());
71+
#endif
72+
73+
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
74+
// CHECK:dpct::merge(oneapi::dpl::execution::seq, AH.begin(), AH.end(), BH.begin(), BH.end(), CH.begin(), DH.begin(), EH.begin(), FH.begin());
75+
// CHECK-NEXT:dpct::merge(oneapi::dpl::execution::make_device_policy(q_ct1), AD.begin(), AD.end(), BD.begin(), BD.end(), CD.begin(), DD.begin(), ED.begin(), FD.begin());
76+
// CHECK-NEXT:if (dpct::is_device_ptr(h_ptr)) {
77+
// CHECK-NEXT: dpct::merge(oneapi::dpl::execution::make_device_policy(q_ct1), dpct::device_pointer<int>(h_ptr), dpct::device_pointer<int>(h_ptr + 4), dpct::device_pointer<>(BH.begin()), dpct::device_pointer<>(BH.end()), dpct::device_pointer<>(CH.begin()), dpct::device_pointer<>(DH.begin()), dpct::device_pointer<>(EH.begin()), dpct::device_pointer<>(FH.begin()));
78+
// CHECK-NEXT:} else {
79+
// CHECK-NEXT: dpct::merge(oneapi::dpl::execution::seq, h_ptr, h_ptr + 4, BH.begin(), BH.end(), CH.begin(), DH.begin(), EH.begin(), FH.begin());
80+
// CHECK-NEXT:};
81+
// CHECK-NEXT:if (dpct::is_device_ptr(d_ptr)) {
82+
// CHECK-NEXT: dpct::merge(oneapi::dpl::execution::make_device_policy(q_ct1), dpct::device_pointer<int>(d_ptr), dpct::device_pointer<int>(d_ptr + 4), dpct::device_pointer<>(BD.begin()), dpct::device_pointer<>(BD.end()), dpct::device_pointer<>(CD.begin()), dpct::device_pointer<>(DD.begin()), dpct::device_pointer<>(ED.begin()), dpct::device_pointer<>(FD.begin()));
83+
// CHECK-NEXT:} else {
84+
// CHECK-NEXT: dpct::merge(oneapi::dpl::execution::seq, d_ptr, d_ptr + 4, BD.begin(), BD.end(), CD.begin(), DD.begin(), ED.begin(), FD.begin());
85+
// CHECK-NEXT:};
86+
// VERSION first1 last1 first2 last2 val1 val2 keys values
87+
thrust::merge_by_key(thrust::host, AH.begin(), AH.end(), BH.begin(), BH.end(), CH.begin(), DH.begin(), EH.begin(), FH.begin());
88+
thrust::merge_by_key(thrust::device, AD.begin(), AD.end(), BD.begin(), BD.end(), CD.begin(), DD.begin(), ED.begin(), FD.begin());
89+
thrust::merge_by_key(thrust::host, h_ptr, h_ptr+4, BH.begin(), BH.end(), CH.begin(), DH.begin(), EH.begin(), FH.begin());
90+
thrust::merge_by_key(thrust::device, d_ptr, d_ptr+4, BD.begin(), BD.end(), CD.begin(), DD.begin(), ED.begin(), FD.begin());
91+
92+
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
93+
// CHECK:dpct::merge(oneapi::dpl::execution::seq, AH.begin(), AH.end(), BH.begin(), BH.end(), CH.begin(), DH.begin(), EH.begin(), FH.begin(), std::greater<int>());
94+
// CHECK-NEXT:dpct::merge(oneapi::dpl::execution::make_device_policy(q_ct1), AD.begin(), AD.end(), BD.begin(), BD.end(), CD.begin(), DD.begin(), ED.begin(), FD.begin(), std::greater<int>());
95+
// CHECK-NEXT:if (dpct::is_device_ptr(h_ptr)) {
96+
// CHECK-NEXT: dpct::merge(oneapi::dpl::execution::make_device_policy(q_ct1), dpct::device_pointer<int>(h_ptr), dpct::device_pointer<int>(h_ptr + 4), dpct::device_pointer<>(BH.begin()), dpct::device_pointer<>(BH.end()), dpct::device_pointer<>(CH.begin()), dpct::device_pointer<>(DH.begin()), dpct::device_pointer<>(EH.begin()), dpct::device_pointer<>(FH.begin()), std::greater<int>());
97+
// CHECK-NEXT:} else {
98+
// CHECK-NEXT: dpct::merge(oneapi::dpl::execution::seq, h_ptr, h_ptr + 4, BH.begin(), BH.end(), CH.begin(), DH.begin(), EH.begin(), FH.begin(), std::greater<int>());
99+
// CHECK-NEXT:};
100+
// CHECK-NEXT:if (dpct::is_device_ptr(d_ptr)) {
101+
// CHECK-NEXT: dpct::merge(oneapi::dpl::execution::make_device_policy(q_ct1), dpct::device_pointer<int>(d_ptr), dpct::device_pointer<int>(d_ptr + 4), dpct::device_pointer<>(BD.begin()), dpct::device_pointer<>(BD.end()), dpct::device_pointer<>(CD.begin()), dpct::device_pointer<>(DD.begin()), dpct::device_pointer<>(ED.begin()), dpct::device_pointer<>(FD.begin()), std::greater<int>());
102+
// CHECK-NEXT:} else {
103+
// CHECK-NEXT: dpct::merge(oneapi::dpl::execution::seq, d_ptr, d_ptr + 4, BD.begin(), BD.end(), CD.begin(), DD.begin(), ED.begin(), FD.begin(), std::greater<int>());
104+
// CHECK-NEXT:};
105+
// VERSION first1 last1 first2 last2 val1 val2 keys values comparator
106+
thrust::merge_by_key(thrust::host, AH.begin(), AH.end(), BH.begin(), BH.end(), CH.begin(), DH.begin(), EH.begin(), FH.begin(), thrust::greater<int>());
107+
thrust::merge_by_key(thrust::device, AD.begin(), AD.end(), BD.begin(), BD.end(), CD.begin(), DD.begin(), ED.begin(), FD.begin(), thrust::greater<int>());
108+
thrust::merge_by_key(thrust::host, h_ptr, h_ptr+4, BH.begin(), BH.end(), CH.begin(), DH.begin(), EH.begin(), FH.begin(), thrust::greater<int>());
109+
thrust::merge_by_key(thrust::device, d_ptr, d_ptr+4, BD.begin(), BD.end(), CD.begin(), DD.begin(), ED.begin(), FD.begin(), thrust::greater<int>());
110+
111+
return 0;
112+
}

0 commit comments

Comments
 (0)