Skip to content
This repository has been archived by the owner on Aug 16, 2024. It is now read-only.

Commit

Permalink
[ROCm] Changes not to rely on CUDA_VERSION or HIP_VERSION (pytorch#65610
Browse files Browse the repository at this point in the history
)

Summary:
Pull Request resolved: pytorch#65610

- Replace HIP_PLATFORM_HCC with USE_ROCM
- Dont rely on CUDA_VERSION or HIP_VERSION and use USE_ROCM and ROCM_VERSION.

- In the next PR
   - Will be removing the mapping from CUDA_VERSION to HIP_VERSION and CUDA to HIP in hipify.
   - HIP_PLATFORM_HCC is deprecated, so will add HIP_PLATFORM_AMD to support HIP host code compilation on gcc.

cc jeffdaily sunway513 jithunnair-amd ROCmSupport amathews-amd

Reviewed By: jbschlosser

Differential Revision: D30909053

Pulled By: ezyang

fbshipit-source-id: 224a966ebf1aaec79beccbbd686fdf3d49267e06
  • Loading branch information
pruthvistony authored and facebook-github-bot committed Sep 29, 2021
1 parent 9b40eaa commit 085e2f7
Show file tree
Hide file tree
Showing 131 changed files with 415 additions and 398 deletions.
4 changes: 2 additions & 2 deletions aten/src/ATen/Dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,11 @@ TORCH_API void record_kernel_function_dtype(std::string name);
// Workaround for C10_UNUSED because CUDA 10.1 and below fails to handle unused
// attribute in the type aliasing context. Keep name long and verbose to avoid
// macro collisions.
#if defined(__CUDACC__) && CUDA_VERSION <= 10100
#if defined(__CUDACC__) && defined(CUDA_VERSION) && CUDA_VERSION <= 10100
#define C10_UNUSED_DISPATCH_CUDA_WORKAROUND
#else
#define C10_UNUSED_DISPATCH_CUDA_WORKAROUND C10_UNUSED
#endif // defined(__CUDACC__) && CUDA_VERSION <= 10100
#endif // defined(__CUDACC__) && defined(CUDA_VERSION) && CUDA_VERSION <= 10100

#if defined __cpp_if_constexpr
#define AT_QINT_PRIVATE_CASE_TYPE( \
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/core/Array.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ struct Array {
C10_HOST_DEVICE T& operator[](int i) {
return data[i];
}
#ifdef __HIP_PLATFORM_HCC__
#if defined(USE_ROCM)
C10_HOST_DEVICE Array() = default;
C10_HOST_DEVICE Array(const Array&) = default;
C10_HOST_DEVICE Array& operator=(const Array&) = default;
Expand Down
10 changes: 5 additions & 5 deletions aten/src/ATen/cuda/Atomic.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ static inline __device__ int32_t gpuAtomicAdd(int32_t *address, int32_t val) {
}

static inline __device__ void gpuAtomicAdd(int64_t *address, int64_t val) {
#ifdef __HIP_PLATFORM_HCC__
#if defined(USE_ROCM)
__atomic_fetch_add(address, val, __ATOMIC_RELAXED);
#else
AtomicAddIntegerImpl<int64_t, sizeof(int64_t)>()(address, val);
Expand All @@ -179,7 +179,7 @@ static inline __device__ void gpuAtomicAdd(bool *address, bool val) {
}

static inline __device__ at::Half gpuAtomicAdd(at::Half *address, at::Half val) {
#if ((CUDA_VERSION < 10000) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700)))
#if defined(USE_ROCM) || ((defined(CUDA_VERSION) && CUDA_VERSION < 10000) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700)))
return AtomicFPOp<at::Half>()(address, val,
[](at::Half hsum, at::Half val) {
return hsum + val;
Expand All @@ -196,7 +196,7 @@ static inline __device__ at::BFloat16 gpuAtomicAdd(at::BFloat16 *address, at::BF
});
}

#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 600 || CUDA_VERSION < 8000)
#if defined(CUDA_VERSION) && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 600 || CUDA_VERSION < 8000)
// from CUDA C Programmic Guide
static inline __device__ double atomicAdd(double* address, double val)
#if defined(__clang__) && defined(__CUDA__)
Expand All @@ -212,7 +212,7 @@ static inline __device__ double atomicAdd(double* address, double val)
return __double_as_longlong(val + __longlong_as_double(assumed));
});
}
#elif !defined(__CUDA_ARCH__) && (CUDA_VERSION < 8000) || defined(__HIP_PLATFORM_HCC__)
#elif defined(USE_ROCM) || !(defined(__CUDA_ARCH__) && (defined(CUDA_VERSION) && CUDA_VERSION < 8000))

/* Note [hip-clang differences to hcc]
* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down Expand Up @@ -298,7 +298,7 @@ static inline __device__ void gpuAtomicAddNoReturn(at::BFloat16 *address, at::BF
static inline __device__ void gpuAtomicAddNoReturn(double *address, double val) { gpuAtomicAdd(address, val); }

/* Special case fp32 atomic. */
#if defined(__HIP_PLATFORM_HCC__) && defined(__gfx908__)
#if defined(USE_ROCM) && defined(__gfx908__)
static inline __device__ void gpuAtomicAddNoReturn(float *address, float val) { atomicAddNoRet(address, val); }
#else
static inline __device__ void gpuAtomicAddNoReturn(float *address, float val) { gpuAtomicAdd(address, val); }
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/cuda/CUDAApplyUtils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ template <typename Op,
typename IndexType,
int ADims,
int step>
#if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__
#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM)
C10_LAUNCH_BOUNDS_2(AT_APPLY_THREADS_PER_BLOCK, AT_APPLY_BLOCKS_PER_SM)
#endif
__global__ void kernelPointwiseApply1(detail::TensorInfo<scalar, IndexType> a,
Expand Down Expand Up @@ -360,7 +360,7 @@ template <typename Op,
int step,
int max_threads_per_block=AT_APPLY_THREADS_PER_BLOCK,
int min_blocks_per_sm=AT_APPLY_BLOCKS_PER_SM>
#if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__
#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM)
C10_LAUNCH_BOUNDS_2(max_threads_per_block, min_blocks_per_sm)
#endif
__global__ void
Expand Down
44 changes: 22 additions & 22 deletions aten/src/ATen/cuda/CUDABlas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ const char* _cublasGetErrorEnum(cublasStatus_t error) {

/* LEVEL 3 BLAS FUNCTIONS */

#ifndef __HIP_PLATFORM_HCC__
#ifndef USE_ROCM
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11200
#define cublasGemmStridedBatchedExFix cublasGemmStridedBatchedEx
#else
Expand Down Expand Up @@ -271,7 +271,7 @@ void bgemm<at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half)) {
BGEMM_CHECK_ARGVALUES(at::Half);
float falpha = alpha;
float fbeta = beta;
#ifdef __HIP_PLATFORM_HCC__
#ifdef USE_ROCM
TORCH_CUDABLAS_CHECK(rocblas_gemm_strided_batched_ex(handle, opa, opb, (int)m, (int)n, (int)k,
(void*)&falpha, a, rocblas_datatype_f16_r, (int)lda, stridea,
b, rocblas_datatype_f16_r, (int)ldb, strideb,
Expand All @@ -284,7 +284,7 @@ void bgemm<at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half)) {
// On CUDA versions prior to 11, users are required to set the math mode to CUBLAS_TENSOR_OP_MATH
// manually to be able to use tensor cores for FP16. On CUDA 11, this is no longer required.
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#endif // CUDA_VERSION < 11000
#endif // defined(CUDA_VERSION) && CUDA_VERSION < 11000

cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
if (prop->major >= 5){
Expand All @@ -308,11 +308,11 @@ void bgemm<at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half)) {
// On CUDA versions prior to 11, users are required to set the math mode to CUBLAS_TENSOR_OP_MATH
// manually to be able to use tensor cores for FP16. On CUDA 11, this is no longer required.
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
#endif // CUDA_VERSION < 11000
#endif // __HIP_PLATFORM_HCC__
#endif // defined(CUDA_VERSION) && CUDA_VERSION < 11000
#endif // USE_ROCM
}

#if defined(__HIP_PLATFORM_HCC__) || defined(CUDA_VERSION) && CUDA_VERSION >= 11000
#if defined(USE_ROCM) || defined(CUDA_VERSION) && CUDA_VERSION >= 11000
template <>
void bgemm<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)) {
// See Note [Writing Nondeterministic Operations]
Expand All @@ -332,7 +332,7 @@ void bgemm<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)) {
b, CUDA_R_16BF, (int)ldb, strideb,
(void*)&fbeta, c, CUDA_R_16BF, (int)ldc, stridec,
(int)num_batches, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));
#elif defined(__HIP_PLATFORM_HCC__)
#elif defined(USE_ROCM)
TORCH_CUDABLAS_CHECK(rocblas_gemm_strided_batched_ex(handle, opa, opb, (int)m, (int)n, (int)k,
(void*)&falpha, a, rocblas_datatype_bf16_r, (int)lda, stridea,
b, rocblas_datatype_bf16_r, (int)ldb, strideb,
Expand All @@ -344,7 +344,7 @@ void bgemm<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)) {
TORCH_CHECK(false, "CUDA BFloat16 bgemm requires CUDA 11 or later");
#endif // defined(CUDA_VERSION) && CUDA_VERSION >= 11000
}
#endif // __HIP_PLATFORM_HCC__
#endif // USE_ROCM

template <>
void gemm<double>(CUDABLAS_GEMM_ARGTYPES(double)) {
Expand Down Expand Up @@ -372,7 +372,7 @@ void gemm<float>(CUDABLAS_GEMM_ARGTYPES(float)) {
handle, opa, opb, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc));
}

#if !defined(__HIP_PLATFORM_HCC__) || (defined(__HIP_PLATFORM_HCC__) && TORCH_HIP_VERSION >= 210)
#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 21000)
template <>
void gemm<c10::complex<double>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<double>)) {
// See Note [Writing Nondeterministic Operations]
Expand All @@ -389,7 +389,7 @@ void gemm<float>(CUDABLAS_GEMM_ARGTYPES(float)) {
}
#endif

#if !defined(__HIP_PLATFORM_HCC__) || (defined(__HIP_PLATFORM_HCC__) && TORCH_HIP_VERSION >= 210)
#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 21000)
template <>
void gemm<c10::complex<float>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<float>)) {
// See Note [Writing Nondeterministic Operations]
Expand Down Expand Up @@ -417,7 +417,7 @@ void gemm<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half)) {
float fbeta = beta;
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
GEMM_CHECK_ARGVALUES(at::Half);
#ifdef __HIP_PLATFORM_HCC__
#ifdef USE_ROCM
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(
handle,
opa,
Expand Down Expand Up @@ -450,7 +450,7 @@ void gemm<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half)) {
// On CUDA versions prior to 11, users are required to set the math mode to CUBLAS_TENSOR_OP_MATH
// manually to be able to use tensor cores for FP16. On CUDA 11, this is no longer required.
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#endif // CUDA_VERSION < 11000
#endif // defined(CUDA_VERSION) && CUDA_VERSION < 11000
TORCH_CUDABLAS_CHECK(cublasGemmEx(
handle,
opa,
Expand All @@ -475,7 +475,7 @@ void gemm<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half)) {
// On CUDA versions prior to 11, users are required to set the math mode to CUBLAS_TENSOR_OP_MATH
// manually to be able to use tensor cores for FP16. On CUDA 11, this is no longer required.
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
#endif // CUDA_VERSION < 11000
#endif // defined(CUDA_VERSION) && CUDA_VERSION < 11000
} else {
TORCH_CUDABLAS_CHECK(cublasSgemmEx(
handle,
Expand All @@ -499,7 +499,7 @@ void gemm<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half)) {
#endif
}

#ifdef __HIP_PLATFORM_HCC__
#ifdef USE_ROCM
template <>
void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
Expand Down Expand Up @@ -569,7 +569,7 @@ void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
CUDA_R_32F,
CUBLAS_GEMM_DFALT_TENSOR_OP));
}
#endif
#endif // defined(CUDA_VERSION) && CUDA_VERSION >= 11000

template <>
void trsm<float>(CUDABLAS_TRSM_ARGTYPES(float)) {
Expand Down Expand Up @@ -702,7 +702,7 @@ void trsmBatched<c10::complex<double>>(
CUDABLAS_POSINT_CHECK(gemv<Dtype>, incy); \
} while (0)

#if !defined(__HIP_PLATFORM_HCC__) || (defined(__HIP_PLATFORM_HCC__) && TORCH_HIP_VERSION >= 210)
#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 21000)
template <>
void gemv<c10::complex<double>>(CUDABLAS_GEMV_ARGTYPES(c10::complex<double>)) {
// See Note [Writing Nondeterministic Operations]
Expand All @@ -718,7 +718,7 @@ void trsmBatched<c10::complex<double>>(
}
#endif

#if !defined(__HIP_PLATFORM_HCC__) || (defined(__HIP_PLATFORM_HCC__) && TORCH_HIP_VERSION >= 210)
#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 21000)
template <>
void gemv<c10::complex<float>>(CUDABLAS_GEMV_ARGTYPES(c10::complex<float>)) {
// gemv is bw bound, and does not benefit from TF32. But the precision
Expand Down Expand Up @@ -797,7 +797,7 @@ void gemv<at::Half>(CUDABLAS_GEMV_ARGTYPES(at::Half)) {
'n', trans_flipped, 1, m, n, alpha, x, incx, a, lda, beta, y, incy);
}

#if defined(__HIP_PLATFORM_HCC__) || defined(CUDA_VERSION) && CUDA_VERSION >= 11000
#if defined(USE_ROCM) || defined(CUDA_VERSION) && CUDA_VERSION >= 11000
template <>
void gemv<at::BFloat16>(CUDABLAS_GEMV_ARGTYPES(at::BFloat16)) {
bool trans_bool = (_cublasOpFromChar(trans) != CUBLAS_OP_N);
Expand Down Expand Up @@ -838,7 +838,7 @@ void dot<c10::complex<float>>(CUDABLAS_DOT_ARGTYPES(c10::complex<float>)) {

template <>
void dot<at::Half>(CUDABLAS_DOT_ARGTYPES(at::Half)) {
#if CUDA_VERSION >= 8000
#if defined(CUDA_VERSION) && CUDA_VERSION >= 8000
TORCH_CUDABLAS_CHECK(cublasDotEx(
handle,
n,
Expand All @@ -851,7 +851,7 @@ void dot<at::Half>(CUDABLAS_DOT_ARGTYPES(at::Half)) {
result,
CUDA_R_16F,
CUDA_R_32F));
#elif TORCH_HIP_VERSION >= 210
#elif defined(ROCM_VERSION) && ROCM_VERSION >= 21000
TORCH_CUDABLAS_CHECK(rocblas_hdot(
handle,
n,
Expand All @@ -867,7 +867,7 @@ void dot<at::Half>(CUDABLAS_DOT_ARGTYPES(at::Half)) {

template <>
void dot<at::BFloat16>(CUDABLAS_DOT_ARGTYPES(at::BFloat16)) {
#if CUDA_VERSION >= 11000
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
TORCH_CUDABLAS_CHECK(cublasDotEx(
handle,
n,
Expand All @@ -880,7 +880,7 @@ void dot<at::BFloat16>(CUDABLAS_DOT_ARGTYPES(at::BFloat16)) {
result,
CUDA_R_16BF,
CUDA_R_32F));
#elif TORCH_HIP_VERSION >= 210
#elif defined(ROCM_VERSION) && ROCM_VERSION >= 21000
TORCH_CUDABLAS_CHECK(rocblas_bfdot(
handle,
n,
Expand Down
12 changes: 6 additions & 6 deletions aten/src/ATen/cuda/CUDABlas.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,17 +54,17 @@ template <>
void gemm<double>(CUDABLAS_GEMM_ARGTYPES(double));
template <>
void gemm<float>(CUDABLAS_GEMM_ARGTYPES(float));
#if !defined(__HIP_PLATFORM_HCC__) || (defined(__HIP_PLATFORM_HCC__) && TORCH_HIP_VERSION >= 210)
#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 21000)
template <>
void gemm<c10::complex<double>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<double>));
#endif
#if !defined(__HIP_PLATFORM_HCC__) || (defined(__HIP_PLATFORM_HCC__) && TORCH_HIP_VERSION >= 210)
#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 21000)
template <>
void gemm<c10::complex<float>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<float>));
#endif
template <>
void gemm<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half));
#if defined(__HIP_PLATFORM_HCC__) || defined(CUDA_VERSION) && CUDA_VERSION >= 11000
#if defined(USE_ROCM) || defined(CUDA_VERSION) && CUDA_VERSION >= 11000
template <>
void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16));
#endif
Expand All @@ -90,7 +90,7 @@ template <>
void bgemm<c10::complex<float>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<float>));
template <>
void bgemm<at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half));
#if defined(__HIP_PLATFORM_HCC__) || defined(CUDA_VERSION) && CUDA_VERSION >= 11000
#if defined(USE_ROCM) || defined(CUDA_VERSION) && CUDA_VERSION >= 11000
template <>
void bgemm<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16));
#endif
Expand Down Expand Up @@ -152,15 +152,15 @@ template <>
void gemv<double>(CUDABLAS_GEMV_ARGTYPES(double));
template <>
void gemv<float>(CUDABLAS_GEMV_ARGTYPES(float));
#if !defined(__HIP_PLATFORM_HCC__) || (defined(__HIP_PLATFORM_HCC__) && TORCH_HIP_VERSION >= 210)
#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 21000)
template <>
void gemv<c10::complex<double>>(CUDABLAS_GEMV_ARGTYPES(c10::complex<double>));
template <>
void gemv<c10::complex<float>>(CUDABLAS_GEMV_ARGTYPES(c10::complex<float>));
#endif
template <>
void gemv<at::Half>(CUDABLAS_GEMV_ARGTYPES(at::Half));
#if defined(__HIP_PLATFORM_HCC__) || defined(CUDA_VERSION) && CUDA_VERSION >= 11000
#if defined(USE_ROCM) || defined(CUDA_VERSION) && CUDA_VERSION >= 11000
template <>
void gemv<at::BFloat16>(CUDABLAS_GEMV_ARGTYPES(at::BFloat16));
#endif
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/cuda/CUDAEvent.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ struct TORCH_CUDA_CPP_API CUDAEvent {

CUDAEvent(
DeviceIndex device_index, const cudaIpcEventHandle_t* handle) {
#ifndef __HIP_PLATFORM_HCC__
#if !defined(USE_ROCM)
device_index_ = device_index;
CUDAGuard guard(device_index_);

Expand Down Expand Up @@ -148,7 +148,7 @@ struct TORCH_CUDA_CPP_API CUDAEvent {

// Note: cudaIpcGetEventHandle must be called on the same device as the event
void ipc_handle(cudaIpcEventHandle_t * handle) {
#ifndef __HIP_PLATFORM_HCC__
#if !defined(USE_ROCM)
if (!is_created_) {
// this CUDAEvent object was initially constructed from flags but event_
// is not created yet.
Expand Down
Loading

0 comments on commit 085e2f7

Please sign in to comment.