Skip to content

Commit e976315

Browse files
authored
[SYCLomatic] Support migration of thrust random types (#1002)
Signed-off-by: chenwei.sun <[email protected]>
1 parent 70aea77 commit e976315

File tree

4 files changed

+107
-5
lines changed

4 files changed

+107
-5
lines changed

clang/lib/DPCT/APINamesTemplateType.inc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,4 +200,16 @@ TYPE_REWRITE_ENTRY("cooperative_groups::__v1::bit_or",
200200
TYPE_FACTORY(STR(MapNames::getClNamespace() + "bit_or"),
201201
TEMPLATE_ARG(0)))
202202

203+
TYPE_REWRITE_ENTRY("thrust::random::default_random_engine",
204+
TYPE_FACTORY(STR("oneapi::dpl::default_engine")))
205+
206+
TYPE_REWRITE_ENTRY("thrust::random::uniform_real_distribution",
207+
TYPE_FACTORY(STR("oneapi::dpl::uniform_real_distribution"), TEMPLATE_ARG(0)))
208+
209+
TYPE_REWRITE_ENTRY("thrust::random::normal_distribution",
210+
TYPE_FACTORY(STR("oneapi::dpl::normal_distribution"), TEMPLATE_ARG(0)))
211+
212+
TYPE_REWRITE_ENTRY("thrust::random::uniform_int_distribution",
213+
TYPE_FACTORY(STR("oneapi::dpl::uniform_int_distribution"), TEMPLATE_ARG(0)))
214+
203215
// clang-format on

clang/lib/DPCT/ASTTraversal.cpp

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1123,6 +1123,8 @@ void IncludesCallbacks::InclusionDirective(
11231123
} else if (FileName.compare(StringRef("thrust/uninitialized_copy.h")) ==
11241124
0) {
11251125
DpctGlobalInfo::getInstance().insertHeader(HashLoc, HT_DPL_Memory);
1126+
} else if (FileName.compare(StringRef("thrust/random.h")) == 0) {
1127+
DpctGlobalInfo::getInstance().insertHeader(HashLoc, HT_DPL_Random);
11261128
} else {
11271129
if(FileName.compare(StringRef("thrust/functional.h")) == 0)
11281130
DpctGlobalInfo::getInstance().insertHeader(HashLoc, HT_Functional);
@@ -14644,17 +14646,37 @@ void TemplateSpecializationTypeLocRule::registerMatcher(
1464414646
ast_matchers::MatchFinder &MF) {
1464514647
auto TargetTypeName = [&]() {
1464614648
return hasAnyName("thrust::not_equal_to", "thrust::constant_iterator",
14647-
"thrust::system::cuda::experimental::pinned_allocator");
14649+
"thrust::system::cuda::experimental::pinned_allocator",
14650+
"thrust::random::default_random_engine",
14651+
"thrust::random::uniform_real_distribution",
14652+
"thrust::random::normal_distribution",
14653+
"thrust::random::linear_congruential_engine",
14654+
"thrust::random::uniform_int_distribution");
1464814655
};
1464914656

14650-
MF.addMatcher(typeLoc(
14651-
loc(qualType(hasDeclaration(namedDecl(TargetTypeName())))))
14652-
.bind("loc"),
14653-
this);
14657+
MF.addMatcher(
14658+
typeLoc(loc(qualType(hasDeclaration(namedDecl(TargetTypeName())))))
14659+
.bind("loc"),
14660+
this);
14661+
14662+
MF.addMatcher(declRefExpr().bind("declRefExpr"), this);
1465414663
}
1465514664

1465614665
void TemplateSpecializationTypeLocRule::runRule(
1465714666
const ast_matchers::MatchFinder::MatchResult &Result) {
14667+
14668+
const DeclRefExpr *DRE = getNodeAsType<DeclRefExpr>(Result, "declRefExpr");
14669+
if (DRE) {
14670+
std::string TypeName = DpctGlobalInfo::getTypeName(DRE->getType());
14671+
std::string Name = DRE->getNameInfo().getName().getAsString();
14672+
if (TypeName.find("thrust::random::linear_congruential_engine") !=
14673+
std::string::npos &&
14674+
Name == "max") {
14675+
emplaceTransformation(
14676+
new ReplaceStmt(DRE, "oneapi::dpl::default_engine::max()"));
14677+
}
14678+
}
14679+
1465814680
if (auto TL = getNodeAsType<TypeLoc>(Result, "loc")) {
1465914681
ExprAnalysis EA;
1466014682
EA.analyze(*TL);

clang/lib/DPCT/HeaderTypes.inc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ ONEDPL_HEADER(Execution, "<oneapi/dpl/execution>")
5252
ONEDPL_HEADER(Iterator, "<oneapi/dpl/iterator>")
5353
ONEDPL_HEADER(Async, "<oneapi/dpl/async>")
5454
ONEDPL_HEADER(Memory, "<oneapi/dpl/memory>")
55+
ONEDPL_HEADER(Random, "<oneapi/dpl/random>")
5556

5657
MKL_HEADER(Mkl, "<oneapi/mkl.hpp>")
5758
MKL_HEADER(RNG, "<oneapi/mkl/rng/device.hpp>")
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
// UNSUPPORTED: cuda-8.0
2+
// UNSUPPORTED: v8.0
3+
// RUN: dpct --format-range=none -out-root %T/thrust-random-type %s --cuda-include-path="%cuda-path/include" -- -std=c++14 -x cuda --cuda-host-only
4+
// RUN: FileCheck --input-file %T/thrust-random-type/thrust-random-type.dp.cpp --match-full-lines %s
5+
6+
7+
#include <iostream>
8+
#include <thrust/device_vector.h>
9+
#include <thrust/iterator/counting_iterator.h>
10+
#include <thrust/random.h>
11+
#include <thrust/random/linear_congruential_engine.h>
12+
#include <thrust/random/uniform_real_distribution.h>
13+
#include <thrust/transform.h>
14+
15+
// CHECK:struct random_1 {
16+
// CHECK-NEXT: float operator()(const unsigned int n) const {
17+
// CHECK-NEXT: oneapi::dpl::default_engine rng;
18+
// CHECK-NEXT: oneapi::dpl::uniform_real_distribution<float> dist(1.0f, 2.0f);
19+
// CHECK-NEXT: rng.discard(n);
20+
// CHECK-NEXT: return dist(rng);
21+
// CHECK-NEXT: }
22+
// CHECK-NEXT:};
23+
struct random_1 {
24+
__host__ __device__ float operator()(const unsigned int n) const {
25+
thrust::default_random_engine rng;
26+
thrust::uniform_real_distribution<float> dist(1.0f, 2.0f);
27+
rng.discard(n);
28+
return dist(rng);
29+
}
30+
};
31+
32+
33+
// CHECK:struct random_2 {
34+
// CHECK-NEXT: float operator()(const unsigned int n) const {
35+
// CHECK-NEXT: oneapi::dpl::default_engine rng;
36+
// CHECK-NEXT: rng.discard(n);
37+
// CHECK-NEXT: return (float)rng() / oneapi::dpl::default_engine::max();
38+
// CHECK-NEXT: }
39+
// CHECK-NEXT:};
40+
struct random_2 {
41+
__device__ float operator()(const unsigned int n) {
42+
thrust::default_random_engine rng;
43+
rng.discard(n);
44+
return (float)rng() / thrust::default_random_engine::max;
45+
}
46+
};
47+
48+
void test(void) {
49+
{
50+
const int N = 20;
51+
// CHECK: dpct::device_vector<float> numbers(N);
52+
// CHECK-NEXT: oneapi::dpl::counting_iterator<unsigned int> index_sequence_begin(0);
53+
// CHECK-NEXT: std::transform(oneapi::dpl::execution::seq, index_sequence_begin, index_sequence_begin + N, numbers.begin(), random_1());
54+
// CHECK-NEXT: std::transform(oneapi::dpl::execution::seq, index_sequence_begin, index_sequence_begin + N, numbers.begin(), random_2());
55+
thrust::device_vector<float> numbers(N);
56+
thrust::counting_iterator<unsigned int> index_sequence_begin(0);
57+
thrust::transform(index_sequence_begin, index_sequence_begin + N, numbers.begin(), random_1());
58+
thrust::transform(index_sequence_begin, index_sequence_begin + N, numbers.begin(), random_2());
59+
}
60+
61+
{
62+
// CHECK: oneapi::dpl::uniform_int_distribution<int> dist1(-5, 10);
63+
// CHECK-NEXT: oneapi::dpl::normal_distribution<float> dist2(1.0f, 2.0f);
64+
thrust::uniform_int_distribution<int> dist1(-5, 10);
65+
thrust::normal_distribution<float> dist2(1.0f, 2.0f);
66+
}
67+
}

0 commit comments

Comments
 (0)