Skip to content

Commit 36c4127

Browse files
authored
[SYCLomatic] Support migration of thrust::null_type (#837)
Signed-off-by: chenwei.sun <[email protected]>
1 parent 2c93ba1 commit 36c4127

File tree

3 files changed

+10
-1
lines changed

3 files changed

+10
-1
lines changed

clang/lib/DPCT/APINamesTemplateType.inc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,3 +171,5 @@ TYPE_REWRITE_ENTRY(
171171

172172
TYPE_REWRITE_ENTRY("thrust::identity",
173173
TYPE_FACTORY(STR("oneapi::dpl::identity")))
174+
175+
TYPE_REWRITE_ENTRY("thrust::null_type", TYPE_FACTORY(STR("dpct::null_type")))

clang/lib/DPCT/ThrustAPIMigration.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,8 @@ void ThrustTypeRule::registerMatcher(ast_matchers::MatchFinder &MF) {
297297
return hasAnyName("thrust::greater_equal", "thrust::less_equal",
298298
"thrust::logical_and", "thrust::bit_and",
299299
"thrust::bit_or", "thrust::minimum", "thrust::bit_xor",
300-
"thrust::modulus", "thrust::greater", "thrust::identity");
300+
"thrust::modulus", "thrust::greater", "thrust::identity",
301+
"thrust::null_type");
301302
};
302303
MF.addMatcher(typeLoc(loc(hasCanonicalType(qualType(
303304
hasDeclaration(namedDecl(ThrustTypeHasNames()))))))

clang/test/dpct/thrust-placeholders.cu

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <thrust/memory.h>
1212
#include <thrust/transform.h>
1313
#include <thrust/functional.h>
14+
#include <thrust/tuple.h>
1415
#include <cuda_runtime.h>
1516

1617
// CHECK: // COMMENT_A
@@ -34,6 +35,11 @@ __global__ void kernel(float *src, float *dst) {
3435
auto tmp2 = _1 + logf(2.0) + thrust::placeholders::_2;
3536
}
3637

38+
// CHECK: template <typename T1, typename T2 = dpct::null_type>
39+
// CHECK-NEXT:void print_tuple(const std::tuple<T1, T2> &t) {}
40+
template <typename T1, typename T2 = thrust::null_type>
41+
void print_tuple(const thrust::tuple<T1, T2> &t) {}
42+
3743
int main() {
3844
float x[] = {1.0f, 2.0f, 3.0f, 4.0f};
3945
float y[] = {2.0f, 1.0f, 1.0f, 1.0f};

0 commit comments

Comments
 (0)