Skip to content

Commit

Permalink
Fix FillRandomNormal for odd number of length (#4338)
Browse files Browse the repository at this point in the history
In both CUDA and HIP, the length must be even for the normal
distribution generator we use. This PR fixes it by generating the last
one on the host if the total length is odd.
  • Loading branch information
WeiqunZhang authored Feb 19, 2025
1 parent b364bec commit 28a1e19
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 14 deletions.
24 changes: 16 additions & 8 deletions Src/Base/AMReX_GpuError.H
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,14 @@ namespace Gpu {
amrex::Abort(errStr); \
}}

#define AMREX_CURAND_SAFE_CALL(x) do { if((x)!=CURAND_STATUS_SUCCESS) { \
std::string errStr(std::string("CURAND error in file ") + __FILE__ \
+ " line " + std::to_string(__LINE__)); \
amrex::Abort(errStr); }} while(0)
#define AMREX_CURAND_SAFE_CALL(call) { \
curandStatus_t amrex_i_err = call; \
if (CURAND_STATUS_SUCCESS != amrex_i_err) { \
std::string errStr(std::string("CURAND error ") + std::to_string(amrex_i_err) \
+ std::string(" in file ") + __FILE__ \
+ " line " + std::to_string(__LINE__)); \
amrex::Abort(errStr); \
}}

#define AMREX_CUFFT_SAFE_CALL(call) { \
cufftResult_t amrex_i_err = call; \
Expand All @@ -106,10 +110,14 @@ namespace Gpu {
amrex::Abort(errStr); \
}}

#define AMREX_HIPRAND_SAFE_CALL(x) do { if((x)!=HIPRAND_STATUS_SUCCESS) { \
std::string errStr(std::string("HIPRAND error in file ") + __FILE__ \
+ " line " + std::to_string(__LINE__)); \
amrex::Abort(errStr); }} while(0)
#define AMREX_HIPRAND_SAFE_CALL(call) { \
hiprandStatus_t amrex_i_err = call; \
if (HIPRAND_STATUS_SUCCESS != amrex_i_err) { \
std::string errStr(std::string("HIPRAND error ") + std::to_string(amrex_i_err) \
+ std::string(" in file ") + __FILE__ \
+ " line " + std::to_string(__LINE__)); \
amrex::Abort(errStr); \
}}

#define AMREX_ROCFFT_SAFE_CALL(call) { \
auto amrex_i_err = call; \
Expand Down
31 changes: 25 additions & 6 deletions Src/Base/AMReX_Random.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -285,23 +285,34 @@ void FillRandom (Real* p, Long N)

void FillRandomNormal (Real* p, Long N, Real mean, Real stddev)
{
if (N <= 0) { return; }

#if defined(AMREX_USE_CUDA) || defined(AMREX_USE_HIP)
if (N == 1) {
auto r = amrex::RandomNormal(mean, stddev);
Gpu::htod_memcpy_async(p, &r, sizeof(Real));
Gpu::streamSynchronize();
return;
}
// The length passed to [cu|hip]randGenerateNormal must be even
Long Neven = (N%2 == 0) ? N : N-1;
#endif

#if defined(AMREX_USE_CUDA)

# ifdef BL_USE_FLOAT
AMREX_CURAND_SAFE_CALL(curandGenerateNormal(gpu_rand_generator, p, N, mean, stddev));
AMREX_CURAND_SAFE_CALL(curandGenerateNormal(gpu_rand_generator, p, Neven, mean, stddev));
# else
AMREX_CURAND_SAFE_CALL(curandGenerateNormalDouble(gpu_rand_generator, p, N, mean, stddev));
AMREX_CURAND_SAFE_CALL(curandGenerateNormalDouble(gpu_rand_generator, p, Neven, mean, stddev));
# endif
Gpu::synchronize();

#elif defined(AMREX_USE_HIP)

# ifdef BL_USE_FLOAT
AMREX_HIPRAND_SAFE_CALL(hiprandGenerateNormal(gpu_rand_generator, p, N, mean, stddev));
AMREX_HIPRAND_SAFE_CALL(hiprandGenerateNormal(gpu_rand_generator, p, Neven, mean, stddev));
# else
AMREX_HIPRAND_SAFE_CALL(hiprandGenerateNormalDouble(gpu_rand_generator, p, N, mean, stddev));
AMREX_HIPRAND_SAFE_CALL(hiprandGenerateNormalDouble(gpu_rand_generator, p, Neven, mean, stddev));
# endif
Gpu::synchronize();

#elif defined(AMREX_USE_SYCL)

Expand All @@ -318,6 +329,14 @@ void FillRandomNormal (Real* p, Long N, Real mean, Real stddev)
}

#endif

#if defined(AMREX_USE_CUDA) || defined(AMREX_USE_HIP)
if (Neven < N) {
auto r = amrex::RandomNormal(mean, stddev);
Gpu::htod_memcpy_async(p+(N-1), &r, sizeof(Real));
}
Gpu::synchronize();
#endif
}

} // namespace amrex
Expand Down

0 comments on commit 28a1e19

Please sign in to comment.