Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ROCm Support : HIP kernel Generator #1201

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
26 changes: 17 additions & 9 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,11 @@ def read_version(file_path="version.txt"):
CUDAExtension,
BuildExtension,
CUDA_HOME,
ROCM_HOME,
petrex marked this conversation as resolved.
Show resolved Hide resolved
IS_WINDOWS
)

IS_ROCM = (torch.version.hip is not None) and (ROCM_HOME is not None)

def get_extensions():
debug_mode = os.getenv('DEBUG', '0') == '1'
Expand All @@ -57,11 +59,11 @@ def get_extensions():

if not torch.cuda.is_available():
print("PyTorch GPU support is not available. Skipping compilation of CUDA extensions")
if CUDA_HOME is None and torch.cuda.is_available():
print("CUDA toolkit is not available. Skipping compilation of CUDA extensions")
if CUDA_HOME is None or not IS_ROCM and torch.cuda.is_available():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some of these if conditions might cause our binary builds to break - check CI, since we run binary builds when people modify setup.py

Until you get your first PR merged I'll have to manually approve every CI run so might be easiest if you wanna make some simple doc level PR and I can merge that fast

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for approving the CI flow, we will look into this

print("CUDA toolkit or ROCm is not available. Skipping compilation of CUDA extensions")
print("If you'd like to compile CUDA extensions locally please install the cudatoolkit from https://anaconda.org/nvidia/cuda-toolkit")

use_cuda = torch.cuda.is_available() and CUDA_HOME is not None
use_cuda = torch.cuda.is_available() and (CUDA_HOME is not None or ROCM_HOME is not None)
extension = CUDAExtension if use_cuda else CppExtension

if not IS_WINDOWS:
Expand All @@ -71,15 +73,14 @@ def get_extensions():
"-O3" if not debug_mode else "-O0",
"-fdiagnostics-color=always",
],
"nvcc": [
"-O3" if not debug_mode else "-O0",
"-t=0",
]
}
if use_cuda and not IS_ROCM:
extra_compile_args["nvcc"] = ["-O3" if not debug_mode else "-O0", "-t=0",]

if debug_mode:
extra_compile_args["cxx"].append("-g")
extra_compile_args["nvcc"].append("-g")
if "nvcc" in extra_compile_args:
extra_compile_args["nvcc"].append("-g")
extra_link_args.extend(["-O0", "-g"])

else:
Expand Down Expand Up @@ -107,9 +108,16 @@ def get_extensions():
extensions_cuda_dir = os.path.join(extensions_dir, "cuda")
cuda_sources = list(glob.glob(os.path.join(extensions_cuda_dir, "**/*.cu"), recursive=True))

if use_cuda:
extensions_hip_dir = os.path.join(extensions_dir, "cuda", "tensor_core_tiled_layout")
hip_sources = list(glob.glob(os.path.join(extensions_hip_dir, "*.cu"), recursive=True))
petrex marked this conversation as resolved.
Show resolved Hide resolved

if not IS_ROCM and use_cuda:
sources += cuda_sources

# TOOD: Remove this and use what CUDA has once we fix all the builds.
if IS_ROCM and use_cuda:
sources += hip_sources

if len(sources) == 0:
return None

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 // at least Ampere
#if (defined(USE_ROCM) && ROCM_VERSION >= 60200) || !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800

#include <ATen/ATen.h>
#include <ATen/core/Tensor.h>
Expand All @@ -7,13 +7,24 @@
#include <c10/cuda/CUDAGuard.h>
#include <torch/extension.h>

#if defined(USE_ROCM)
#include <hip/hip_bf16.h>
#include <hip/hip_fp16.h>
#include <hip/hip_runtime.h>
#endif

template <typename U, typename V>
constexpr __host__ __device__ auto divUp(U a, V b) -> decltype(a + b) {
static_assert(std::is_integral<U>::value && std::is_integral<V>::value, "");
const uint64_t blocks = a / b + (a % b != 0);
return blocks;
}

#if defined(USE_ROCM)
constexpr int32_t kWarpSize = 64;
#else
petrex marked this conversation as resolved.
Show resolved Hide resolved
constexpr int32_t kWarpSize = 32;
#endif

//Simple data structure to represent 4 pairs of bfloat16s, used for vectorized dequantization
//https://github.com/pytorch/pytorch/blob/b6689e0fb83a1578959ab0d9c6d2d9e11f7df21a/aten/src/ATen/native/cuda/int4mm.cu#L178-L180
Expand All @@ -30,38 +41,71 @@ inline __device__ bf16x2x4 convert_i4x8_to_bf16x2x4(uint32_t source) {
uint32_t const source_i4s = source;

// First, we extract the i4s and construct an intermediate fp16 number.
#if !defined(USE_ROCM)
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
#endif
static constexpr uint32_t MASK = 0x000f000f;
static constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300;

// We don't have enough mantissa to remove as much shift overhead as FP16, so
// we must loop. No shift needed for first item.
uint32_t i4s = source_i4s;
// AMD MI300X ISA that performs two bitwise operations in a single instruction:
// v_and_or_b32 performs H[0] = (i4s & MASK) | I4s_TO_BF16s_MAGIC_NUM
// - First ANDs `i4s` with `MASK` (0x000f000f) to extract 4-bit values
// - Then ORs the result with `I4s_TO_BF16s_MAGIC_NUM` (0x43004300) to convert them to bfloat16
#if defined(USE_ROCM)
asm volatile("v_and_or_b32 %0, %1, %2, %3"
: "=v"(h[0])
petrex marked this conversation as resolved.
Show resolved Hide resolved
: "v"(i4s), "v"(MASK), "v"(I4s_TO_BF16s_MAGIC_NUM));
#else
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[0])
: "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut));
#endif

#pragma unroll
for (int ii = 1; ii < kElements / 2; ++ii) {
i4s >>= 4; // or is it 8?
// (i4s & 0x000f000f) | 0x43004300
#if defined(USE_ROCM)
asm volatile("v_and_or_b32 %0, %1, %2, %3"
: "=v"(h[ii])
: "v"(i4s), "v"(MASK), "v"(I4s_TO_BF16s_MAGIC_NUM));
#else
asm volatile(
"lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[ii])
: "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut));
#endif
}

// This is the BF16 {-136, -136} represented as an integer.
#if defined(USE_ROCM)
#if ROCM_VERSION >= 60200
auto BF16_BIAS = __bfloat162bfloat162(__hip_bfloat16(__hip_bfloat16_raw{0xC308}));
auto BF16_ONE = __bfloat162bfloat162(__hip_bfloat16(__hip_bfloat16_raw{0x3F80}));
#else
auto BF16_BIAS = __bfloat162bfloat162(__hip_bfloat16{0xC308});
auto BF16_ONE = __bfloat162bfloat162(__hip_bfloat16{0x3F80});
#endif
#else
static constexpr uint32_t BF16_BIAS = 0xC308C308;
static constexpr uint32_t BF16_ONE = 0x3F803F80;
#endif

// Finally, we construct the output numbers.
#pragma unroll
for (int ii = 0; ii < kElements / 2; ++ii) {
// Since this section is for Ampere+, we use bf16 fma to do the bias
// subtraction
#if defined(USE_ROCM)
result.vals[ii] = __hfma2(result.vals[ii], BF16_ONE, BF16_BIAS);
#else
asm("fma.rn.bf16x2 %0, %1, %2, %3;\n"
: "=r"(h[ii])
: "r"(h[ii]), "r"(BF16_ONE), "r"(BF16_BIAS));
#endif
}

return result;
Expand Down Expand Up @@ -123,11 +167,16 @@ __global__ void _dequantize_int4_kernel(
// All b values within a 16x16 tile should fall within the same q group
// Hence we load 1 scale and zero per loop
int qgroup = ks[0] / groupSize;
const __nv_bfloat16 *pSZ = reinterpret_cast<const __nv_bfloat16*>(&scales_and_zeros.value()[qgroup][n0][0]);

// Vectorize scales and zeros
__nv_bfloat162 scale2 = __bfloat162bfloat162(pSZ[0]);
__nv_bfloat162 zero2 = __bfloat162bfloat162(pSZ[1]);
__nv_bfloat162 scale2 = __bfloat162bfloat162(__hip_bfloat16(1.0f));
__nv_bfloat162 zero2 = __bfloat162bfloat162(__hip_bfloat16(1.0f));

if (scales_and_zeros) {
const auto& sz = *scales_and_zeros;
const __nv_bfloat16* pSZ = reinterpret_cast<const __nv_bfloat16*>(&sz[qgroup][n0][0]);

scale2 = __bfloat162bfloat162(pSZ[0]);
zero2 = __bfloat162bfloat162(pSZ[1]);
}

#pragma unroll
for (int i = 0; i < 4; i++) {
Expand Down
Loading