Skip to content

Commit db19b77

Browse files
authored
[SYCLomatic] Enable migration of thrust::find_if and thrust::find_if_not (#777)
Signed-off-by: chenwei.sun <[email protected]>
1 parent 3feaafd commit db19b77

File tree

4 files changed

+98
-2
lines changed

4 files changed

+98
-2
lines changed

clang/lib/DPCT/APINamesThrust.inc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,16 @@ thrustFactory("thrust::any_of",
4545
{{4,PolicyState::HasPolicy,2,"oneapi::dpl::any_of", HelperFeatureEnum::no_feature_helper },
4646
{3,PolicyState::NoPolicy ,2,"oneapi::dpl::any_of", HelperFeatureEnum::no_feature_helper }}),
4747

48+
// thrust::find_if
49+
thrustFactory("thrust::find_if",
50+
{{4,PolicyState::HasPolicy,2,"oneapi::dpl::find_if", HelperFeatureEnum::no_feature_helper },
51+
{3,PolicyState::NoPolicy ,2,"oneapi::dpl::find_if", HelperFeatureEnum::no_feature_helper }}),
52+
53+
// thrust::find_if_not
54+
thrustFactory("thrust::find_if_not",
55+
{{4,PolicyState::HasPolicy,2,"oneapi::dpl::find_if_not", HelperFeatureEnum::no_feature_helper },
56+
{3,PolicyState::NoPolicy ,2,"oneapi::dpl::find_if_not", HelperFeatureEnum::no_feature_helper }}),
57+
4858
// thrust::replace
4959
thrustFactory("thrust::replace",
5060
{{5,PolicyState::HasPolicy,2,"oneapi::dpl::replace", HelperFeatureEnum::no_feature_helper },

clang/lib/DPCT/APINames_thrust.inc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,8 @@ ENTRY(thrust::exp, thrust::exp, true, NO_FLAG, P4, "Successful")
6767
ENTRY(thrust::fill, thrust::fill, true, NO_FLAG, P4, "Successful")
6868
ENTRY(thrust::fill_n, thrust::fill_n, true, NO_FLAG, P4, "Successful")
6969
ENTRY(thrust::find, thrust::find, true, NO_FLAG, P4, "Successful")
70-
ENTRY(thrust::find_if, thrust::find_if, false, NO_FLAG, P4, "comment")
71-
ENTRY(thrust::find_if_not, thrust::find_if_not, false, NO_FLAG, P4, "comment")
70+
ENTRY(thrust::find_if, thrust::find_if, true, NO_FLAG, P4, "comment")
71+
ENTRY(thrust::find_if_not, thrust::find_if_not, true, NO_FLAG, P4, "comment")
7272
ENTRY(thrust::for_each, thrust::for_each, true, NO_FLAG, P4, "Successful")
7373
ENTRY(thrust::for_each_n, thrust::for_each_n, true, NO_FLAG, P4, "comment")
7474
ENTRY(thrust::free, thrust::free, false, NO_FLAG, P4, "comment")

clang/test/dpct/thrust-algo-raw-ptr-noneusm.cu

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include <thrust/tabulate.h>
1717
#include <thrust/functional.h>
1818
#include <thrust/remove.h>
19+
#include <thrust/find.h>
1920

2021
// for cuda 12.0
2122
#include <thrust/iterator/constant_iterator.h>
@@ -562,3 +563,44 @@ void remvoe_test() {
562563
thrust::remove(thrust::host, data, data + N, 1);
563564
thrust::remove(data, data + N, 1);
564565
}
566+
567+
struct greater_than_four {
568+
__host__ __device__ bool operator()(int x) const { return x > 4; }
569+
};
570+
571+
void find_if_test() {
572+
const int N = 4;
573+
int data[4] = {0,5, 3, 7};
574+
575+
//CHECK: if (dpct::is_device_ptr(data + 3)) {
576+
//CHECK-NEXT: oneapi::dpl::find_if(oneapi::dpl::execution::make_device_policy(q_ct1), dpct::device_pointer<int>(data), dpct::device_pointer<int>(data + 3), greater_than_four());
577+
//CHECK-NEXT: } else {
578+
//CHECK-NEXT: oneapi::dpl::find_if(oneapi::dpl::execution::seq, data, data + 3, greater_than_four());
579+
//CHECK-NEXT: };
580+
//CHECK-NEXT: if (dpct::is_device_ptr(data)) {
581+
//CHECK-NEXT: oneapi::dpl::find_if(oneapi::dpl::execution::make_device_policy(q_ct1), dpct::device_pointer<int>(data), dpct::device_pointer<int>(data + 3), greater_than_four());
582+
//CHECK-NEXT: } else {
583+
//CHECK-NEXT: oneapi::dpl::find_if(oneapi::dpl::execution::seq, data, data + 3, greater_than_four());
584+
//CHECK-NEXT: };
585+
thrust::find_if(data, data+3, greater_than_four());
586+
thrust::find_if(thrust::host, data, data+3, greater_than_four());
587+
}
588+
589+
void find_if_not_test() {
590+
const int N = 4;
591+
int data[4] = {0,5, 3, 7};
592+
593+
//CHECK: if (dpct::is_device_ptr(data + 3)) {
594+
//CHECK-NEXT: oneapi::dpl::find_if_not(oneapi::dpl::execution::make_device_policy(q_ct1), dpct::device_pointer<int>(data), dpct::device_pointer<int>(data + 3), greater_than_four());
595+
//CHECK-NEXT: } else {
596+
//CHECK-NEXT: oneapi::dpl::find_if_not(oneapi::dpl::execution::seq, data, data + 3, greater_than_four());
597+
//CHECK-NEXT: };
598+
//CHECK-NEXT: if (dpct::is_device_ptr(data)) {
599+
//CHECK-NEXT: oneapi::dpl::find_if_not(oneapi::dpl::execution::make_device_policy(q_ct1), dpct::device_pointer<int>(data), dpct::device_pointer<int>(data + 3), greater_than_four());
600+
//CHECK-NEXT: } else {
601+
//CHECK-NEXT: oneapi::dpl::find_if_not(oneapi::dpl::execution::seq, data, data + 3, greater_than_four());
602+
//CHECK-NEXT: };
603+
thrust::find_if_not(data, data+3, greater_than_four());
604+
thrust::find_if_not(thrust::host, data, data+3, greater_than_four());
605+
}
606+

clang/test/dpct/thrust-algo.cu

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -931,3 +931,47 @@ void remvoe_test() {
931931
thrust::remove(host_data.begin(), host_data.begin() + N, 1);
932932
thrust::remove(device_data.begin(), device_data.begin() + N, 1);
933933
}
934+
935+
struct greater_than_four {
936+
__host__ __device__ bool operator()(int x) const { return x > 4; }
937+
};
938+
939+
void find_if_test() {
940+
const int N = 4;
941+
int data[4] = {0,5, 3, 7};
942+
thrust::device_vector<int> device_data(data, data + N);
943+
thrust::host_vector<int> host_data(data, data + N);
944+
945+
// CHECK:oneapi::dpl::find_if(oneapi::dpl::execution::seq, data, data+3, greater_than_four());
946+
// CHECK-NEXT:oneapi::dpl::find_if(oneapi::dpl::execution::make_device_policy(q_ct1), device_data.begin(), device_data.end(), greater_than_four());
947+
// CHECK-NEXT:oneapi::dpl::find_if(oneapi::dpl::execution::seq, host_data.begin(), host_data.end(), greater_than_four());
948+
// CHECK-NEXT:oneapi::dpl::find_if(oneapi::dpl::execution::seq, data, data+3, greater_than_four());
949+
// CHECK-NEXT:oneapi::dpl::find_if(oneapi::dpl::execution::make_device_policy(q_ct1), device_data.begin(), device_data.end(), greater_than_four());
950+
// CHECK-NEXT:oneapi::dpl::find_if(oneapi::dpl::execution::seq, host_data.begin(), host_data.end(), greater_than_four());
951+
thrust::find_if(data, data+3, greater_than_four());
952+
thrust::find_if(device_data.begin(), device_data.end(), greater_than_four());
953+
thrust::find_if(host_data.begin(), host_data.end(), greater_than_four());
954+
thrust::find_if(thrust::host, data, data+3, greater_than_four());
955+
thrust::find_if(thrust::device, device_data.begin(), device_data.end(), greater_than_four());
956+
thrust::find_if(thrust::host, host_data.begin(), host_data.end(), greater_than_four());
957+
}
958+
959+
void find_if_not_test() {
960+
const int N = 4;
961+
int data[4] = {0,5, 3, 7};
962+
thrust::device_vector<int> device_data(data, data + N);
963+
thrust::host_vector<int> host_data(data, data + N);
964+
965+
// CHECK:oneapi::dpl::find_if_not(oneapi::dpl::execution::seq, data, data+3, greater_than_four());
966+
// CHECK-NEXT:oneapi::dpl::find_if_not(oneapi::dpl::execution::make_device_policy(q_ct1), device_data.begin(), device_data.end(), greater_than_four());
967+
// CHECK-NEXT:oneapi::dpl::find_if_not(oneapi::dpl::execution::seq, host_data.begin(), host_data.end(), greater_than_four());
968+
// CHECK-NEXT:oneapi::dpl::find_if_not(oneapi::dpl::execution::seq, data, data+3, greater_than_four());
969+
// CHECK-NEXT:oneapi::dpl::find_if_not(oneapi::dpl::execution::make_device_policy(q_ct1), device_data.begin(), device_data.end(), greater_than_four());
970+
// CHECK-NEXT:oneapi::dpl::find_if_not(oneapi::dpl::execution::seq, host_data.begin(), host_data.end(), greater_than_four());
971+
thrust::find_if_not(data, data+3, greater_than_four());
972+
thrust::find_if_not(device_data.begin(), device_data.end(), greater_than_four());
973+
thrust::find_if_not(host_data.begin(), host_data.end(), greater_than_four());
974+
thrust::find_if_not(thrust::host, data, data+3, greater_than_four());
975+
thrust::find_if_not(thrust::device, device_data.begin(), device_data.end(), greater_than_four());
976+
thrust::find_if_not(thrust::host, host_data.begin(), host_data.end(), greater_than_four());
977+
}

0 commit comments

Comments
 (0)