Skip to content

Commit 1a32fd2

Browse files
authored
[SYCLomatic] Enable cooperativate_groups::reduce function (#863)
Signed-off-by: Chen, Sheng S <[email protected]>
1 parent b32836f commit 1a32fd2

File tree

7 files changed

+92
-3
lines changed

7 files changed

+92
-3
lines changed

clang/lib/DPCT/APINamesCooperativeGroups.inc

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,21 @@ CONDITIONAL_FACTORY_ENTRY(
6969
UNSUPPORT_FACTORY_ENTRY("sync",
7070
Diagnostics::API_NOT_MIGRATED,
7171
ARG("sync"))))))
72+
73+
/*
74+
--- cg::reduce(group, val, op) ---
75+
76+
cg::reduce(cg::thread_block_tile<32>, val, op)
77+
=> sycl::reduce_over_group(item.get_sub_group(), val, op)
78+
*/
79+
CONDITIONAL_FACTORY_ENTRY(
80+
argHasThreadBlockTileType(0, 32),
81+
CALL_FACTORY_ENTRY("reduce", CALL(MapNames::getClNamespace() + "reduce_over_group",
82+
SUBGROUP, ARG(1), ARG(2))),
83+
UNSUPPORT_FACTORY_ENTRY("reduce",
84+
Diagnostics::API_NOT_MIGRATED,
85+
ARG("reduce")))
86+
7287
/*
7388
--- cg::thread_rank(X) ---
7489
cg::thread_rank(cg::thread_block)

clang/lib/DPCT/APINamesTemplateType.inc

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,3 +173,27 @@ TYPE_REWRITE_ENTRY("thrust::identity",
173173
TYPE_FACTORY(STR("oneapi::dpl::identity")))
174174

175175
TYPE_REWRITE_ENTRY("thrust::null_type", TYPE_FACTORY(STR("dpct::null_type")))
176+
177+
TYPE_REWRITE_ENTRY("cooperative_groups::__v1::plus",
178+
TYPE_FACTORY(STR(MapNames::getClNamespace() + "plus"),
179+
TEMPLATE_ARG(0)))
180+
181+
TYPE_REWRITE_ENTRY("cooperative_groups::__v1::less",
182+
TYPE_FACTORY(STR(MapNames::getClNamespace() + "minimum"),
183+
TEMPLATE_ARG(0)))
184+
185+
TYPE_REWRITE_ENTRY("cooperative_groups::__v1::greater",
186+
TYPE_FACTORY(STR(MapNames::getClNamespace() + "maximum"),
187+
TEMPLATE_ARG(0)))
188+
189+
TYPE_REWRITE_ENTRY("cooperative_groups::__v1::bit_and",
190+
TYPE_FACTORY(STR(MapNames::getClNamespace() + "bit_and"),
191+
TEMPLATE_ARG(0)))
192+
193+
TYPE_REWRITE_ENTRY("cooperative_groups::__v1::bit_xor",
194+
TYPE_FACTORY(STR(MapNames::getClNamespace() + "bit_xor"),
195+
TEMPLATE_ARG(0)))
196+
197+
TYPE_REWRITE_ENTRY("cooperative_groups::__v1::bit_or",
198+
TYPE_FACTORY(STR(MapNames::getClNamespace() + "bit_or"),
199+
TEMPLATE_ARG(0)))

clang/lib/DPCT/ASTTraversal.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11933,6 +11933,11 @@ void CooperativeGroupsFunctionRule::runRule(
1193311933
}
1193411934
}
1193511935
}
11936+
} else if (FuncName == "reduce") {
11937+
RUW.NeedReport = false;
11938+
ExprAnalysis EA(CE);
11939+
emplaceTransformation(EA.getReplacement());
11940+
EA.applyAllSubExprRepl();
1193611941
}
1193711942
}
1193811943

clang/lib/DPCT/ExprAnalysis.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -573,7 +573,8 @@ void ExprAnalysis::analyzeExpr(const CXXTemporaryObjectExpr *Temp) {
573573
Temp->getType().getCanonicalType());
574574
if ((StringRef(TypeName).startswith("cub::") &&
575575
CubTypeRule::CanMappingToSyclType(TypeName)) ||
576-
StringRef(TypeName).startswith("thrust::")) {
576+
StringRef(TypeName).startswith("thrust::") ||
577+
StringRef(TypeName).startswith("cooperative_groups::")) {
577578
analyzeType(Temp->getTypeSourceInfo()->getTypeLoc());
578579
}
579580
analyzeExpr(static_cast<const CXXConstructExpr *>(Temp));
@@ -1034,8 +1035,11 @@ void ExprAnalysis::analyzeType(TypeLoc TL, const Expr *CSCE) {
10341035
}
10351036
case TypeLoc::TemplateSpecialization: {
10361037
llvm::raw_string_ostream OS(TyName);
1038+
TyName.clear();
10371039
auto &TSTL = TYPELOC_CAST(TemplateSpecializationTypeLoc);
1038-
TSTL.getTypePtr()->getTemplateName().print(OS, Context.getPrintingPolicy());
1040+
auto PP = Context.getPrintingPolicy();
1041+
PP.PrintCanonicalTypes = 1;
1042+
TSTL.getTypePtr()->getTemplateName().print(OS, PP, TemplateName::Qualified::Fully);
10391043
if (!TypeLocRewriterFactoryBase::TypeLocRewriterMap)
10401044
return;
10411045
auto Itr = TypeLocRewriterFactoryBase::TypeLocRewriterMap->find(OS.str());

clang/lib/DPCT/MapNames.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4428,4 +4428,4 @@ const std::vector<std::string> MemoryDataTypeRule::RemoveMember{
44284428

44294429
const std::unordered_set<std::string> MapNames::CooperativeGroupsAPISet{
44304430
"this_thread_block", "sync", "tiled_partition",
4431-
"thread_rank", "size", "shfl_down"};
4431+
"thread_rank", "size", "shfl_down", "reduce"};

clang/lib/Headers/__clang_cuda_runtime_wrapper.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -559,4 +559,15 @@ extern "C" unsigned __cudaPushCallConfiguration(dim3 gridDim, dim3 blockDim,
559559
#pragma pop_macro("__clang_major__")
560560
#endif // SYCLomatic_CUSTOMIZATION
561561

562+
#if defined(SYCLomatic_CUSTOMIZATION)
563+
// Fixed the parsing error when source include <cooperative_groups/reduce.h>.
564+
#define __CUDA_INCLUDE_COMPILER_INTERNAL_HEADERS__
565+
# if CUDA_VERSION >= 11800
566+
#include <crt/sm_90_rt.h>
567+
#endif
568+
#if CUDA_VERSION >= 11000
569+
#include <crt/sm_80_rt.h>
570+
#endif
571+
#endif // SYCLomatic_CUSTOMIZATION
572+
562573
#endif // __CLANG_CUDA_RUNTIME_WRAPPER_H__
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
// UNSUPPORTED: cuda-8.0, cuda-9.0, cuda-9.1, cuda-9.2, cuda-10.0, cuda-10.1, cuda-10.2, cuda-11.0
2+
// UNSUPPORTED: v8.0, v9.0, v9.1, v9.2, v10.0, v10.1, v10.2, v11.0
3+
// RUN: dpct --format-range=none -out-root %T/cooperative_groups_reduce %s --cuda-include-path="%cuda-path/include" --extra-arg="-std=c++14"
4+
// RUN: FileCheck %s --match-full-lines --input-file %T/cooperative_groups_reduce/cooperative_groups_reduce.dp.cpp
5+
6+
7+
#include <cooperative_groups.h>
8+
#include <cooperative_groups/reduce.h>
9+
10+
namespace cg = cooperative_groups;
11+
12+
__device__ void testReduce(double *sdata, const cg::thread_block &cta) {
13+
const unsigned int tid = cta.thread_rank();
14+
cg::thread_block_tile<32> tile32 = cg::tiled_partition<32>(cta);
15+
int *idata;
16+
// CHECK: sycl::reduce_over_group(item_ct1.get_sub_group(), sdata[tid], sycl::plus<double>());
17+
cg::reduce(tile32, sdata[tid], cg::plus<double>());
18+
// CHECK: sycl::reduce_over_group(item_ct1.get_sub_group(), sdata[tid], sycl::minimum<double>());
19+
cg::reduce(tile32, sdata[tid], cg::less<double>());
20+
// CHECK: sycl::reduce_over_group(item_ct1.get_sub_group(), sdata[tid], sycl::maximum<double>());
21+
cg::reduce(tile32, sdata[tid], cg::greater<double>());
22+
// CHECK: sycl::reduce_over_group(item_ct1.get_sub_group(), idata[tid], sycl::bit_and<int>());
23+
cg::reduce(tile32, idata[tid], cg::bit_and<int>());
24+
// CHECK: sycl::reduce_over_group(item_ct1.get_sub_group(), idata[tid], sycl::bit_xor<int>());
25+
cg::reduce(tile32, idata[tid], cg::bit_xor<int>());
26+
// CHECK: sycl::reduce_over_group(item_ct1.get_sub_group(), idata[tid], sycl::bit_or<int>());
27+
cg::reduce(tile32, idata[tid], cg::bit_or<int>());
28+
cg::sync(cta);
29+
30+
}

0 commit comments

Comments
 (0)