Skip to content

Commit 17ac089

Browse files
author
Yihan Wang
authored
[SYCLomatic] Support migration of cub::DeviceSegmentedRadixSort::SortKeys/SortKeysDescending (#868)
Signed-off-by: Wang, Yihan <[email protected]>
1 parent 737a0d5 commit 17ac089

File tree

4 files changed

+439
-18
lines changed

4 files changed

+439
-18
lines changed

clang/lib/DPCT/APINamesCUB.inc

Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1246,6 +1246,196 @@ CASE_FACTORY_ENTRY(
12461246
QUEUESTR),
12471247
ARG(2), ARG(3), ARG(4), ARG(5), ARG(6), LITERAL("true"))))))))))))
12481248

1249+
// cub::DeviceSegmentedRadixSort::SortKeys
1250+
CASE_FACTORY_ENTRY(
1251+
CASE(CheckCubRedundantFunctionCall(),
1252+
REMOVE_API_FACTORY_ENTRY("cub::DeviceSegmentedRadixSort::SortKeys")),
1253+
OTHERWISE(FEATURE_REQUEST_FACTORY(
1254+
HelperFeatureEnum::DplExtrasAlgorithm_sort_keys,
1255+
HEADER_INSERT_FACTORY(
1256+
HeaderType::HT_DPCT_DPL_Utils,
1257+
REMOVE_CUB_TEMP_STORAGE_FACTORY(CASE_FACTORY_ENTRY(
1258+
CASE(CheckArgType(2, "cub::DoubleBuffer"),
1259+
CASE_FACTORY_ENTRY(
1260+
CASE(makeCheckAnd(
1261+
CheckArgCount(10, std::greater_equal<>(),
1262+
/* IncludeDefaultArg */ false),
1263+
makeCheckNot(CheckArgIsDefaultCudaStream(9))),
1264+
CALL_FACTORY_ENTRY(
1265+
"cub::DeviceSegmentedRadixSort::SortKeys",
1266+
CALL("dpct::segmented_sort_keys",
1267+
CALL("oneapi::dpl::execution::device_"
1268+
"policy",
1269+
STREAM(9)),
1270+
ARG(2), ARG(3), ARG(4), ARG(5), ARG(6),
1271+
LITERAL("false"), LITERAL("true"),
1272+
ARG(7), ARG(8)))),
1273+
CASE(CheckArgCount(9, std::greater_equal<>(),
1274+
/* IncludeDefaultArg */ false),
1275+
CALL_FACTORY_ENTRY(
1276+
"cub::DeviceSegmentedRadixSort::SortKeys",
1277+
CALL("dpct::segmented_sort_keys",
1278+
CALL("oneapi::dpl::execution::device_"
1279+
"policy",
1280+
QUEUESTR),
1281+
ARG(2), ARG(3), ARG(4), ARG(5), ARG(6),
1282+
LITERAL("false"), LITERAL("true"),
1283+
ARG(7), ARG(8)))),
1284+
CASE(CheckArgCount(8, std::greater_equal<>(),
1285+
/* IncludeDefaultArg */ false),
1286+
CALL_FACTORY_ENTRY(
1287+
"cub::DeviceSegmentedRadixSort::SortKeys",
1288+
CALL("dpct::segmented_sort_keys",
1289+
CALL("oneapi::dpl::execution::device_"
1290+
"policy",
1291+
QUEUESTR),
1292+
ARG(2), ARG(3), ARG(4), ARG(5), ARG(6),
1293+
LITERAL("false"), LITERAL("true"),
1294+
ARG(7)))),
1295+
OTHERWISE(CALL_FACTORY_ENTRY(
1296+
"cub::DeviceSegmentedRadixSort::SortKeys",
1297+
CALL("dpct::segmented_sort_keys",
1298+
CALL("oneapi::dpl::execution::device_policy",
1299+
QUEUESTR),
1300+
ARG(2), ARG(3), ARG(4), ARG(5), ARG(6),
1301+
LITERAL("false"), LITERAL("true")))))),
1302+
OTHERWISE(CASE_FACTORY_ENTRY(
1303+
CASE(makeCheckAnd(
1304+
CheckArgCount(11, std::greater_equal<>(),
1305+
/* IncludeDefaultArg */ false),
1306+
makeCheckNot(CheckArgIsDefaultCudaStream(10))),
1307+
CALL_FACTORY_ENTRY(
1308+
"cub::DeviceSegmentedRadixSort::SortKeys",
1309+
CALL("dpct::segmented_sort_keys",
1310+
CALL("oneapi::dpl::execution::device_policy",
1311+
STREAM(10)),
1312+
ARG(2), ARG(3), ARG(4), ARG(5), ARG(6),
1313+
ARG(7), LITERAL("false"), ARG(8), ARG(9)))),
1314+
CASE(CheckArgCount(10, std::greater_equal<>(),
1315+
/* IncludeDefaultArg */ false),
1316+
CALL_FACTORY_ENTRY(
1317+
"cub::DeviceSegmentedRadixSort::SortKeys",
1318+
CALL("dpct::segmented_sort_keys",
1319+
CALL("oneapi::dpl::execution::device_policy",
1320+
QUEUESTR),
1321+
ARG(2), ARG(3), ARG(4), ARG(5), ARG(6),
1322+
ARG(7), LITERAL("false"), ARG(8), ARG(9)))),
1323+
CASE(CheckArgCount(9, std::greater_equal<>(),
1324+
/* IncludeDefaultArg */ false),
1325+
CALL_FACTORY_ENTRY(
1326+
"cub::DeviceSegmentedRadixSort::SortKeys",
1327+
CALL("dpct::segmented_sort_keys",
1328+
CALL("oneapi::dpl::execution::device_policy",
1329+
QUEUESTR),
1330+
ARG(2), ARG(3), ARG(4), ARG(5), ARG(6),
1331+
ARG(7), LITERAL("false"), ARG(8)))),
1332+
OTHERWISE(CALL_FACTORY_ENTRY(
1333+
"cub::DeviceSegmentedRadixSort::SortKeys",
1334+
CALL("dpct::segmented_sort_keys",
1335+
CALL("oneapi::dpl::execution::device_policy",
1336+
QUEUESTR),
1337+
ARG(2), ARG(3), ARG(4), ARG(5), ARG(6), ARG(7),
1338+
LITERAL("false"))))))))))))
1339+
1340+
// cub::DeviceSegmentedRadixSort::SortKeysDescending
1341+
CASE_FACTORY_ENTRY(
1342+
CASE(CheckCubRedundantFunctionCall(),
1343+
REMOVE_API_FACTORY_ENTRY(
1344+
"cub::DeviceSegmentedRadixSort::SortKeysDescending")),
1345+
OTHERWISE(FEATURE_REQUEST_FACTORY(
1346+
HelperFeatureEnum::DplExtrasAlgorithm_sort_keys,
1347+
HEADER_INSERT_FACTORY(
1348+
HeaderType::HT_DPCT_DPL_Utils,
1349+
REMOVE_CUB_TEMP_STORAGE_FACTORY(CASE_FACTORY_ENTRY(
1350+
CASE(
1351+
CheckArgType(2, "cub::DoubleBuffer"),
1352+
CASE_FACTORY_ENTRY(
1353+
CASE(makeCheckAnd(
1354+
CheckArgCount(10, std::greater_equal<>(),
1355+
/* IncludeDefaultArg */ false),
1356+
makeCheckNot(CheckArgIsDefaultCudaStream(9))),
1357+
CALL_FACTORY_ENTRY(
1358+
"cub::DeviceSegmentedRadixSort::"
1359+
"SortKeysDescending",
1360+
CALL("dpct::segmented_sort_keys",
1361+
CALL("oneapi::dpl::execution::device_"
1362+
"policy",
1363+
STREAM(9)),
1364+
ARG(2), ARG(3), ARG(4), ARG(5), ARG(6),
1365+
LITERAL("true"), LITERAL("true"), ARG(7),
1366+
ARG(8)))),
1367+
CASE(CheckArgCount(9, std::greater_equal<>(),
1368+
/* IncludeDefaultArg */ false),
1369+
CALL_FACTORY_ENTRY(
1370+
"cub::DeviceSegmentedRadixSort::"
1371+
"SortKeysDescending",
1372+
CALL("dpct::segmented_sort_keys",
1373+
CALL("oneapi::dpl::execution::device_"
1374+
"policy",
1375+
QUEUESTR),
1376+
ARG(2), ARG(3), ARG(4), ARG(5), ARG(6),
1377+
LITERAL("true"), LITERAL("true"), ARG(7),
1378+
ARG(8)))),
1379+
CASE(CheckArgCount(8, std::greater_equal<>(),
1380+
/* IncludeDefaultArg */ false),
1381+
CALL_FACTORY_ENTRY(
1382+
"cub::DeviceSegmentedRadixSort::"
1383+
"SortKeysDescending",
1384+
CALL("dpct::segmented_sort_keys",
1385+
CALL("oneapi::dpl::execution::device_"
1386+
"policy",
1387+
QUEUESTR),
1388+
ARG(2), ARG(3), ARG(4), ARG(5), ARG(6),
1389+
LITERAL("true"), LITERAL("true"),
1390+
ARG(7)))),
1391+
OTHERWISE(CALL_FACTORY_ENTRY(
1392+
"cub::DeviceSegmentedRadixSort::SortKeysDescending",
1393+
CALL("dpct::segmented_sort_keys",
1394+
CALL("oneapi::dpl::execution::device_policy",
1395+
QUEUESTR),
1396+
ARG(2), ARG(3), ARG(4), ARG(5), ARG(6),
1397+
LITERAL("true"), LITERAL("true")))))),
1398+
OTHERWISE(CASE_FACTORY_ENTRY(
1399+
CASE(
1400+
makeCheckAnd(
1401+
CheckArgCount(11, std::greater_equal<>(),
1402+
/* IncludeDefaultArg */ false),
1403+
makeCheckNot(CheckArgIsDefaultCudaStream(10))),
1404+
CALL_FACTORY_ENTRY(
1405+
"cub::DeviceSegmentedRadixSort::SortKeysDescending",
1406+
CALL("dpct::segmented_sort_keys",
1407+
CALL("oneapi::dpl::execution::device_policy",
1408+
STREAM(10)),
1409+
ARG(2), ARG(3), ARG(4), ARG(5), ARG(6), ARG(7),
1410+
LITERAL("true"), ARG(8), ARG(9)))),
1411+
CASE(
1412+
CheckArgCount(10, std::greater_equal<>(),
1413+
/* IncludeDefaultArg */ false),
1414+
CALL_FACTORY_ENTRY(
1415+
"cub::DeviceSegmentedRadixSort::SortKeysDescending",
1416+
CALL("dpct::segmented_sort_keys",
1417+
CALL("oneapi::dpl::execution::device_policy",
1418+
QUEUESTR),
1419+
ARG(2), ARG(3), ARG(4), ARG(5), ARG(6), ARG(7),
1420+
LITERAL("true"), ARG(8), ARG(9)))),
1421+
CASE(
1422+
CheckArgCount(9, std::greater_equal<>(),
1423+
/* IncludeDefaultArg */ false),
1424+
CALL_FACTORY_ENTRY(
1425+
"cub::DeviceSegmentedRadixSort::SortKeysDescending",
1426+
CALL("dpct::segmented_sort_keys",
1427+
CALL("oneapi::dpl::execution::device_policy",
1428+
QUEUESTR),
1429+
ARG(2), ARG(3), ARG(4), ARG(5), ARG(6), ARG(7),
1430+
LITERAL("true"), ARG(8)))),
1431+
OTHERWISE(CALL_FACTORY_ENTRY(
1432+
"cub::DeviceSegmentedRadixSort::SortKeysDescending",
1433+
CALL("dpct::segmented_sort_keys",
1434+
CALL("oneapi::dpl::execution::device_policy",
1435+
QUEUESTR),
1436+
ARG(2), ARG(3), ARG(4), ARG(5), ARG(6), ARG(7),
1437+
LITERAL("true"))))))))))))
1438+
12491439
// cub::DeviceSegmentedRadixSort::SortPairs
12501440
CASE_FACTORY_ENTRY(
12511441
CASE(CheckCubRedundantFunctionCall(),

clang/lib/DPCT/APINames_CUB.inc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,8 +168,8 @@ ENTRY_MEMBER_FUNCTION(cub::DeviceSelect, UniqueByKey, UniqueByKey, true, NO_FLAG
168168
ENTRY_MEMBER_FUNCTION(cub::DeviceSpmv, CsrMV, CsrMV, false, NO_FLAG, P4, "Comment")
169169
ENTRY_MEMBER_FUNCTION(cub::DeviceSegmentedRadixSort, SortPairs, SortPairs, true, NO_FLAG, P4, "Successful: DPCT1026/DPCT1027")
170170
ENTRY_MEMBER_FUNCTION(cub::DeviceSegmentedRadixSort, SortPairsDescending, SortPairsDescending, true, NO_FLAG, P4, "Successful: DPCT1026/DPCT1027")
171-
ENTRY_MEMBER_FUNCTION(cub::DeviceSegmentedRadixSort, SortKeys, SortKeys, false, NO_FLAG, P4, "Comment")
172-
ENTRY_MEMBER_FUNCTION(cub::DeviceSegmentedRadixSort, SortKeysDescending, SortKeysDescending, false, NO_FLAG, P4, "Comment")
171+
ENTRY_MEMBER_FUNCTION(cub::DeviceSegmentedRadixSort, SortKeys, SortKeys, true, NO_FLAG, P4, "Successful: DPCT1026/DPCT1027")
172+
ENTRY_MEMBER_FUNCTION(cub::DeviceSegmentedRadixSort, SortKeysDescending, SortKeysDescending, true, NO_FLAG, P4, "Successful: DPCT1026/DPCT1027")
173173
ENTRY_MEMBER_FUNCTION(cub::DeviceSegmentedReduce, Reduce, Reduce, true, NO_FLAG, P4, "Successful: DPCT1026/DPCT1027/DPCT1091/DPCT1092")
174174
ENTRY_MEMBER_FUNCTION(cub::DeviceSegmentedReduce, Sum, Sum, true, NO_FLAG, P4, "Successful: DPCT1026/DPCT1027/DPCT1092")
175175
ENTRY_MEMBER_FUNCTION(cub::DeviceSegmentedReduce, Min, Min, true, NO_FLAG, P4, "Successful: DPCT1026/DPCT1027/DPCT1092")

clang/lib/DPCT/CUBAPIMigration.cpp

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
#include "AnalysisInfo.h"
1212
#include "CallExprRewriter.h"
1313
#include "MigrationRuleManager.h"
14-
1514
#include "clang/AST/Decl.h"
1615
#include "clang/AST/DeclCXX.h"
1716
#include "clang/AST/Expr.h"
@@ -77,13 +76,11 @@ REGISTER_RULE(CubIntrinsicRule, PassKind::PK_Analysis)
7776
void CubTypeRule::registerMatcher(ast_matchers::MatchFinder &MF) {
7877
auto TargetTypeName = [&]() {
7978
return hasAnyName("cub::Sum", "cub::Max", "cub::Min", "cub::Equality",
80-
"cub::KeyValuePair",
81-
"cub::CountingInputIterator",
79+
"cub::KeyValuePair", "cub::CountingInputIterator",
8280
"cub::TransformInputIterator",
8381
"cub::ConstantInputIterator",
8482
"cub::ArgIndexInputIterator",
85-
"cub::DiscardOutputIterator",
86-
"cub::DoubleBuffer");
83+
"cub::DiscardOutputIterator", "cub::DoubleBuffer");
8784
};
8885

8986
MF.addMatcher(
@@ -109,7 +106,7 @@ bool CubTypeRule::CanMappingToSyclNativeBinaryOp(StringRef OpTypeName) {
109106

110107
bool CubTypeRule::CanMappingToSyclType(StringRef OpTypeName) {
111108
return CanMappingToSyclNativeBinaryOp(OpTypeName) ||
112-
OpTypeName == "cub::Equality" ||
109+
OpTypeName == "cub::Equality" ||
113110

114111
// Ignore template arguments, .e.g cub::KeyValuePair<int, int>
115112
OpTypeName.startswith("cub::KeyValuePair");
@@ -133,16 +130,18 @@ void CubDeviceLevelRule::runRule(
133130
void CubMemberCallRule::registerMatcher(ast_matchers::MatchFinder &MF) {
134131
MF.addMatcher(
135132
cxxMemberCallExpr(
136-
allOf(on(hasType(hasCanonicalType(qualType(hasDeclaration(
137-
namedDecl(hasName("cub::ArgIndexInputIterator"))))))),
138-
callee(cxxMethodDecl(hasName("normalize")))))
133+
allOf(on(hasType(hasCanonicalType(qualType(hasDeclaration(
134+
namedDecl(hasName("cub::ArgIndexInputIterator"))))))),
135+
callee(cxxMethodDecl(hasName("normalize")))))
139136
.bind("memberCall"),
140137
this);
141138

142139
MF.addMatcher(
143-
memberExpr(hasObjectExpression(hasType(hasCanonicalType(qualType(hasDeclaration(namedDecl(hasName("cub::DoubleBuffer"))))))),
144-
member(hasAnyName("Current", "Alternate", "d_buffers")))
145-
.bind("memberExpr"),
140+
memberExpr(
141+
hasObjectExpression(hasType(hasCanonicalType(qualType(
142+
hasDeclaration(namedDecl(hasName("cub::DoubleBuffer"))))))),
143+
member(hasAnyName("Current", "Alternate", "d_buffers")))
144+
.bind("memberExpr"),
146145
this);
147146
}
148147

@@ -166,7 +165,8 @@ void CubIntrinsicRule::registerMatcher(ast_matchers::MatchFinder &MF) {
166165
this);
167166
}
168167

169-
void CubIntrinsicRule::runRule(const ast_matchers::MatchFinder::MatchResult &Result) {
168+
void CubIntrinsicRule::runRule(
169+
const ast_matchers::MatchFinder::MatchResult &Result) {
170170
if (const auto *CE = getNodeAsType<CallExpr>(Result, "IntrinsicCall")) {
171171
ExprAnalysis EA;
172172
EA.analyze(CE);
@@ -353,8 +353,8 @@ void removeVarDecl(const VarDecl *VD) {
353353
const auto &Mgr = DpctGlobalInfo::getSourceManager();
354354
const auto Range = DS->getSourceRange();
355355
const CharSourceRange CRange(Range, true);
356-
auto Replacement = std::make_shared<ExtReplacement>(
357-
Mgr, CRange, "", nullptr);
356+
auto Replacement =
357+
std::make_shared<ExtReplacement>(Mgr, CRange, "", nullptr);
358358
DpctGlobalInfo::getInstance().addReplacement(Replacement);
359359
return;
360360
}
@@ -1273,7 +1273,6 @@ void CubRule::processWarpLevelMemberCall(const CXXMemberCallExpr *WarpMC) {
12731273
emplaceTransformation(new ReplaceStmt(WarpMC, Repl));
12741274
}
12751275

1276-
12771276
if (auto FuncInfo = DeviceFunctionDecl::LinkRedecls(FD)) {
12781277
FuncInfo->addSubGroupSizeRequest(WarpSize, WarpMC->getBeginLoc(),
12791278
NewFuncName);

0 commit comments

Comments
 (0)