Skip to content

Commit

Permalink
FP8 splitgemm user defined triton kernel (#263)
Browse files Browse the repository at this point in the history
* FP8 splitgemm user defined triton kernel

* yolo

* Trigger CI

* yolo

* yolo

* yolo

* yolo

* Update test_fp8.py
  • Loading branch information
msaroufim authored May 23, 2024
1 parent 5e28109 commit cb8b651
Show file tree
Hide file tree
Showing 4 changed files with 166 additions and 0 deletions.
45 changes: 45 additions & 0 deletions test/dtypes/test_fp8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import os
import unittest
import torch
from torch.testing._internal.common_utils import (
TestCase,
instantiate_parametrized_tests,
parametrize,
run_tests,
)
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_4
try:
from torchao.prototype.fp8 import gemm_split_k
triton_available = True
except ImportError:
triton_available = False

@unittest.skipIf(not triton_available, "Triton is required but not available")
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is required")
class TestFP8Gemm(TestCase):
# @parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
def test_gemm_split_k(self):
m, n, k = 256, 256, 512

a = torch.randn((m, k), dtype=torch.float16, device="cuda")
b = torch.randn((k, n), dtype=torch.float16, device="cuda")
c = gemm_split_k(a, b)
c_expected = torch.matmul(a, b)
assert torch.allclose(c, c_expected, atol=0.07) # less than this and the accuracy check fails

# https://pytorch.org/tutorials/recipes/torch_compile_user_defined_triton_kernel_tutorial.html
@unittest.skip("fp8 kernel compilation does not work on a10g")
def test_user_defined_triton_function(self):
import torch._inductor.config
torch._inductor.config.force_disable_caches = True
os.environ['TORCHINDUCTOR_COMPILE_THREADS'] = '1'
m, n, k = 256, 256, 512

a = torch.randn((m, k), dtype=torch.float16, device="cuda")
b = torch.randn((k, n), dtype=torch.float16, device="cuda")
compiled_function = torch.compile(gemm_split_k, fullgraph=True)(a,b)



if __name__ == "__main__":
run_tests()
1 change: 1 addition & 0 deletions torchao/prototype/fp8/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .splitk_gemm import gemm_split_k
119 changes: 119 additions & 0 deletions torchao/prototype/fp8/splitk_gemm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Code from https://github.com/pytorch-labs/applied-ai/blob/main/kernels/triton/inference/fp8/splitk_gemm_fp8.py
import torch
import triton
import triton.language as tl

@triton.jit
def grouped_launch(pid,
m, n,
block_m: tl.constexpr, block_n: tl.constexpr, group_m: tl.constexpr):

grid_m = tl.cdiv(m, block_m)
grid_n = tl.cdiv(n, block_n)

width = group_m * grid_n
group_id = pid // width
group_size = tl.minimum(grid_m - group_id * group_m, group_m)

pid_m = group_id * group_m + (pid % group_size)
pid_n = (pid % width) // group_size

return pid_m, pid_n


@triton.jit()
def col_major(pid,
m, n,
block_m: tl.constexpr, block_n: tl.constexpr):

grid_m = tl.cdiv(m, block_m)

pid_m = pid % grid_m
pid_n = pid // grid_m

return pid_m, pid_n


@triton.jit
def gemm_split_k_kernel(a_ptr, b_ptr, c_ptr,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
m, n, k,
block_m: tl.constexpr, block_n: tl.constexpr, block_k: tl.constexpr,
split_k: tl.constexpr, group_m: tl.constexpr):

pid = tl.program_id(0)
pid_k = tl.program_id(1)
grid_k = tl.cdiv(k, block_k*split_k)

pid_m, pid_n = grouped_launch(pid,
m, n,
block_m, block_n, group_m)

offs_m = pid_m*block_m + tl.arange(0, block_m)
offs_n = pid_n*block_n + tl.arange(0, block_n)
offs_k = pid_k*block_k + tl.arange(0, block_k)

offs_am = tl.max_contiguous(tl.multiple_of(offs_m, block_m), block_m)
offs_bn = tl.max_contiguous(tl.multiple_of(offs_n, block_n), block_n)

a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)

acc = tl.zeros((block_m, block_n), dtype=tl.float32)
for k_ in range(0, grid_k):

k_remaining = k - k_ * (block_k * split_k)

a = tl.load(a_ptrs, mask=offs_k[None, :] < k_remaining, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < k_remaining, other=0.0)

acc = tl.dot(a, b, acc, out_dtype=tl.float32)

a_ptrs += block_k * split_k * stride_ak
b_ptrs += block_k * split_k * stride_bk

acc = acc.to(tl.float16)

offs_m = pid_m*block_m + tl.arange(0, block_m)
offs_n = pid_n*block_n + tl.arange(0, block_n)

c_ptrs = c_ptr + (offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn)
mask = (offs_m < m)[:, None] & (offs_n < n)[None, :]

tl.atomic_add(c_ptrs, acc, mask=mask)

def gemm_split_k(a, b):

m, k = a.shape
_, n = b.shape

# Need to change these otherwise was getting
# triton.runtime.errors.OutOfResources: out of resource: shared memory, Required: 393216, Hardware limit: 232448. Reducing block sizes or `num_stages` may help.
# TODO: Should we tune this differently for different hardware?
block_m = 32
block_n = 32
block_k = 256
num_stages = 2
num_warps = 4
split_k = 4
group_m = 8

total_blocks_m = triton.cdiv(m, block_m)
total_blocks_n = triton.cdiv(n, block_n)
total_programs_mn = total_blocks_m * total_blocks_n
total_programs_k = split_k

grid = (total_programs_mn, total_programs_k)

c = torch.zeros((m, n), device=a.device, dtype=torch.float16)
gemm_split_k_kernel[grid](a, b, c,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
m, n, k,
block_m, block_n, block_k,
split_k, group_m, num_stages=num_stages, num_warps=num_warps)

return c
1 change: 1 addition & 0 deletions torchao/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def get_model_size_in_bytes(model):
s += b.nelement() * b.element_size()
return s

# TODO: quantization namespace is not the right place ot have this
if version.parse(torch.__version__) >= version.parse("2.4.0.dev"):
TORCH_VERSION_AFTER_2_4 = True
else:
Expand Down

0 comments on commit cb8b651

Please sign in to comment.