Skip to content

Commit

Permalink
Refactor C++ ops to take output tensor.
Browse files Browse the repository at this point in the history
  • Loading branch information
tgale96 committed Sep 25, 2023
1 parent 88b3cd1 commit 28257c3
Show file tree
Hide file tree
Showing 6 changed files with 85 additions and 45 deletions.
71 changes: 42 additions & 29 deletions csrc/grouped_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -222,11 +222,11 @@ void CublasGemm(c10::BFloat16 *a, int64_t a_rows, int64_t a_cols, bool trans_a,
CUBLAS_GEMM_DEFAULT));
}

torch::Tensor CublasGroupedGemm(torch::Tensor a,
torch::Tensor b,
torch::Tensor c,
torch::Tensor batch_sizes,
bool trans_b) {
void CublasGroupedGemm(torch::Tensor a,
torch::Tensor b,
torch::Tensor c,
torch::Tensor batch_sizes,
bool trans_b) {
int64_t bs = batch_sizes.size(0), k = a.size(1);
int64_t n = trans_b ? b.size(1) : b.size(2);
int64_t b_rows = b.size(1), b_cols = b.size(2);
Expand All @@ -242,13 +242,12 @@ torch::Tensor CublasGroupedGemm(torch::Tensor a,
b_ptr += b_rows * b_cols;
c_ptr += m * n;
}
return c;
}

torch::Tensor CublasGroupedGemmVariableK(torch::Tensor a,
torch::Tensor b,
torch::Tensor c,
torch::Tensor batch_sizes) {
void CublasGroupedGemmVariableK(torch::Tensor a,
torch::Tensor b,
torch::Tensor c,
torch::Tensor batch_sizes) {
int64_t bs = batch_sizes.size(0), m = a.size(1), n = b.size(1);
c10::BFloat16* a_ptr = a.data_ptr<c10::BFloat16>();
c10::BFloat16* b_ptr = b.data_ptr<c10::BFloat16>();
Expand All @@ -262,12 +261,12 @@ torch::Tensor CublasGroupedGemmVariableK(torch::Tensor a,
b_ptr += k * n;
c_ptr += m * n;
}
return c;
}

torch::Tensor GroupedGemmVariableK(torch::Tensor a,
torch::Tensor b,
torch::Tensor batch_sizes) {
void GroupedGemmVariableK(torch::Tensor a,
torch::Tensor b,
torch::Tensor c,
torch::Tensor batch_sizes) {
// We expected a CUDA tensor with two dimensions and shape
// (tokens, hidden_out) for 'b'.
TORCH_CHECK(b.is_cuda());
Expand All @@ -281,22 +280,27 @@ torch::Tensor GroupedGemmVariableK(torch::Tensor a,
// Validate that we have the same contraction dimension.
TORCH_CHECK(tokens == b.size(0));

// Allocate the output.
auto options = torch::TensorOptions().dtype(torch::kBFloat16).device(a.device());
torch::Tensor c = torch::empty({num_experts, m, n}, options);
// Validate the output shape.
TORCH_CHECK(c.is_cuda());
TORCH_CHECK(c.ndimension() == 3);
TORCH_CHECK(c.scalar_type() == torch::kBFloat16);
TORCH_CHECK(c.size(0) == num_experts);
TORCH_CHECK(c.size(1) == m);
TORCH_CHECK(c.size(2) == n);

// Run the computation.
return CublasGroupedGemmVariableK(a, b, c, batch_sizes);
CublasGroupedGemmVariableK(a, b, c, batch_sizes);
}

// NOTE: We only support dynamic group sizes for the 'a' tensor. Tensor 'b' is
// assumed to be batched with fixed sized batches.
//
// TODO(tgale): Validate alignment is true for every batch element.
torch::Tensor GroupedGemm(torch::Tensor a,
torch::Tensor b,
torch::Tensor batch_sizes,
bool trans_a, bool trans_b) {
void GroupedGemm(torch::Tensor a,
torch::Tensor b,
torch::Tensor c,
torch::Tensor batch_sizes,
bool trans_a, bool trans_b) {
// NOTE: We only support 'trans_a' or 'trans_b', not both.
TORCH_CHECK(!(trans_a && trans_b));

Expand All @@ -312,7 +316,10 @@ torch::Tensor GroupedGemm(torch::Tensor a,
TORCH_CHECK(a.scalar_type() == torch::kBFloat16);

// Defer to the variable 'k' helper for the rest of the op.
if (trans_a) return GroupedGemmVariableK(a, b, batch_sizes);
if (trans_a) {
GroupedGemmVariableK(a, b, c, batch_sizes);
return;
}

// We expected a CUDA tensor with three dimensions and shape
// (num_experts, hidden_in, hidden_out) for 'b'.
Expand All @@ -329,23 +336,29 @@ torch::Tensor GroupedGemm(torch::Tensor a,
// Validate that we have one size per expert.
TORCH_CHECK(batch_sizes.size(0) == num_experts);

// Allocate the output.
auto options = torch::TensorOptions().dtype(torch::kBFloat16).device(a.device());
torch::Tensor c = torch::empty({tokens, hidden_out}, options);
// Validate the output shape.
TORCH_CHECK(c.is_cuda());
TORCH_CHECK(c.ndimension() == 2);
TORCH_CHECK(c.scalar_type() == torch::kBFloat16);
TORCH_CHECK(c.size(0) == tokens);
TORCH_CHECK(c.size(1) == hidden_out);

// NOTE: We support transposition through the 'trans_b' flag.
TORCH_CHECK(a.is_contiguous());
TORCH_CHECK(b.is_contiguous());

// NOTE: Use cuBLAS for SM90 until CUTLASS supports SM90-optimized grouped-gemm.
#if GROUPED_GEMM_DEVICE_CAPABILITY >= 90
return CublasGroupedGemm(a, b, c, batch_sizes, trans_b);
CublasGroupedGemm(a, b, c, batch_sizes, trans_b);
return;
#elif GROUPED_GEMM_DEVICE_CAPABILITY >= 80
// TODO(tgale): Support transposition with CUTLASS grouped GEMM.
if (trans_b) {
return CublasGroupedGemm(a, b, c, batch_sizes, trans_b);
CublasGroupedGemm(a, b, c, batch_sizes, trans_b);
return;
}
return CutlassGroupedGemm(a, b, c, batch_sizes);
CutlassGroupedGemm(a, b, c, batch_sizes);
return;
#else
#error "Unsupported compute capability " GROUPED_GEMM_STRINGIFY(GROUPED_GEMM_DEVICE_CAPABILITY)
#endif
Expand Down
9 changes: 5 additions & 4 deletions csrc/grouped_gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

namespace grouped_gemm {

torch::Tensor GroupedGemm(torch::Tensor a,
torch::Tensor b,
torch::Tensor batch_sizes,
bool trans_a, bool trans_b);
void GroupedGemm(torch::Tensor a,
torch::Tensor b,
torch::Tensor c,
torch::Tensor batch_sizes,
bool trans_a, bool trans_b);

} // namespace grouped_gemm
2 changes: 1 addition & 1 deletion csrc/ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
namespace grouped_gemm {

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("grouped_gemm", &GroupedGemm, "Grouped GEMM.");
m.def("gmm", &GroupedGemm, "Grouped GEMM.");
}

} // namespace grouped_gemm
2 changes: 2 additions & 0 deletions grouped_gemm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
import grouped_gemm.ops
import grouped_gemm.backend
29 changes: 29 additions & 0 deletions grouped_gemm/backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# NOTE: Torch needs to be imported before the custom
# extensions. Otherwise libc10.so cannot be found.
import torch

# TODO(tgale): Wrap this in a try-block with better
# error message and instructions for building the
# c++ operations.
import grouped_gemm_backend as backend


def _allocate_output(a, b, batch_sizes, trans_a, trans_b):
assert not (trans_a and trans_b)
assert batch_sizes.ndim == 1, "Expected 1d tensor for batch_sizes"
assert a.ndim == 2, "Expected 2d tensor for 'a'"
assert b.ndim == (2 if trans_a else 3)

shape = (
(batch_sizes.shape[0], a.shape[1], b.shape[1])
if trans_a else
(a.shape[0], (b.shape[1] if trans_b else b.shape[2]))
)
return torch.empty(*shape, device=a.device, dtype=a.dtype)

def gmm(a, b, batch_sizes, trans_a=False, trans_b=False, c=None):
if c is None:
c = _allocate_output(a, b, batch_sizes, trans_a, trans_b)
backend.gmm(a, b, c, batch_sizes, trans_a, trans_b)
return c

17 changes: 6 additions & 11 deletions grouped_gemm/ops.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,14 @@
# NOTE: Torch needs to be imported before the custom
# extensions. Otherwise libc10.so cannot be found.
from grouped_gemm import backend
import torch

# TODO(tgale): Wrap this in a try-block with better
# error message and instructions for building the
# c++ operations.
import grouped_gemm_backend as backend


class GroupedGemm(torch.autograd.Function):

@staticmethod
def forward(ctx, a, b, batch_sizes, trans_b):
ctx.save_for_backward(a, b, batch_sizes)
ctx.trans_b = trans_b
return backend.grouped_gemm(a, b, batch_sizes, False, trans_b)
return backend.gmm(a, b, batch_sizes, trans_a=False, trans_b=trans_b)

@staticmethod
def backward(ctx, grad):
Expand All @@ -24,13 +18,14 @@ def backward(ctx, grad):

agrad = None
if ctx.needs_input_grad[0]:
agrad = backend.grouped_gemm(
grad, b, batch_sizes, False, not trans_b)
agrad = backend.gmm(
grad, b, batch_sizes, trans_a=False, trans_b=not trans_b)

bgrad = None
if ctx.needs_input_grad[1]:
lhs, rhs = (grad, a) if trans_b else (a, grad)
bgrad = backend.grouped_gemm(lhs, rhs, batch_sizes, True, False)
bgrad = backend.gmm(
lhs, rhs, batch_sizes, trans_a=True, trans_b=False)
return agrad, bgrad, None, None


Expand Down

0 comments on commit 28257c3

Please sign in to comment.