Skip to content

Commit 57ac2d5

Browse files
authored
[SYCLomatic] Change default engine from philox4x32x10 to mcg59 for both host and device generators (#710)
Signed-off-by: Jiang, Zhiwei <[email protected]>
1 parent a9eacb4 commit 57ac2d5

17 files changed

+228
-150
lines changed

clang/lib/DPCT/APINamesCURAND.inc

Lines changed: 11 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -40,17 +40,7 @@ CONDITIONAL_FACTORY_ENTRY(
4040
makeCombinedArg(ARG("{0, static_cast<std::uint64_t>(("),
4141
ARG(0)),
4242
ARG(") * 4)}"))))),
43-
FEATURE_REQUEST_FACTORY(
44-
HelperFeatureEnum::RngUtils_rng_generator_get_engine,
45-
CALL_FACTORY_ENTRY(
46-
"skipahead_sequence",
47-
CALL("oneapi::mkl::rng::device::skip_ahead",
48-
MEMBER_CALL(DEREF(makeDerefArgCreatorWithCall(1)),
49-
false, "get_engine"),
50-
makeCombinedArg(
51-
makeCombinedArg(ARG("{0, static_cast<std::uint64_t>(("),
52-
ARG(0)),
53-
ARG(") * 8)}"))))))),
43+
REMOVE_API_FACTORY_ENTRY_WITH_MSG("skipahead_sequence", "this API is not supported for mcg59 engine."))),
5444
CONDITIONAL_FACTORY_ENTRY(
5545
CheckArgType(1, "struct curandStateMRG32k3a *"),
5646
FEATURE_REQUEST_FACTORY(
@@ -75,17 +65,7 @@ CONDITIONAL_FACTORY_ENTRY(
7565
makeCombinedArg(ARG("{0, static_cast<std::uint64_t>("),
7666
ARG(0)),
7767
ARG(" * 4)}"))))),
78-
FEATURE_REQUEST_FACTORY(
79-
HelperFeatureEnum::RngUtils_rng_generator_get_engine,
80-
CALL_FACTORY_ENTRY(
81-
"skipahead_sequence",
82-
CALL("oneapi::mkl::rng::device::skip_ahead",
83-
MEMBER_CALL(DEREF(makeDerefArgCreatorWithCall(1)),
84-
false, "get_engine"),
85-
makeCombinedArg(
86-
makeCombinedArg(ARG("{0, static_cast<std::uint64_t>("),
87-
ARG(0)),
88-
ARG(" * 8)}")))))))
68+
REMOVE_API_FACTORY_ENTRY_WITH_MSG("skipahead_sequence", "this API is not supported for mcg59 engine.")))
8969
)
9070

9171
CONDITIONAL_FACTORY_ENTRY(
@@ -117,9 +97,15 @@ FEATURE_REQUEST_FACTORY(
11797
// bits
11898
FEATURE_REQUEST_FACTORY(
11999
HelperFeatureEnum::RngUtils_rng_generator_generate,
120-
MEMBER_CALL_FACTORY_ENTRY("curand",
121-
DEREF(makeDerefArgCreatorWithCall(0)), false,
122-
"generate<oneapi::mkl::rng::device::bits<std::uint32_t>, 1>"))
100+
CONDITIONAL_FACTORY_ENTRY(
101+
CheckArgType(0, "struct curandStateXORWOW *"),
102+
MEMBER_CALL_FACTORY_ENTRY("curand",
103+
DEREF(makeDerefArgCreatorWithCall(0)), false,
104+
"generate<oneapi::mkl::rng::device::uniform_"
105+
"bits<std::uint32_t>, 1>"),
106+
MEMBER_CALL_FACTORY_ENTRY(
107+
"curand", DEREF(makeDerefArgCreatorWithCall(0)), false,
108+
"generate<oneapi::mkl::rng::device::bits<std::uint32_t>, 1>")))
123109
FEATURE_REQUEST_FACTORY(
124110
HelperFeatureEnum::RngUtils_rng_generator_generate,
125111
MEMBER_CALL_FACTORY_ENTRY("curand4",

clang/lib/DPCT/APINames_cuRAND.inc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ ENTRY(curandSetStream, curandSetStream, true, NO_FLAG, P4, "Successful")
6060
ENTRY(curandMakeMTGP32KernelState, curandMakeMTGP32KernelState, false, NO_FLAG, P4, "comment")
6161
ENTRY(curand, curand, true, NO_FLAG, P4, "Successful")
6262
ENTRY(curand4, curand4, true, NO_FLAG, P4, "Successful")
63-
ENTRY(curand_init, curand_init, true, NO_FLAG, P4, "DPCT1028")
63+
ENTRY(curand_init, curand_init, true, NO_FLAG, P4, "DPCT1028/DPCT1105")
6464
ENTRY(curand_log_normal, curand_log_normal, true, NO_FLAG, P4, "Successful")
6565
ENTRY(curand_log_normal2, curand_log_normal2, true, NO_FLAG, P4, "Successful")
6666
ENTRY(curand_log_normal2_double, curand_log_normal2_double, true, NO_FLAG, P4, "Successful")
@@ -78,7 +78,7 @@ ENTRY(curand_uniform2_double, curand_uniform2_double, true, NO_FLAG, P4, "Succes
7878
ENTRY(curand_uniform4, curand_uniform4, true, NO_FLAG, P4, "Successful")
7979
ENTRY(curand_uniform_double, curand_uniform_double, true, NO_FLAG, P4, "Successful")
8080
ENTRY(skipahead, skipahead, true, NO_FLAG, P4, "Successful")
81-
ENTRY(skipahead_sequence, skipahead_sequence, true, NO_FLAG, P4, "Successful")
81+
ENTRY(skipahead_sequence, skipahead_sequence, true, NO_FLAG, P4, "DPCT1026/DPCT1027")
8282
ENTRY(skipahead_subsequence, skipahead_subsequence, true, NO_FLAG, P4, "Successful")
8383
ENTRY(curand_uniform4_double, curand_uniform4_double, true, NO_FLAG, P4, "Successful")
8484
ENTRY(curand_normal4_double, curand_normal4_double, true, NO_FLAG, P4, "Successful")

clang/lib/DPCT/ASTTraversal.cpp

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4743,28 +4743,32 @@ void DeviceRandomFunctionCallRule::runRule(
47434743
FirstOffsetArg = "static_cast<std::uint64_t>(" + RNGOffset + ")";
47444744
}
47454745

4746-
std::string Factor = "8";
4747-
if (GeneratorType == "dpct::rng::device::rng_generator<oneapi::"
4748-
"mkl::rng::device::philox4x32x10<1>>" &&
4749-
(DRefArg3Type == "curandStatePhilox4_32_10_t" ||
4750-
DRefArg3Type == "curandStatePhilox4_32_10")) {
4751-
Factor = "4";
4752-
}
4753-
4754-
if (needExtraParens(CE->getArg(1))) {
4755-
RNGSubseq = "(" + RNGSubseq + ")";
4756-
}
4757-
if (IsRNGSubseqLiteral) {
4758-
SecondOffsetArg = RNGSubseq + " * " + Factor;
4746+
std::string ReplStr;
4747+
if (DRefArg3Type == "curandStateXORWOW") {
4748+
report(FuncNameBegin, Diagnostics::SUBSEQUENCE_IGNORED, false, RNGSubseq);
4749+
ReplStr = RNGStateName + " = " + GeneratorType + "(" + RNGSeed + ", " +
4750+
FirstOffsetArg + ")";
47594751
} else {
4760-
SecondOffsetArg =
4761-
"static_cast<std::uint64_t>(" + RNGSubseq + " * " + Factor + ")";
4762-
}
4752+
std::string Factor = "8";
4753+
if (GeneratorType == "dpct::rng::device::rng_generator<oneapi::"
4754+
"mkl::rng::device::philox4x32x10<1>>" &&
4755+
DRefArg3Type == "curandStatePhilox4_32_10") {
4756+
Factor = "4";
4757+
}
47634758

4764-
std::string ReplStr = RNGStateName + " = " + GeneratorType + "(" + RNGSeed +
4765-
", {" + FirstOffsetArg + ", " + SecondOffsetArg +
4766-
"})";
4759+
if (needExtraParens(CE->getArg(1))) {
4760+
RNGSubseq = "(" + RNGSubseq + ")";
4761+
}
4762+
if (IsRNGSubseqLiteral) {
4763+
SecondOffsetArg = RNGSubseq + " * " + Factor;
4764+
} else {
4765+
SecondOffsetArg =
4766+
"static_cast<std::uint64_t>(" + RNGSubseq + " * " + Factor + ")";
4767+
}
47674768

4769+
ReplStr = RNGStateName + " = " + GeneratorType + "(" + RNGSeed + ", {" +
4770+
FirstOffsetArg + ", " + SecondOffsetArg + "})";
4771+
}
47684772
emplaceTransformation(
47694773
new ReplaceText(FuncNameBegin, FuncCallLength, std::move(ReplStr)));
47704774
} else if (FuncName == "skipahead" || FuncName == "skipahead_sequence" ||

clang/lib/DPCT/CallExprRewriter.h

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -307,16 +307,19 @@ class InsertAroundRewriter : public CallExprRewriter {
307307
class RemoveAPIRewriter : public CallExprRewriter {
308308
bool IsAssigned = false;
309309
std::string CalleeName;
310+
std::string Message;
310311

311312
public:
312-
RemoveAPIRewriter(const CallExpr *C, std::string CalleeName)
313-
: CallExprRewriter(C, CalleeName), IsAssigned(isAssigned(C)), CalleeName(CalleeName) {}
313+
RemoveAPIRewriter(const CallExpr *C, std::string CalleeName,
314+
std::string Message = "")
315+
: CallExprRewriter(C, CalleeName), IsAssigned(isAssigned(C)),
316+
CalleeName(CalleeName), Message(Message) {}
314317

315318
std::optional<std::string> rewrite() override {
316-
std::string Msg = "this call is redundant in SYCL.";
319+
std::string Msg =
320+
Message.empty() ? "this call is redundant in SYCL." : Message;
317321
if (IsAssigned) {
318-
report(Diagnostics::FUNC_CALL_REMOVED_0, false,
319-
CalleeName, Msg);
322+
report(Diagnostics::FUNC_CALL_REMOVED_0, false, CalleeName, Msg);
320323
return std::optional<std::string>("0");
321324
}
322325
report(Diagnostics::FUNC_CALL_REMOVED, false,

clang/lib/DPCT/CallExprRewriterCommon.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1084,8 +1084,11 @@ inline std::shared_ptr<CallExprRewriterFactoryBase> createToStringExprRewriterFa
10841084
}
10851085

10861086
inline std::shared_ptr<CallExprRewriterFactoryBase>
1087-
createRemoveAPIRewriterFactory(const std::string &SourceName) {
1088-
return std::make_shared<CallExprRewriterFactory<RemoveAPIRewriter>>(SourceName);
1087+
createRemoveAPIRewriterFactory(const std::string &SourceName,
1088+
std::string Message = "") {
1089+
return std::make_shared<
1090+
CallExprRewriterFactory<RemoveAPIRewriter, std::string>>(SourceName,
1091+
Message);
10891092
}
10901093

10911094
/// Create AssignableRewriterFactory key-value pair with inner key-value.
@@ -1846,6 +1849,8 @@ class IsArgumentIntegerType {
18461849
{FuncName, createToStringExprRewriterFactory(FuncName, __VA_ARGS__)},
18471850
#define REMOVE_API_FACTORY_ENTRY(FuncName) \
18481851
{FuncName, createRemoveAPIRewriterFactory(FuncName)},
1852+
#define REMOVE_API_FACTORY_ENTRY_WITH_MSG(FuncName, Msg) \
1853+
{FuncName, createRemoveAPIRewriterFactory(FuncName, Msg)},
18491854
#define CASE_FACTORY_ENTRY(...) \
18501855
createCaseRewriterFactory(__VA_ARGS__),
18511856

clang/lib/DPCT/Diagnostics.inc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ DEF_WARNING(MODULE_LOAD, 1103, "'%0' should be a dynamic library. The dynamic li
223223
DEF_COMMENT(MODULE_LOAD, 1103, "'{0}' should be a dynamic library. The dynamic library should supply wrapped kernel functions.")
224224
DEF_WARNING(MODULE_LOAD_DATA, 1104, "'%0' should point to a dynamic library loaded in memory. The dynamic library should supply wrapped kernel functions.")
225225
DEF_COMMENT(MODULE_LOAD_DATA, 1104, "'{0}' should point to a dynamic library loaded in memory. The dynamic library should supply wrapped kernel functions.")
226-
DEF_WARNING(RESERVE_FOR_FUTURE, 1105, "")
227-
DEF_COMMENT(RESERVE_FOR_FUTURE, 1105, "")
226+
DEF_WARNING(SUBSEQUENCE_IGNORED, 1105, "The mcg59 random number generator is used. The subsequence argument \"%0\" is ignored. You need to verify the migration.")
227+
DEF_COMMENT(SUBSEQUENCE_IGNORED, 1105, "The mcg59 random number generator is used. The subsequence argument \"{0}\" is ignored. You need to verify the migration.")
228228
DEF_WARNING(EXTENSION_DEVICE_INFO, 1106, "'%0' was migrated with the Intel extensions for device information which may not be supported by all compilers or runtimes. You may need to adjust the code.")
229229
DEF_COMMENT(EXTENSION_DEVICE_INFO, 1106, "'{0}' was migrated with the Intel extensions for device information which may not be supported by all compilers or runtimes. You may need to adjust the code.")

clang/lib/DPCT/MapNames.cpp

Lines changed: 42 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ std::unordered_map<std::string, std::shared_ptr<TypeNameRule>>
3232
MapNames::TypeNamesMap;
3333
std::unordered_map<std::string, std::shared_ptr<ClassFieldRule>>
3434
MapNames::ClassFieldMap;
35+
MapNames::MapTy MapNames::RandomEngineTypeMap;
36+
MapNames::MapTy MapNames::DeviceRandomGeneratorTypeMap;
3537
std::unordered_map<std::string, std::shared_ptr<TypeNameRule>>
3638
MapNames::CuDNNTypeNamesMap;
3739
std::unordered_map<std::string, std::shared_ptr<EnumNameRule>>
@@ -435,6 +437,46 @@ void MapNames::setExplicitNamespaceMap() {
435437
// ...
436438
};
437439

440+
// Host Random Engine Type mapping
441+
RandomEngineTypeMap = {
442+
{"CURAND_RNG_PSEUDO_DEFAULT", getDpctNamespace() + "rng::random_engine_type::mcg59"},
443+
{"CURAND_RNG_PSEUDO_XORWOW", getDpctNamespace() + "rng::random_engine_type::mcg59"},
444+
{"CURAND_RNG_PSEUDO_MRG32K3A", getDpctNamespace() + "rng::random_engine_type::mrg32k3a"},
445+
{"CURAND_RNG_PSEUDO_MTGP32", getDpctNamespace() + "rng::random_engine_type::mt2203"},
446+
{"CURAND_RNG_PSEUDO_MT19937", getDpctNamespace() + "rng::random_engine_type::mt19937"},
447+
{"CURAND_RNG_PSEUDO_PHILOX4_32_10",
448+
getDpctNamespace() + "rng::random_engine_type::philox4x32x10"},
449+
{"CURAND_RNG_QUASI_DEFAULT", getDpctNamespace() + "rng::random_engine_type::sobol"},
450+
{"CURAND_RNG_QUASI_SOBOL32", getDpctNamespace() + "rng::random_engine_type::sobol"},
451+
{"CURAND_RNG_QUASI_SCRAMBLED_SOBOL32",
452+
getDpctNamespace() + "rng::random_engine_type::sobol"},
453+
{"CURAND_RNG_QUASI_SOBOL64", getDpctNamespace() + "rng::random_engine_type::sobol"},
454+
{"CURAND_RNG_QUASI_SCRAMBLED_SOBOL64",
455+
getDpctNamespace() + "rng::random_engine_type::sobol"},
456+
};
457+
458+
// Device Random Generator Type mapping
459+
DeviceRandomGeneratorTypeMap = {
460+
{"curandStateXORWOW_t", getDpctNamespace() + "rng::device::rng_generator<oneapi::"
461+
"mkl::rng::device::mcg59<1>>"},
462+
{"curandStateXORWOW", getDpctNamespace() + "rng::device::rng_generator<oneapi::"
463+
"mkl::rng::device::mcg59<1>>"},
464+
{"curandState_t", getDpctNamespace() + "rng::device::rng_generator<oneapi::mkl::"
465+
"rng::device::mcg59<1>>"},
466+
{"curandState", getDpctNamespace() + "rng::device::rng_generator<oneapi::mkl::"
467+
"rng::device::mcg59<1>>"},
468+
{"curandStatePhilox4_32_10_t",
469+
getDpctNamespace() + "rng::device::rng_generator<oneapi::mkl::rng::device::"
470+
"philox4x32x10<1>>"},
471+
{"curandStatePhilox4_32_10",
472+
getDpctNamespace() + "rng::device::rng_generator<"
473+
"oneapi::mkl::rng::device::philox4x32x10<1>>"},
474+
{"curandStateMRG32k3a_t", getDpctNamespace() + "rng::device::rng_generator<"
475+
"oneapi::mkl::rng::device::mrg32k3a<1>>"},
476+
{"curandStateMRG32k3a", getDpctNamespace() + "rng::device::rng_generator<oneapi::"
477+
"mkl::rng::device::mrg32k3a<1>>"},
478+
};
479+
438480
// CuDNN Type names mapping.
439481
CuDNNTypeNamesMap = {
440482
{"cudnnHandle_t",
@@ -4063,47 +4105,6 @@ const std::map<std::string, MapNames::SOLVERFuncReplInfo>
40634105
"oneapi::mkl::lapack::gesvd")},
40644106
};
40654107

4066-
// Host Random Engine Type mapping
4067-
const MapNames::MapTy MapNames::RandomEngineTypeMap{
4068-
{"CURAND_RNG_PSEUDO_DEFAULT",
4069-
"dpct::rng::random_engine_type::philox4x32x10"},
4070-
{"CURAND_RNG_PSEUDO_XORWOW",
4071-
"dpct::rng::random_engine_type::philox4x32x10"},
4072-
{"CURAND_RNG_PSEUDO_MRG32K3A", "dpct::rng::random_engine_type::mrg32k3a"},
4073-
{"CURAND_RNG_PSEUDO_MTGP32", "dpct::rng::random_engine_type::mt2203"},
4074-
{"CURAND_RNG_PSEUDO_MT19937", "dpct::rng::random_engine_type::mt19937"},
4075-
{"CURAND_RNG_PSEUDO_PHILOX4_32_10",
4076-
"dpct::rng::random_engine_type::philox4x32x10"},
4077-
{"CURAND_RNG_QUASI_DEFAULT", "dpct::rng::random_engine_type::sobol"},
4078-
{"CURAND_RNG_QUASI_SOBOL32", "dpct::rng::random_engine_type::sobol"},
4079-
{"CURAND_RNG_QUASI_SCRAMBLED_SOBOL32",
4080-
"dpct::rng::random_engine_type::sobol"},
4081-
{"CURAND_RNG_QUASI_SOBOL64", "dpct::rng::random_engine_type::sobol"},
4082-
{"CURAND_RNG_QUASI_SCRAMBLED_SOBOL64",
4083-
"dpct::rng::random_engine_type::sobol"},
4084-
};
4085-
4086-
// Device Random Generator Type mapping
4087-
const MapNames::MapTy MapNames::DeviceRandomGeneratorTypeMap{
4088-
{"curandStateXORWOW_t", "dpct::rng::device::rng_generator<oneapi::"
4089-
"mkl::rng::device::philox4x32x10<1>>"},
4090-
{"curandStateXORWOW", "dpct::rng::device::rng_generator<oneapi::"
4091-
"mkl::rng::device::philox4x32x10<1>>"},
4092-
{"curandState_t", "dpct::rng::device::rng_generator<oneapi::mkl::"
4093-
"rng::device::philox4x32x10<1>>"},
4094-
{"curandState", "dpct::rng::device::rng_generator<oneapi::mkl::"
4095-
"rng::device::philox4x32x10<1>>"},
4096-
{"curandStatePhilox4_32_10_t",
4097-
"dpct::rng::device::rng_generator<oneapi::mkl::rng::device::"
4098-
"philox4x32x10<1>>"},
4099-
{"curandStatePhilox4_32_10", "dpct::rng::device::rng_generator<"
4100-
"oneapi::mkl::rng::device::philox4x32x10<1>>"},
4101-
{"curandStateMRG32k3a_t", "dpct::rng::device::rng_generator<"
4102-
"oneapi::mkl::rng::device::mrg32k3a<1>>"},
4103-
{"curandStateMRG32k3a", "dpct::rng::device::rng_generator<oneapi::"
4104-
"mkl::rng::device::mrg32k3a<1>>"},
4105-
};
4106-
41074108
const std::map<std::string, std::string> MapNames::RandomGenerateFuncMap{
41084109
{"curandGenerate", {"generate_uniform_bits"}},
41094110
{"curandGenerateLongLong", {"generate_uniform_bits"}},

clang/lib/DPCT/MapNames.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -354,10 +354,10 @@ class MapNames {
354354
SOLVERFuncReplInfoMap;
355355

356356
static MapTy ITFName;
357-
static const MapTy RandomEngineTypeMap;
357+
static MapTy RandomEngineTypeMap;
358358
static const std::map<std::string, std::string> RandomGenerateFuncMap;
359359

360-
static const MapTy DeviceRandomGeneratorTypeMap;
360+
static MapTy DeviceRandomGeneratorTypeMap;
361361

362362
static const std::map<std::string, std::vector<unsigned int>>
363363
FFTPlanAPINeedParenIdxMap;

0 commit comments

Comments
 (0)