@@ -1051,7 +1051,7 @@ void IncludesCallbacks::InclusionDirective(
10511051 if (FileName.compare(StringRef("cuda/std/tuple")) == 0) {
10521052 DpctGlobalInfo::getInstance().insertHeader(HashLoc, HT_Tuple);
10531053 }
1054-
1054+
10551055 }
10561056
10571057 if (!isChildPath(CudaPath, IncludePath) &&
@@ -1974,54 +1974,48 @@ REGISTER_RULE(ZeroLengthArrayRule, PassKind::PK_Migration)
19741974void TypeInDeclRule::registerMatcher(MatchFinder &MF) {
19751975 MF.addMatcher(
19761976 typeLoc(
1977- loc(qualType(hasDeclaration(namedDecl(
1978- hasAnyName(
1979- "cudaError", "curandStatus", "cublasStatus", "CUstream",
1980- "CUstream_st", "thrust::complex", "thrust::device_vector",
1981- "thrust::device_ptr", "thrust::device_reference",
1982- "thrust::host_vector", "cublasHandle_t",
1983- "CUevent_st", "__half", "half", "__half2", "half2",
1984- "cudaMemoryAdvise", "cudaError_enum", "cudaDeviceProp",
1985- "cudaPitchedPtr", "thrust::counting_iterator",
1986- "thrust::transform_iterator", "thrust::permutation_iterator",
1987- "thrust::iterator_difference", "cusolverDnHandle_t",
1988- "cusolverDnParams_t", "gesvdjInfo_t",
1989- "thrust::device_malloc_allocator", "thrust::divides",
1990- "thrust::tuple", "thrust::maximum", "thrust::multiplies",
1991- "thrust::plus", "cudaDataType_t", "cudaError_t", "CUresult",
1992- "CUdevice", "cudaEvent_t", "cublasStatus_t", "cuComplex",
1993- "cuFloatComplex", "cuDoubleComplex", "CUevent",
1994- "cublasFillMode_t", "cublasDiagType_t", "cublasSideMode_t",
1995- "cublasOperation_t", "cusolverStatus_t", "cusolverEigType_t",
1996- "cusolverEigMode_t", "curandStatus_t", "cudaStream_t",
1997- "cusparseStatus_t", "cusparseDiagType_t",
1998- "cusparseFillMode_t", "cusparseIndexBase_t",
1999- "cusparseMatrixType_t", "cusparseOperation_t",
2000- "cusparseMatDescr_t", "cusparseHandle_t", "CUcontext",
2001- "cublasPointerMode_t", "cusparsePointerMode_t",
2002- "cublasGemmAlgo_t", "cusparseSolveAnalysisInfo_t",
2003- "cudaDataType", "cublasDataType_t", "curandState_t",
2004- "curandState", "curandStateXORWOW_t", "curandStateXORWOW",
2005- "curandStatePhilox4_32_10_t", "curandStatePhilox4_32_10",
2006- "curandStateMRG32k3a_t", "curandStateMRG32k3a",
2007- "thrust::minus", "thrust::negate", "thrust::logical_or",
2008- "thrust::identity", "thrust::equal_to", "thrust::less",
2009- "cudaSharedMemConfig", "curandGenerator_t", "cufftHandle",
2010- "cufftReal", "cufftDoubleReal", "cufftComplex",
2011- "cufftDoubleComplex", "cufftResult_t", "cufftResult",
2012- "cufftType_t", "cufftType", "thrust::pair", "CUdeviceptr",
2013- "cudaDeviceAttr", "CUmodule", "CUjit_option",
2014- "CUfunction", "cudaMemcpyKind",
2015- "cudaComputeMode", "__nv_bfloat16",
2016- "cooperative_groups::__v1::thread_block_tile",
2017- "cooperative_groups::__v1::thread_block",
2018- "libraryPropertyType_t", "libraryPropertyType",
2019- "cudaDataType_t", "cudaDataType", "cublasComputeType_t",
2020- "cublasAtomicsMode_t", "CUmem_advise_enum", "CUmem_advise",
2021- "thrust::tuple_element", "thrust::tuple_size", "cublasMath_t",
2022- "cudaPointerAttributes", "thrust::zip_iterator",
2023- "cusolverEigRange_t", "cudaUUID_t")
2024- )))))
1977+ loc(qualType(hasDeclaration(namedDecl(hasAnyName(
1978+ "cudaError", "curandStatus", "cublasStatus", "CUstream",
1979+ "CUstream_st", "thrust::complex", "thrust::device_vector",
1980+ "thrust::device_ptr", "thrust::device_reference",
1981+ "thrust::host_vector", "cublasHandle_t", "CUevent_st", "__half",
1982+ "half", "__half2", "half2", "cudaMemoryAdvise", "cudaError_enum",
1983+ "cudaDeviceProp", "cudaPitchedPtr", "thrust::counting_iterator",
1984+ "thrust::transform_iterator", "thrust::permutation_iterator",
1985+ "thrust::iterator_difference", "cusolverDnHandle_t",
1986+ "cusolverDnParams_t", "gesvdjInfo_t",
1987+ "thrust::device_malloc_allocator", "thrust::divides",
1988+ "thrust::tuple", "thrust::maximum", "thrust::multiplies",
1989+ "thrust::plus", "cudaDataType_t", "cudaError_t", "CUresult",
1990+ "CUdevice", "cudaEvent_t", "cublasStatus_t", "cuComplex",
1991+ "cuFloatComplex", "cuDoubleComplex", "CUevent",
1992+ "cublasFillMode_t", "cublasDiagType_t", "cublasSideMode_t",
1993+ "cublasOperation_t", "cusolverStatus_t", "cusolverEigType_t",
1994+ "cusolverEigMode_t", "curandStatus_t", "cudaStream_t",
1995+ "cusparseStatus_t", "cusparseDiagType_t", "cusparseFillMode_t",
1996+ "cusparseIndexBase_t", "cusparseMatrixType_t",
1997+ "cusparseOperation_t", "cusparseMatDescr_t", "cusparseHandle_t",
1998+ "CUcontext", "cublasPointerMode_t", "cusparsePointerMode_t",
1999+ "cublasGemmAlgo_t", "cusparseSolveAnalysisInfo_t", "cudaDataType",
2000+ "cublasDataType_t", "curandState_t", "curandState",
2001+ "curandStateXORWOW_t", "curandStateXORWOW",
2002+ "curandStatePhilox4_32_10_t", "curandStatePhilox4_32_10",
2003+ "curandStateMRG32k3a_t", "curandStateMRG32k3a", "thrust::minus",
2004+ "thrust::negate", "thrust::logical_or", "thrust::identity",
2005+ "thrust::equal_to", "thrust::less", "cudaSharedMemConfig",
2006+ "curandGenerator_t", "curandRngType_t", "cufftHandle",
2007+ "cufftReal", "cufftDoubleReal", "cufftComplex",
2008+ "cufftDoubleComplex", "cufftResult_t", "cufftResult",
2009+ "cufftType_t", "cufftType", "thrust::pair", "CUdeviceptr",
2010+ "cudaDeviceAttr", "CUmodule", "CUjit_option", "CUfunction",
2011+ "cudaMemcpyKind", "cudaComputeMode", "__nv_bfloat16",
2012+ "cooperative_groups::__v1::thread_block_tile",
2013+ "cooperative_groups::__v1::thread_block", "libraryPropertyType_t",
2014+ "libraryPropertyType", "cudaDataType_t", "cudaDataType",
2015+ "cublasComputeType_t", "cublasAtomicsMode_t", "CUmem_advise_enum",
2016+ "CUmem_advise", "thrust::tuple_element", "thrust::tuple_size",
2017+ "cublasMath_t", "cudaPointerAttributes", "thrust::zip_iterator",
2018+ "cusolverEigRange_t", "cudaUUID_t"))))))
20252019 .bind("cudaTypeDef"),
20262020 this);
20272021 MF.addMatcher(varDecl(hasType(classTemplateSpecializationDecl(
@@ -3967,12 +3961,14 @@ void BLASEnumsRule::runRule(const MatchFinder::MatchResult &Result) {
39673961REGISTER_RULE(BLASEnumsRule, PassKind::PK_Migration)
39683962
39693963// Rule for RANDOM enums.
3970- // Migrate RANDOM status values to corresponding int values
39713964void RandomEnumsRule::registerMatcher(MatchFinder &MF) {
39723965 MF.addMatcher(
39733966 declRefExpr(to(enumConstantDecl(matchesName("CURAND_STATUS.*"))))
39743967 .bind("RANDOMStatusConstants"),
39753968 this);
3969+ MF.addMatcher(declRefExpr(to(enumConstantDecl(matchesName("CURAND_RNG.*"))))
3970+ .bind("RandomTypeEnum"),
3971+ this);
39763972}
39773973
39783974void RandomEnumsRule::runRule(const MatchFinder::MatchResult &Result) {
@@ -3981,6 +3977,23 @@ void RandomEnumsRule::runRule(const MatchFinder::MatchResult &Result) {
39813977 auto *EC = cast<EnumConstantDecl>(DE->getDecl());
39823978 emplaceTransformation(new ReplaceStmt(DE, toString(EC->getInitVal(), 10)));
39833979 }
3980+ if (const DeclRefExpr *DE =
3981+ getNodeAsType<DeclRefExpr>(Result, "RandomTypeEnum")) {
3982+ std::string EnumStr = DE->getNameInfo().getName().getAsString();
3983+ auto Search = MapNames::RandomEngineTypeMap.find(EnumStr);
3984+ if (Search == MapNames::RandomEngineTypeMap.end()) {
3985+ report(DE->getBeginLoc(), Diagnostics::API_NOT_MIGRATED, false, EnumStr);
3986+ return;
3987+ }
3988+ if (EnumStr == "CURAND_RNG_PSEUDO_XORWOW" ||
3989+ EnumStr == "CURAND_RNG_QUASI_SOBOL64" ||
3990+ EnumStr == "CURAND_RNG_QUASI_SCRAMBLED_SOBOL64") {
3991+ report(DE->getBeginLoc(), Diagnostics::DIFFERENT_GENERATOR, false);
3992+ } else if (EnumStr == "CURAND_RNG_QUASI_SCRAMBLED_SOBOL32") {
3993+ report(DE->getBeginLoc(), Diagnostics::DIFFERENT_BASIC_GENERATOR, false);
3994+ }
3995+ emplaceTransformation(new ReplaceStmt(DE, Search->second));
3996+ }
39843997}
39853998
39863999REGISTER_RULE(RandomEnumsRule, PassKind::PK_Migration)
@@ -4559,26 +4572,6 @@ void RandomFunctionCallRule::runRule(const MatchFinder::MatchResult &Result) {
45594572
45604573 if (FuncName == "curandCreateGenerator" ||
45614574 FuncName == "curandCreateGeneratorHost") {
4562- std::string EnumStr = ExprAnalysis::ref(CE->getArg(1));
4563- if (MapNames::RandomEngineTypeMap.find(EnumStr) ==
4564- MapNames::RandomEngineTypeMap.end()) {
4565- report(PrefixInsertLoc, Diagnostics::NOT_SUPPORTED_PARAMETER, false,
4566- FuncName, "parameter " + EnumStr + " is unsupported");
4567- return;
4568- }
4569-
4570- if (EnumStr == "CURAND_RNG_PSEUDO_XORWOW" ||
4571- EnumStr == "CURAND_RNG_QUASI_SOBOL64" ||
4572- EnumStr == "CURAND_RNG_QUASI_SCRAMBLED_SOBOL64") {
4573- report(CE->getArg(1)->getBeginLoc(), Diagnostics::DIFFERENT_GENERATOR,
4574- false);
4575- } else if (EnumStr == "CURAND_RNG_QUASI_SCRAMBLED_SOBOL32") {
4576- report(CE->getArg(1)->getBeginLoc(),
4577- Diagnostics::DIFFERENT_BASIC_GENERATOR, false);
4578- }
4579-
4580- std::string EngineType =
4581- MapNames::RandomEngineTypeMap.find(EnumStr)->second;
45824575 const auto *const Arg0 = CE->getArg(0);
45834576 requestFeature(HelperFeatureEnum::RngUtils_create_host_rng, CE);
45844577 if (Arg0->getStmtClass() == Stmt::UnaryOperatorClass) {
@@ -4590,13 +4583,14 @@ void RandomFunctionCallRule::runRule(const MatchFinder::MatchResult &Result) {
45904583 return emplaceTransformation(new ReplaceStmt(
45914584 CE, false,
45924585 buildString(ExprAnalysis::ref(SE), " = dpct::rng::create_host_rng(",
4593- EngineType , ")")));
4586+ ExprAnalysis::ref(CE->getArg(1)) , ")")));
45944587 }
45954588 }
4596- return emplaceTransformation(new ReplaceStmt(
4597- CE, false,
4598- buildString("*(", ExprAnalysis::ref(CE->getArg(0)),
4599- ") = dpct::rng::create_host_rng(", EngineType, ")")));
4589+ return emplaceTransformation(
4590+ new ReplaceStmt(CE, false,
4591+ buildString("*(", ExprAnalysis::ref(CE->getArg(0)),
4592+ ") = dpct::rng::create_host_rng(",
4593+ ExprAnalysis::ref(CE->getArg(1)), ")")));
46004594 }
46014595 if (FuncName == "curandDestroyGenerator") {
46024596 return emplaceTransformation(new ReplaceStmt(
@@ -15199,15 +15193,15 @@ void CudaExtentRule::runRule(
1519915193 // struct Foo { cudaExtent e; Foo() : e() {} }; -> struct Foo { sycl::range<3> e; Foo() : e{0, 0, 0} {} };
1520015194 if (const CXXConstructExpr *Ctor =
1520115195 getNodeAsType<CXXConstructExpr>(Result, "defaultCtor")) {
15202-
15196+
1520315197 // Ignore implicit move/copy ctor
1520415198 if (Ctor->getNumArgs() != 0)
1520515199 return;
1520615200 CharSourceRange CSR;
1520715201 SourceRange SR = Ctor->getParenOrBraceRange();
1520815202 auto &SM = DpctGlobalInfo::getSourceManager();
1520915203 std::string Replacement = "{0, 0, 0}";
15210-
15204+
1521115205 if (SR.isInvalid()) {
1521215206 auto CtorLoc = Ctor->getLocation().isMacroID()
1521315207 ? SM.getSpellingLoc(Ctor->getLocation())
0 commit comments