Skip to content

Commit

Permalink
[Hotfix][CI/Build][Kernel] CUDA 11.8 does not support layernorm optim…
Browse files Browse the repository at this point in the history
…izations (vllm-project#3782)
  • Loading branch information
mawong-amd authored Apr 8, 2024
1 parent bc0c019 commit 59a6abf
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
2 changes: 2 additions & 0 deletions cmake/utils.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)

if (CUDA_VERSION VERSION_GREATER_EQUAL 11.8)
list(APPEND GPU_FLAGS "-DENABLE_FP8_E5M2")
endif()
if (CUDA_VERSION VERSION_GREATER_EQUAL 12.0)
list(REMOVE_ITEM GPU_FLAGS
"-D__CUDA_NO_HALF_OPERATORS__"
"-D__CUDA_NO_HALF_CONVERSIONS__"
Expand Down
6 changes: 4 additions & 2 deletions csrc/layernorm_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ __global__ void rms_norm_kernel(
template<typename torch_type>
struct _typeConvert { static constexpr bool exists = false; };

#if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000))
// CUDA < 12.0 runs into issues with packed type conversion
template<>
struct _typeConvert<c10::Half> {
static constexpr bool exists = true;
Expand All @@ -85,8 +87,8 @@ struct _typeConvert<c10::BFloat16> {
__device__ static inline hip_type convert(float x) { return __float2bfloat16(x); }
__device__ static inline packed_hip_type convert(float2 x) { return __float22bfloat162_rn(x); }
};
#endif

#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000))

/* Vector POD struct to generate vectorized and packed FP16/BF16 ops
for appropriate specializations of fused_add_rms_norm_kernel.
Expand Down

0 comments on commit 59a6abf

Please sign in to comment.