diff --git a/csrc/grouped_gemm.cu b/csrc/grouped_gemm.cu index f40a7db..4d8ea75 100644 --- a/csrc/grouped_gemm.cu +++ b/csrc/grouped_gemm.cu @@ -222,11 +222,11 @@ void CublasGemm(c10::BFloat16 *a, int64_t a_rows, int64_t a_cols, bool trans_a, CUBLAS_GEMM_DEFAULT)); } -torch::Tensor CublasGroupedGemm(torch::Tensor a, - torch::Tensor b, - torch::Tensor c, - torch::Tensor batch_sizes, - bool trans_b) { +void CublasGroupedGemm(torch::Tensor a, + torch::Tensor b, + torch::Tensor c, + torch::Tensor batch_sizes, + bool trans_b) { int64_t bs = batch_sizes.size(0), k = a.size(1); int64_t n = trans_b ? b.size(1) : b.size(2); int64_t b_rows = b.size(1), b_cols = b.size(2); @@ -242,13 +242,12 @@ torch::Tensor CublasGroupedGemm(torch::Tensor a, b_ptr += b_rows * b_cols; c_ptr += m * n; } - return c; } -torch::Tensor CublasGroupedGemmVariableK(torch::Tensor a, - torch::Tensor b, - torch::Tensor c, - torch::Tensor batch_sizes) { +void CublasGroupedGemmVariableK(torch::Tensor a, + torch::Tensor b, + torch::Tensor c, + torch::Tensor batch_sizes) { int64_t bs = batch_sizes.size(0), m = a.size(1), n = b.size(1); c10::BFloat16* a_ptr = a.data_ptr(); c10::BFloat16* b_ptr = b.data_ptr(); @@ -262,12 +261,12 @@ torch::Tensor CublasGroupedGemmVariableK(torch::Tensor a, b_ptr += k * n; c_ptr += m * n; } - return c; } -torch::Tensor GroupedGemmVariableK(torch::Tensor a, - torch::Tensor b, - torch::Tensor batch_sizes) { +void GroupedGemmVariableK(torch::Tensor a, + torch::Tensor b, + torch::Tensor c, + torch::Tensor batch_sizes) { // We expected a CUDA tensor with two dimensions and shape // (tokens, hidden_out) for 'b'. TORCH_CHECK(b.is_cuda()); @@ -281,22 +280,27 @@ torch::Tensor GroupedGemmVariableK(torch::Tensor a, // Validate that we have the same contraction dimension. TORCH_CHECK(tokens == b.size(0)); - // Allocate the output. - auto options = torch::TensorOptions().dtype(torch::kBFloat16).device(a.device()); - torch::Tensor c = torch::empty({num_experts, m, n}, options); + // Validate the output shape. + TORCH_CHECK(c.is_cuda()); + TORCH_CHECK(c.ndimension() == 3); + TORCH_CHECK(c.scalar_type() == torch::kBFloat16); + TORCH_CHECK(c.size(0) == num_experts); + TORCH_CHECK(c.size(1) == m); + TORCH_CHECK(c.size(2) == n); // Run the computation. - return CublasGroupedGemmVariableK(a, b, c, batch_sizes); + CublasGroupedGemmVariableK(a, b, c, batch_sizes); } // NOTE: We only support dynamic group sizes for the 'a' tensor. Tensor 'b' is // assumed to be batched with fixed sized batches. // // TODO(tgale): Validate alignment is true for every batch element. -torch::Tensor GroupedGemm(torch::Tensor a, - torch::Tensor b, - torch::Tensor batch_sizes, - bool trans_a, bool trans_b) { +void GroupedGemm(torch::Tensor a, + torch::Tensor b, + torch::Tensor c, + torch::Tensor batch_sizes, + bool trans_a, bool trans_b) { // NOTE: We only support 'trans_a' or 'trans_b', not both. TORCH_CHECK(!(trans_a && trans_b)); @@ -312,7 +316,10 @@ torch::Tensor GroupedGemm(torch::Tensor a, TORCH_CHECK(a.scalar_type() == torch::kBFloat16); // Defer to the variable 'k' helper for the rest of the op. - if (trans_a) return GroupedGemmVariableK(a, b, batch_sizes); + if (trans_a) { + GroupedGemmVariableK(a, b, c, batch_sizes); + return; + } // We expected a CUDA tensor with three dimensions and shape // (num_experts, hidden_in, hidden_out) for 'b'. @@ -329,9 +336,12 @@ torch::Tensor GroupedGemm(torch::Tensor a, // Validate that we have one size per expert. TORCH_CHECK(batch_sizes.size(0) == num_experts); - // Allocate the output. - auto options = torch::TensorOptions().dtype(torch::kBFloat16).device(a.device()); - torch::Tensor c = torch::empty({tokens, hidden_out}, options); + // Validate the output shape. + TORCH_CHECK(c.is_cuda()); + TORCH_CHECK(c.ndimension() == 2); + TORCH_CHECK(c.scalar_type() == torch::kBFloat16); + TORCH_CHECK(c.size(0) == tokens); + TORCH_CHECK(c.size(1) == hidden_out); // NOTE: We support transposition through the 'trans_b' flag. TORCH_CHECK(a.is_contiguous()); @@ -339,13 +349,16 @@ torch::Tensor GroupedGemm(torch::Tensor a, // NOTE: Use cuBLAS for SM90 until CUTLASS supports SM90-optimized grouped-gemm. #if GROUPED_GEMM_DEVICE_CAPABILITY >= 90 - return CublasGroupedGemm(a, b, c, batch_sizes, trans_b); + CublasGroupedGemm(a, b, c, batch_sizes, trans_b); + return; #elif GROUPED_GEMM_DEVICE_CAPABILITY >= 80 // TODO(tgale): Support transposition with CUTLASS grouped GEMM. if (trans_b) { - return CublasGroupedGemm(a, b, c, batch_sizes, trans_b); + CublasGroupedGemm(a, b, c, batch_sizes, trans_b); + return; } - return CutlassGroupedGemm(a, b, c, batch_sizes); + CutlassGroupedGemm(a, b, c, batch_sizes); + return; #else #error "Unsupported compute capability " GROUPED_GEMM_STRINGIFY(GROUPED_GEMM_DEVICE_CAPABILITY) #endif diff --git a/csrc/grouped_gemm.h b/csrc/grouped_gemm.h index 8b4ae90..ae36a62 100644 --- a/csrc/grouped_gemm.h +++ b/csrc/grouped_gemm.h @@ -2,9 +2,10 @@ namespace grouped_gemm { -torch::Tensor GroupedGemm(torch::Tensor a, - torch::Tensor b, - torch::Tensor batch_sizes, - bool trans_a, bool trans_b); +void GroupedGemm(torch::Tensor a, + torch::Tensor b, + torch::Tensor c, + torch::Tensor batch_sizes, + bool trans_a, bool trans_b); } // namespace grouped_gemm diff --git a/csrc/ops.cu b/csrc/ops.cu index 7ad7e5d..8ea2a25 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -5,7 +5,7 @@ namespace grouped_gemm { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("grouped_gemm", &GroupedGemm, "Grouped GEMM."); + m.def("gmm", &GroupedGemm, "Grouped GEMM."); } } // namespace grouped_gemm diff --git a/grouped_gemm/__init__.py b/grouped_gemm/__init__.py index e69de29..a6d34ee 100644 --- a/grouped_gemm/__init__.py +++ b/grouped_gemm/__init__.py @@ -0,0 +1,2 @@ +import grouped_gemm.ops +import grouped_gemm.backend diff --git a/grouped_gemm/backend.py b/grouped_gemm/backend.py new file mode 100644 index 0000000..32c99c1 --- /dev/null +++ b/grouped_gemm/backend.py @@ -0,0 +1,29 @@ +# NOTE: Torch needs to be imported before the custom +# extensions. Otherwise libc10.so cannot be found. +import torch + +# TODO(tgale): Wrap this in a try-block with better +# error message and instructions for building the +# c++ operations. +import grouped_gemm_backend as backend + + +def _allocate_output(a, b, batch_sizes, trans_a, trans_b): + assert not (trans_a and trans_b) + assert batch_sizes.ndim == 1, "Expected 1d tensor for batch_sizes" + assert a.ndim == 2, "Expected 2d tensor for 'a'" + assert b.ndim == (2 if trans_a else 3) + + shape = ( + (batch_sizes.shape[0], a.shape[1], b.shape[1]) + if trans_a else + (a.shape[0], (b.shape[1] if trans_b else b.shape[2])) + ) + return torch.empty(*shape, device=a.device, dtype=a.dtype) + +def gmm(a, b, batch_sizes, trans_a=False, trans_b=False, c=None): + if c is None: + c = _allocate_output(a, b, batch_sizes, trans_a, trans_b) + backend.gmm(a, b, c, batch_sizes, trans_a, trans_b) + return c + diff --git a/grouped_gemm/ops.py b/grouped_gemm/ops.py index 2720360..1442a84 100644 --- a/grouped_gemm/ops.py +++ b/grouped_gemm/ops.py @@ -1,12 +1,6 @@ -# NOTE: Torch needs to be imported before the custom -# extensions. Otherwise libc10.so cannot be found. +from grouped_gemm import backend import torch -# TODO(tgale): Wrap this in a try-block with better -# error message and instructions for building the -# c++ operations. -import grouped_gemm_backend as backend - class GroupedGemm(torch.autograd.Function): @@ -14,7 +8,7 @@ class GroupedGemm(torch.autograd.Function): def forward(ctx, a, b, batch_sizes, trans_b): ctx.save_for_backward(a, b, batch_sizes) ctx.trans_b = trans_b - return backend.grouped_gemm(a, b, batch_sizes, False, trans_b) + return backend.gmm(a, b, batch_sizes, trans_a=False, trans_b=trans_b) @staticmethod def backward(ctx, grad): @@ -24,13 +18,14 @@ def backward(ctx, grad): agrad = None if ctx.needs_input_grad[0]: - agrad = backend.grouped_gemm( - grad, b, batch_sizes, False, not trans_b) + agrad = backend.gmm( + grad, b, batch_sizes, trans_a=False, trans_b=not trans_b) bgrad = None if ctx.needs_input_grad[1]: lhs, rhs = (grad, a) if trans_b else (a, grad) - bgrad = backend.grouped_gemm(lhs, rhs, batch_sizes, True, False) + bgrad = backend.gmm( + lhs, rhs, batch_sizes, trans_a=True, trans_b=False) return agrad, bgrad, None, None