Skip to content

Commit dc87bf2

Browse files
authored
[SYCLomatic] Support migration of curandRngType. (#750)
Signed-off-by: Tang, Jiajun [email protected]
1 parent 60c3a94 commit dc87bf2

File tree

5 files changed

+128
-85
lines changed

5 files changed

+128
-85
lines changed

clang/lib/DPCT/ASTTraversal.cpp

Lines changed: 71 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -1051,7 +1051,7 @@ void IncludesCallbacks::InclusionDirective(
10511051
if (FileName.compare(StringRef("cuda/std/tuple")) == 0) {
10521052
DpctGlobalInfo::getInstance().insertHeader(HashLoc, HT_Tuple);
10531053
}
1054-
1054+
10551055
}
10561056

10571057
if (!isChildPath(CudaPath, IncludePath) &&
@@ -1974,54 +1974,48 @@ REGISTER_RULE(ZeroLengthArrayRule, PassKind::PK_Migration)
19741974
void TypeInDeclRule::registerMatcher(MatchFinder &MF) {
19751975
MF.addMatcher(
19761976
typeLoc(
1977-
loc(qualType(hasDeclaration(namedDecl(
1978-
hasAnyName(
1979-
"cudaError", "curandStatus", "cublasStatus", "CUstream",
1980-
"CUstream_st", "thrust::complex", "thrust::device_vector",
1981-
"thrust::device_ptr", "thrust::device_reference",
1982-
"thrust::host_vector", "cublasHandle_t",
1983-
"CUevent_st", "__half", "half", "__half2", "half2",
1984-
"cudaMemoryAdvise", "cudaError_enum", "cudaDeviceProp",
1985-
"cudaPitchedPtr", "thrust::counting_iterator",
1986-
"thrust::transform_iterator", "thrust::permutation_iterator",
1987-
"thrust::iterator_difference", "cusolverDnHandle_t",
1988-
"cusolverDnParams_t", "gesvdjInfo_t",
1989-
"thrust::device_malloc_allocator", "thrust::divides",
1990-
"thrust::tuple", "thrust::maximum", "thrust::multiplies",
1991-
"thrust::plus", "cudaDataType_t", "cudaError_t", "CUresult",
1992-
"CUdevice", "cudaEvent_t", "cublasStatus_t", "cuComplex",
1993-
"cuFloatComplex", "cuDoubleComplex", "CUevent",
1994-
"cublasFillMode_t", "cublasDiagType_t", "cublasSideMode_t",
1995-
"cublasOperation_t", "cusolverStatus_t", "cusolverEigType_t",
1996-
"cusolverEigMode_t", "curandStatus_t", "cudaStream_t",
1997-
"cusparseStatus_t", "cusparseDiagType_t",
1998-
"cusparseFillMode_t", "cusparseIndexBase_t",
1999-
"cusparseMatrixType_t", "cusparseOperation_t",
2000-
"cusparseMatDescr_t", "cusparseHandle_t", "CUcontext",
2001-
"cublasPointerMode_t", "cusparsePointerMode_t",
2002-
"cublasGemmAlgo_t", "cusparseSolveAnalysisInfo_t",
2003-
"cudaDataType", "cublasDataType_t", "curandState_t",
2004-
"curandState", "curandStateXORWOW_t", "curandStateXORWOW",
2005-
"curandStatePhilox4_32_10_t", "curandStatePhilox4_32_10",
2006-
"curandStateMRG32k3a_t", "curandStateMRG32k3a",
2007-
"thrust::minus", "thrust::negate", "thrust::logical_or",
2008-
"thrust::identity", "thrust::equal_to", "thrust::less",
2009-
"cudaSharedMemConfig", "curandGenerator_t", "cufftHandle",
2010-
"cufftReal", "cufftDoubleReal", "cufftComplex",
2011-
"cufftDoubleComplex", "cufftResult_t", "cufftResult",
2012-
"cufftType_t", "cufftType", "thrust::pair", "CUdeviceptr",
2013-
"cudaDeviceAttr", "CUmodule", "CUjit_option",
2014-
"CUfunction", "cudaMemcpyKind",
2015-
"cudaComputeMode", "__nv_bfloat16",
2016-
"cooperative_groups::__v1::thread_block_tile",
2017-
"cooperative_groups::__v1::thread_block",
2018-
"libraryPropertyType_t", "libraryPropertyType",
2019-
"cudaDataType_t", "cudaDataType", "cublasComputeType_t",
2020-
"cublasAtomicsMode_t", "CUmem_advise_enum", "CUmem_advise",
2021-
"thrust::tuple_element", "thrust::tuple_size", "cublasMath_t",
2022-
"cudaPointerAttributes", "thrust::zip_iterator",
2023-
"cusolverEigRange_t", "cudaUUID_t")
2024-
)))))
1977+
loc(qualType(hasDeclaration(namedDecl(hasAnyName(
1978+
"cudaError", "curandStatus", "cublasStatus", "CUstream",
1979+
"CUstream_st", "thrust::complex", "thrust::device_vector",
1980+
"thrust::device_ptr", "thrust::device_reference",
1981+
"thrust::host_vector", "cublasHandle_t", "CUevent_st", "__half",
1982+
"half", "__half2", "half2", "cudaMemoryAdvise", "cudaError_enum",
1983+
"cudaDeviceProp", "cudaPitchedPtr", "thrust::counting_iterator",
1984+
"thrust::transform_iterator", "thrust::permutation_iterator",
1985+
"thrust::iterator_difference", "cusolverDnHandle_t",
1986+
"cusolverDnParams_t", "gesvdjInfo_t",
1987+
"thrust::device_malloc_allocator", "thrust::divides",
1988+
"thrust::tuple", "thrust::maximum", "thrust::multiplies",
1989+
"thrust::plus", "cudaDataType_t", "cudaError_t", "CUresult",
1990+
"CUdevice", "cudaEvent_t", "cublasStatus_t", "cuComplex",
1991+
"cuFloatComplex", "cuDoubleComplex", "CUevent",
1992+
"cublasFillMode_t", "cublasDiagType_t", "cublasSideMode_t",
1993+
"cublasOperation_t", "cusolverStatus_t", "cusolverEigType_t",
1994+
"cusolverEigMode_t", "curandStatus_t", "cudaStream_t",
1995+
"cusparseStatus_t", "cusparseDiagType_t", "cusparseFillMode_t",
1996+
"cusparseIndexBase_t", "cusparseMatrixType_t",
1997+
"cusparseOperation_t", "cusparseMatDescr_t", "cusparseHandle_t",
1998+
"CUcontext", "cublasPointerMode_t", "cusparsePointerMode_t",
1999+
"cublasGemmAlgo_t", "cusparseSolveAnalysisInfo_t", "cudaDataType",
2000+
"cublasDataType_t", "curandState_t", "curandState",
2001+
"curandStateXORWOW_t", "curandStateXORWOW",
2002+
"curandStatePhilox4_32_10_t", "curandStatePhilox4_32_10",
2003+
"curandStateMRG32k3a_t", "curandStateMRG32k3a", "thrust::minus",
2004+
"thrust::negate", "thrust::logical_or", "thrust::identity",
2005+
"thrust::equal_to", "thrust::less", "cudaSharedMemConfig",
2006+
"curandGenerator_t", "curandRngType_t", "cufftHandle",
2007+
"cufftReal", "cufftDoubleReal", "cufftComplex",
2008+
"cufftDoubleComplex", "cufftResult_t", "cufftResult",
2009+
"cufftType_t", "cufftType", "thrust::pair", "CUdeviceptr",
2010+
"cudaDeviceAttr", "CUmodule", "CUjit_option", "CUfunction",
2011+
"cudaMemcpyKind", "cudaComputeMode", "__nv_bfloat16",
2012+
"cooperative_groups::__v1::thread_block_tile",
2013+
"cooperative_groups::__v1::thread_block", "libraryPropertyType_t",
2014+
"libraryPropertyType", "cudaDataType_t", "cudaDataType",
2015+
"cublasComputeType_t", "cublasAtomicsMode_t", "CUmem_advise_enum",
2016+
"CUmem_advise", "thrust::tuple_element", "thrust::tuple_size",
2017+
"cublasMath_t", "cudaPointerAttributes", "thrust::zip_iterator",
2018+
"cusolverEigRange_t", "cudaUUID_t"))))))
20252019
.bind("cudaTypeDef"),
20262020
this);
20272021
MF.addMatcher(varDecl(hasType(classTemplateSpecializationDecl(
@@ -3967,12 +3961,14 @@ void BLASEnumsRule::runRule(const MatchFinder::MatchResult &Result) {
39673961
REGISTER_RULE(BLASEnumsRule, PassKind::PK_Migration)
39683962

39693963
// Rule for RANDOM enums.
3970-
// Migrate RANDOM status values to corresponding int values
39713964
void RandomEnumsRule::registerMatcher(MatchFinder &MF) {
39723965
MF.addMatcher(
39733966
declRefExpr(to(enumConstantDecl(matchesName("CURAND_STATUS.*"))))
39743967
.bind("RANDOMStatusConstants"),
39753968
this);
3969+
MF.addMatcher(declRefExpr(to(enumConstantDecl(matchesName("CURAND_RNG.*"))))
3970+
.bind("RandomTypeEnum"),
3971+
this);
39763972
}
39773973

39783974
void RandomEnumsRule::runRule(const MatchFinder::MatchResult &Result) {
@@ -3981,6 +3977,23 @@ void RandomEnumsRule::runRule(const MatchFinder::MatchResult &Result) {
39813977
auto *EC = cast<EnumConstantDecl>(DE->getDecl());
39823978
emplaceTransformation(new ReplaceStmt(DE, toString(EC->getInitVal(), 10)));
39833979
}
3980+
if (const DeclRefExpr *DE =
3981+
getNodeAsType<DeclRefExpr>(Result, "RandomTypeEnum")) {
3982+
std::string EnumStr = DE->getNameInfo().getName().getAsString();
3983+
auto Search = MapNames::RandomEngineTypeMap.find(EnumStr);
3984+
if (Search == MapNames::RandomEngineTypeMap.end()) {
3985+
report(DE->getBeginLoc(), Diagnostics::API_NOT_MIGRATED, false, EnumStr);
3986+
return;
3987+
}
3988+
if (EnumStr == "CURAND_RNG_PSEUDO_XORWOW" ||
3989+
EnumStr == "CURAND_RNG_QUASI_SOBOL64" ||
3990+
EnumStr == "CURAND_RNG_QUASI_SCRAMBLED_SOBOL64") {
3991+
report(DE->getBeginLoc(), Diagnostics::DIFFERENT_GENERATOR, false);
3992+
} else if (EnumStr == "CURAND_RNG_QUASI_SCRAMBLED_SOBOL32") {
3993+
report(DE->getBeginLoc(), Diagnostics::DIFFERENT_BASIC_GENERATOR, false);
3994+
}
3995+
emplaceTransformation(new ReplaceStmt(DE, Search->second));
3996+
}
39843997
}
39853998

39863999
REGISTER_RULE(RandomEnumsRule, PassKind::PK_Migration)
@@ -4559,26 +4572,6 @@ void RandomFunctionCallRule::runRule(const MatchFinder::MatchResult &Result) {
45594572

45604573
if (FuncName == "curandCreateGenerator" ||
45614574
FuncName == "curandCreateGeneratorHost") {
4562-
std::string EnumStr = ExprAnalysis::ref(CE->getArg(1));
4563-
if (MapNames::RandomEngineTypeMap.find(EnumStr) ==
4564-
MapNames::RandomEngineTypeMap.end()) {
4565-
report(PrefixInsertLoc, Diagnostics::NOT_SUPPORTED_PARAMETER, false,
4566-
FuncName, "parameter " + EnumStr + " is unsupported");
4567-
return;
4568-
}
4569-
4570-
if (EnumStr == "CURAND_RNG_PSEUDO_XORWOW" ||
4571-
EnumStr == "CURAND_RNG_QUASI_SOBOL64" ||
4572-
EnumStr == "CURAND_RNG_QUASI_SCRAMBLED_SOBOL64") {
4573-
report(CE->getArg(1)->getBeginLoc(), Diagnostics::DIFFERENT_GENERATOR,
4574-
false);
4575-
} else if (EnumStr == "CURAND_RNG_QUASI_SCRAMBLED_SOBOL32") {
4576-
report(CE->getArg(1)->getBeginLoc(),
4577-
Diagnostics::DIFFERENT_BASIC_GENERATOR, false);
4578-
}
4579-
4580-
std::string EngineType =
4581-
MapNames::RandomEngineTypeMap.find(EnumStr)->second;
45824575
const auto *const Arg0 = CE->getArg(0);
45834576
requestFeature(HelperFeatureEnum::RngUtils_create_host_rng, CE);
45844577
if (Arg0->getStmtClass() == Stmt::UnaryOperatorClass) {
@@ -4590,13 +4583,14 @@ void RandomFunctionCallRule::runRule(const MatchFinder::MatchResult &Result) {
45904583
return emplaceTransformation(new ReplaceStmt(
45914584
CE, false,
45924585
buildString(ExprAnalysis::ref(SE), " = dpct::rng::create_host_rng(",
4593-
EngineType, ")")));
4586+
ExprAnalysis::ref(CE->getArg(1)), ")")));
45944587
}
45954588
}
4596-
return emplaceTransformation(new ReplaceStmt(
4597-
CE, false,
4598-
buildString("*(", ExprAnalysis::ref(CE->getArg(0)),
4599-
") = dpct::rng::create_host_rng(", EngineType, ")")));
4589+
return emplaceTransformation(
4590+
new ReplaceStmt(CE, false,
4591+
buildString("*(", ExprAnalysis::ref(CE->getArg(0)),
4592+
") = dpct::rng::create_host_rng(",
4593+
ExprAnalysis::ref(CE->getArg(1)), ")")));
46004594
}
46014595
if (FuncName == "curandDestroyGenerator") {
46024596
return emplaceTransformation(new ReplaceStmt(
@@ -15199,15 +15193,15 @@ void CudaExtentRule::runRule(
1519915193
// struct Foo { cudaExtent e; Foo() : e() {} }; -> struct Foo { sycl::range<3> e; Foo() : e{0, 0, 0} {} };
1520015194
if (const CXXConstructExpr *Ctor =
1520115195
getNodeAsType<CXXConstructExpr>(Result, "defaultCtor")) {
15202-
15196+
1520315197
// Ignore implicit move/copy ctor
1520415198
if (Ctor->getNumArgs() != 0)
1520515199
return;
1520615200
CharSourceRange CSR;
1520715201
SourceRange SR = Ctor->getParenOrBraceRange();
1520815202
auto &SM = DpctGlobalInfo::getSourceManager();
1520915203
std::string Replacement = "{0, 0, 0}";
15210-
15204+
1521115205
if (SR.isInvalid()) {
1521215206
auto CtorLoc = Ctor->getLocation().isMacroID()
1521315207
? SM.getSpellingLoc(Ctor->getLocation())

clang/lib/DPCT/ExprAnalysis.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -508,6 +508,7 @@ void ExprAnalysis::analyzeExpr(const DeclRefExpr *DRE) {
508508
REPLACE_ENUM(MapNames::BLASEnumsMap);
509509
REPLACE_ENUM(MapNames::FunctionAttrMap);
510510
REPLACE_ENUM(CuDNNTypeRule::CuDNNEnumNamesMap);
511+
REPLACE_ENUM(MapNames::RandomEngineTypeMap);
511512
REPLACE_ENUM(MapNames::SOLVEREnumsMap);
512513
REPLACE_ENUM(MapNames::SPBLASEnumsMap);
513514
#undef REPLACE_ENUM
@@ -1108,6 +1109,10 @@ void ExprAnalysis::analyzeType(TypeLoc TL, const Expr *CSCE) {
11081109
case TypeLoc::Decltype:
11091110
analyzeDecltypeType(TYPELOC_CAST(DecltypeTypeLoc));
11101111
break;
1112+
case TypeLoc::Enum: {
1113+
TyName = DpctGlobalInfo::getTypeName(TL.getType());
1114+
break;
1115+
}
11111116
default:
11121117
return;
11131118
}

clang/lib/DPCT/MapNames.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,12 @@ void MapNames::setExplicitNamespaceMap() {
327327
std::make_shared<TypeNameRule>(
328328
getDpctNamespace() + "rng::host_rng_ptr",
329329
HelperFeatureEnum::RngUtils_typedef_host_rng_ptr)},
330+
{"curandRngType_t", std::make_shared<TypeNameRule>(
331+
getDpctNamespace() + "rng::random_engine_type",
332+
HelperFeatureEnum::RngUtils_random_engine_type)},
333+
{"curandRngType", std::make_shared<TypeNameRule>(
334+
getDpctNamespace() + "rng::random_engine_type",
335+
HelperFeatureEnum::RngUtils_random_engine_type)},
330336
{"curandStatus_t", std::make_shared<TypeNameRule>("int")},
331337
{"curandStatus", std::make_shared<TypeNameRule>("int")},
332338
{"cusparseStatus_t", std::make_shared<TypeNameRule>("int")},

clang/test/dpct/curand.cu

Lines changed: 45 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -235,11 +235,7 @@ void bar4(){
235235

236236
void bar5(){
237237
//CHECK:dpct::rng::host_rng_ptr rng;
238-
//CHECK-NEXT:/*
239-
//CHECK-NEXT:DPCT1028:{{[0-9]+}}: The curandCreateGenerator was not migrated because parameter
240-
//CHECK-NEXT:(curandRngType)101 is unsupported.
241-
//CHECK-NEXT:*/
242-
//CHECK-NEXT:curandCreateGenerator(&rng, (curandRngType)101);
238+
//CHECK-NEXT:rng = dpct::rng::create_host_rng((dpct::rng::random_engine_type)101);
243239
//CHECK-NEXT:rng->set_seed(1337ull);
244240
curandGenerator_t rng;
245241
curandCreateGenerator(&rng, (curandRngType)101);
@@ -267,7 +263,48 @@ int bar6(){
267263

268264
void bar7() {
269265
curandGenerator_t rng;
270-
//CHECK:rng =
271-
//CHECK: dpct::rng::create_host_rng(dpct::rng::random_engine_type::philox4x32x10);
272-
curandCreateGeneratorHost(&rng, CURAND_RNG_PSEUDO_PHILOX4_32_10);
266+
// CHECK: dpct::rng::random_engine_type rngT;
267+
// CHECK-NEXT: dpct::rng::random_engine_type rngT1 = dpct::rng::random_engine_type::mcg59;
268+
// CHECK-NEXT: /*
269+
// CHECK-NEXT: DPCT1032:{{[0-9]+}}: A different random number generator is used. You may need to
270+
// CHECK-NEXT: adjust the code.
271+
// CHECK-NEXT: */
272+
// CHECK-NEXT: dpct::rng::random_engine_type rngT2 = dpct::rng::random_engine_type::mcg59;
273+
// CHECK-NEXT: dpct::rng::random_engine_type rngT3 = dpct::rng::random_engine_type::mrg32k3a;
274+
// CHECK-NEXT: dpct::rng::random_engine_type rngT4 = dpct::rng::random_engine_type::mt2203;
275+
// CHECK-NEXT: dpct::rng::random_engine_type rngT5 = dpct::rng::random_engine_type::mt19937;
276+
// CHECK-NEXT: dpct::rng::random_engine_type rngT6 =
277+
// CHECK-NEXT: dpct::rng::random_engine_type::philox4x32x10;
278+
// CHECK-NEXT: dpct::rng::random_engine_type rngT7 = dpct::rng::random_engine_type::sobol;
279+
// CHECK-NEXT: dpct::rng::random_engine_type rngT8 = dpct::rng::random_engine_type::sobol;
280+
// CHECK-NEXT: /*
281+
// CHECK-NEXT: DPCT1033:{{[0-9]+}}: Migrated code uses a basic Sobol generator. Initialize
282+
// CHECK-NEXT: oneapi::mkl::rng::sobol generator with user-defined direction numbers to use
283+
// CHECK-NEXT: it as Scrambled Sobol generator.
284+
// CHECK-NEXT: */
285+
// CHECK-NEXT: dpct::rng::random_engine_type rngT9 = dpct::rng::random_engine_type::sobol;
286+
// CHECK-NEXT: /*
287+
// CHECK-NEXT: DPCT1032:{{[0-9]+}}: A different random number generator is used. You may need to
288+
// CHECK-NEXT: adjust the code.
289+
// CHECK-NEXT: */
290+
// CHECK-NEXT: dpct::rng::random_engine_type rngT10 = dpct::rng::random_engine_type::sobol;
291+
// CHECK-NEXT: /*
292+
// CHECK-NEXT: DPCT1032:{{[0-9]+}}: A different random number generator is used. You may need to
293+
// CHECK-NEXT: adjust the code.
294+
// CHECK-NEXT: */
295+
// CHECK-NEXT: dpct::rng::random_engine_type rngT11 = dpct::rng::random_engine_type::sobol;
296+
curandRngType_t rngT;
297+
curandRngType_t rngT1 = CURAND_RNG_PSEUDO_DEFAULT;
298+
curandRngType_t rngT2 = CURAND_RNG_PSEUDO_XORWOW;
299+
curandRngType_t rngT3 = CURAND_RNG_PSEUDO_MRG32K3A;
300+
curandRngType_t rngT4 = CURAND_RNG_PSEUDO_MTGP32;
301+
curandRngType_t rngT5 = CURAND_RNG_PSEUDO_MT19937;
302+
curandRngType_t rngT6 = CURAND_RNG_PSEUDO_PHILOX4_32_10;
303+
curandRngType_t rngT7 = CURAND_RNG_QUASI_DEFAULT;
304+
curandRngType_t rngT8 = CURAND_RNG_QUASI_SOBOL32;
305+
curandRngType_t rngT9 = CURAND_RNG_QUASI_SCRAMBLED_SOBOL32;
306+
curandRngType_t rngT10 = CURAND_RNG_QUASI_SOBOL64;
307+
curandRngType_t rngT11 = CURAND_RNG_QUASI_SCRAMBLED_SOBOL64;
308+
// CHECK: rng = dpct::rng::create_host_rng(rngT);
309+
curandCreateGeneratorHost(&rng, rngT);
273310
}

clang/test/dpct/test_api_level/RngUtils/api_test3.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
// RUN: rm -rf %T/RngUtils/api_test3_out
55

66
// CHECK: 32
7+
// TEST_FEATURE: RngUtils_random_engine_type
78
// TEST_FEATURE: RngUtils_create_host_rng
89
// TEST_FEATURE: RngUtils_typedef_host_rng_ptr
910

0 commit comments

Comments
 (0)