Skip to content

Commit

Permalink
[Kernel] Factor out epilogues from cutlass kernels (vllm-project#5391)
Browse files Browse the repository at this point in the history
Co-authored-by: Michael Goin <[email protected]>
Co-authored-by: youkaichao <[email protected]>
Co-authored-by: zifeitong <[email protected]>
Co-authored-by: Robert Shaw <[email protected]>
  • Loading branch information
5 people authored Jun 13, 2024
1 parent 02b6058 commit 5529939
Show file tree
Hide file tree
Showing 12 changed files with 274 additions and 232 deletions.
8 changes: 4 additions & 4 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -179,17 +179,17 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"csrc/quantization/gptq_marlin/gptq_marlin.cu"
"csrc/quantization/gptq_marlin/gptq_marlin_repack.cu"
"csrc/custom_all_reduce.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu")
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu")

#
# The CUTLASS kernels for Hopper require sm90a to be enabled.
# This is done via the below gencode option, BUT that creates kernels for both sm90 and sm90a.
# That adds an extra 17MB to compiled binary, so instead we selectively enable it.
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0)
set_source_files_properties(
"csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu"
PROPERTIES
COMPILE_FLAGS
"-gencode arch=compute_90a,code=sm_90a")
Expand Down
6 changes: 1 addition & 5 deletions benchmarks/cutlass_benchmarks/w8a8_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,7 @@ def pytorch_fp8_impl_fast_accum(a: torch.tensor, b: torch.tensor,
def cutlass_impl(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor,
scale_b: torch.tensor,
out_dtype: torch.dtype) -> torch.tensor:
return ops.cutlass_scaled_mm_dq(a,
b,
scale_a,
scale_b,
out_dtype=out_dtype)
return ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype=out_dtype)


# bench
Expand Down
6 changes: 3 additions & 3 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,9 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
int64_t size_k, int64_t size_n,
int64_t num_bits);

void cutlass_scaled_mm_dq(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& a_scales,
torch::Tensor const& b_scales);
void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& a_scales,
torch::Tensor const& b_scales);

#endif

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,21 +29,14 @@
using namespace cute;

/*
This defines a quantized GEMM operation with dequantized output, similar to
torch._scaled_mm. It is defined using the CUTLASS 2.x API, and is used for
This file defines quantized GEMM operations using the CUTLASS 2.x API, for
NVIDIA GPUs with SM versions prior to sm90 (Hopper).
A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or
per-row. B can be quantized per-tensor or per-column.
Any combination of per-tensor and per-row or column is supported.
A and B must have symmetric quantization (zero point == 0).
So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
scales are applied elementwise with numpy-style broadcasting.
ScaleA and ScaleB define the epilogue functions that apply the scales for
the A and B operands respectively. These scales may be either per-tensor or
per row or column.
Epilogue functions can be defined to post-process the output before it is
written to GPU memory.
Epilogues must contain a public type named EVTCompute of type Sm80EVT,
as well as a static prepare_args function that constructs an
EVTCompute::Arguments struct.
*/

namespace {
Expand Down Expand Up @@ -83,27 +76,25 @@ struct enable_sm89_to_sm90 : Kernel {
}
};

template <typename Arch, template <typename> typename ArchGuard,
typename ElementAB_, typename ElementD_, typename TileShape,
typename WarpShape, typename InstructionShape, int32_t MainLoopStages>
struct cutlass_2x_gemm {
using ElementAB = ElementAB_;
using ElementD = ElementD_;

using ElementAcc =
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
float>::type;
/*
This epilogue function defines a quantized GEMM operation similar to
torch._scaled_mm.
using Operator =
typename std::conditional<std::is_same_v<ElementAB, int8_t>,
cutlass::arch::OpMultiplyAddSaturate,
cutlass::arch::OpMultiplyAdd>::type;
A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or
per-row. B can be quantized per-tensor or per-column.
Any combination of per-tensor and per-row or column is supported.
A and B must have symmetric quantization (zero point == 0).
using OutputTileThreadMap =
cutlass::epilogue::threadblock::OutputTileThreadLayout<
TileShape, WarpShape, float, 4, 1 /* epilogue stages */
>;
So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
scales are applied elementwise with numpy-style broadcasting.
ScaleA and ScaleB define the epilogue functions that apply the scales for
the A and B operands respectively. These scales may be either per-tensor or
per row or column.
*/
template <typename ElementD, typename OutputTileThreadMap>
struct ScaledEpilogue {
private:
using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;

using ScaleA = cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast<
Expand All @@ -123,14 +114,56 @@ struct cutlass_2x_gemm {
cutlass::multiplies, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;

using EVTCompute1 =
public:
using EVTCompute =
cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA, EVTCompute0>;
using ArgumentType = typename EVTCompute::Arguments;

static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
using ScaleAArgs = typename ScaleA::Arguments;
using ScaleBArgs = typename ScaleB::Arguments;

ScaleBArgs b_args{b_scales.data_ptr<float>(), b_scales.numel() != 1, {}};
ScaleAArgs a_args{a_scales.data_ptr<float>(), a_scales.numel() != 1, {}};

typename EVTCompute0::Arguments evt0_compute_args{b_args};

typename EVTCompute::Arguments evt_compute_args{a_args, evt0_compute_args};
return evt_compute_args;
}
};

template <typename Arch, template <typename> typename ArchGuard,
typename ElementAB_, typename ElementD_,
template <typename, typename> typename Epilogue_, typename TileShape,
typename WarpShape, typename InstructionShape, int32_t MainLoopStages>
struct cutlass_2x_gemm {
using ElementAB = ElementAB_;
using ElementD = ElementD_;

using ElementAcc =
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
float>::type;

using Operator =
typename std::conditional<std::is_same_v<ElementAB, int8_t>,
cutlass::arch::OpMultiplyAddSaturate,
cutlass::arch::OpMultiplyAdd>::type;

using OutputTileThreadMap =
cutlass::epilogue::threadblock::OutputTileThreadLayout<
TileShape, WarpShape, float, 4, 1 /* epilogue stages */
>;

using Epilogue = Epilogue_<ElementD, OutputTileThreadMap>;
using EVTCompute = typename Epilogue::EVTCompute;

using D = cutlass::epilogue::threadblock::VisitorAuxStore<
OutputTileThreadMap, ElementD, cutlass::FloatRoundStyle::round_to_nearest,
Stride<int64_t, Int<1>, Int<0>>>;

using EVTD = cutlass::epilogue::threadblock::Sm80EVT<D, EVTCompute1>;
using EVTD = cutlass::epilogue::threadblock::Sm80EVT<D, EVTCompute>;

// clang-format off
using RowMajor = typename cutlass::layout::RowMajor;
Expand All @@ -153,11 +186,10 @@ struct cutlass_2x_gemm {
using Op = cutlass::gemm::device::GemmUniversalAdapter<KernelType>;
};

template <typename Gemm>
void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
template <typename Gemm, typename... EpilogueArgs>
void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
EpilogueArgs&&... epilogue_params) {
using ElementAB = typename Gemm::ElementAB;
using ElementD = typename Gemm::ElementD;

Expand All @@ -177,23 +209,14 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a,
auto b_ptr = static_cast<ElementAB const*>(b.data_ptr());
auto c_ptr = static_cast<ElementD*>(out.data_ptr());

auto a_scales_ptr = a_scales.data_ptr<float>();
auto b_scales_ptr = b_scales.data_ptr<float>();

using ScaleAArgs = typename Gemm::ScaleA::Arguments;
using ScaleBArgs = typename Gemm::ScaleB::Arguments;

ScaleBArgs b_args{b_scales.data_ptr<float>(), b_scales.numel() != 1, {}};
ScaleAArgs a_args{a_scales.data_ptr<float>(), a_scales.numel() != 1, {}};

typename Gemm::EVTCompute0::Arguments evt0_compute_args{b_args};

typename Gemm::EVTCompute1::Arguments evt1_compute_args{a_args,
evt0_compute_args};
typename Gemm::D::Arguments d_args{c_ptr, c_stride};

using Epilogue = typename Gemm::Epilogue;
auto evt_args =
Epilogue::prepare_args(std::forward<EpilogueArgs>(epilogue_params)...);

typename Gemm::EVTD::Arguments epilogue_args{
evt1_compute_args,
evt_args,
d_args,
};

Expand Down Expand Up @@ -229,10 +252,10 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a,

} // namespace

void cutlass_scaled_mm_dq_sm75(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
void cutlass_scaled_mm_sm75(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
TORCH_CHECK(a.dtype() == torch::kInt8);
TORCH_CHECK(b.dtype() == torch::kInt8);
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
Expand All @@ -243,23 +266,23 @@ void cutlass_scaled_mm_dq_sm75(torch::Tensor& out, torch::Tensor const& a,
using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>;

if (out.dtype() == torch::kBFloat16) {
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
return cutlass_gemm_caller<cutlass_2x_gemm<
cutlass::arch::Sm75, enable_sm75_to_sm80, int8_t, cutlass::bfloat16_t,
TileShape, WarpShape, InstructionShape, 2>>(out, a, b, a_scales,
b_scales);
ScaledEpilogue, TileShape, WarpShape, InstructionShape, 2>>(
out, a, b, a_scales, b_scales);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
return cutlass_gemm_caller<cutlass_2x_gemm<
cutlass::arch::Sm75, enable_sm75_to_sm80, int8_t, cutlass::half_t,
TileShape, WarpShape, InstructionShape, 2>>(out, a, b, a_scales,
b_scales);
ScaledEpilogue, TileShape, WarpShape, InstructionShape, 2>>(
out, a, b, a_scales, b_scales);
}
}

void cutlass_scaled_mm_dq_sm80(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
void cutlass_scaled_mm_sm80(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
TORCH_CHECK(a.dtype() == torch::kInt8);
TORCH_CHECK(b.dtype() == torch::kInt8);
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
Expand All @@ -270,23 +293,23 @@ void cutlass_scaled_mm_dq_sm80(torch::Tensor& out, torch::Tensor const& a,
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;

if (out.dtype() == torch::kBFloat16) {
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
return cutlass_gemm_caller<cutlass_2x_gemm<
cutlass::arch::Sm80, enable_sm80_to_sm89, int8_t, cutlass::bfloat16_t,
TileShape, WarpShape, InstructionShape, 5>>(out, a, b, a_scales,
b_scales);
ScaledEpilogue, TileShape, WarpShape, InstructionShape, 5>>(
out, a, b, a_scales, b_scales);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
return cutlass_gemm_caller<cutlass_2x_gemm<
cutlass::arch::Sm80, enable_sm80_to_sm89, int8_t, cutlass::half_t,
TileShape, WarpShape, InstructionShape, 5>>(out, a, b, a_scales,
b_scales);
ScaledEpilogue, TileShape, WarpShape, InstructionShape, 5>>(
out, a, b, a_scales, b_scales);
}
}

void cutlass_scaled_mm_dq_sm89(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
void cutlass_scaled_mm_sm89(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
Expand All @@ -298,32 +321,32 @@ void cutlass_scaled_mm_dq_sm89(torch::Tensor& out, torch::Tensor const& a,
TORCH_CHECK(b.dtype() == torch::kInt8);

if (out.dtype() == torch::kBFloat16) {
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
return cutlass_gemm_caller<cutlass_2x_gemm<
cutlass::arch::Sm89, enable_sm89_to_sm90, int8_t, cutlass::bfloat16_t,
TileShape, WarpShape, InstructionShape, 5>>(out, a, b, a_scales,
b_scales);
ScaledEpilogue, TileShape, WarpShape, InstructionShape, 5>>(
out, a, b, a_scales, b_scales);
} else {
assert(out.dtype() == torch::kFloat16);
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
return cutlass_gemm_caller<cutlass_2x_gemm<
cutlass::arch::Sm89, enable_sm89_to_sm90, int8_t, cutlass::half_t,
TileShape, WarpShape, InstructionShape, 5>>(out, a, b, a_scales,
b_scales);
ScaledEpilogue, TileShape, WarpShape, InstructionShape, 5>>(
out, a, b, a_scales, b_scales);
}
} else {
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);

if (out.dtype() == torch::kBFloat16) {
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
return cutlass_gemm_caller<cutlass_2x_gemm<
cutlass::arch::Sm89, enable_sm89_to_sm90, cutlass::float_e4m3_t,
cutlass::bfloat16_t, TileShape, WarpShape, InstructionShape, 5>>(
out, a, b, a_scales, b_scales);
cutlass::bfloat16_t, ScaledEpilogue, TileShape, WarpShape,
InstructionShape, 5>>(out, a, b, a_scales, b_scales);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
return cutlass_gemm_caller<cutlass_2x_gemm<
cutlass::arch::Sm89, enable_sm89_to_sm90, cutlass::float_e4m3_t,
cutlass::half_t, TileShape, WarpShape, InstructionShape, 5>>(
out, a, b, a_scales, b_scales);
cutlass::half_t, ScaledEpilogue, TileShape, WarpShape,
InstructionShape, 5>>(out, a, b, a_scales, b_scales);
}
}
}
Loading

0 comments on commit 5529939

Please sign in to comment.