-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Add SM80/89 blockwise scaling kernel, support FP8 block/groupwise on Ada, INT8 on Ampere #2328
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
include/cutlass/gemm/collective/sm80_mma_multistage_blockwise_scaling.hpp
Outdated
Show resolved
Hide resolved
The following are the example benchmark results on L40S with CUDA 12.4 and CUTLASS main:
|
This PR has been labeled |
Thank you very much for your work. Recently, I saw a CUTLASS PR [https://github.com//pull/2378] that uses FP16 accumulation on SM89. Do you think this can be combined with the kernel from your current PR? I experimented with it, but encountered many issues. Would you have time to work on this? Thank you very much. |
* FP8 blockwise/groupwise kernel for Ada(L20,L40S,4090) * INT8 blockwise/groupwise kernel for Ampere(A100/800)
I can try to add an option for F16 accumulator in the examples after the PR is merged. |
@sunjianxide I've attempted it, and it compiles successfully, but it's failing at runtime. I haven't been able to diagnose the cause yet. You can check the patch: diff --git a/examples/85_ada_ampere_gemm_with_blockwise_scaling/85a_ada_fp8_gemm_with_groupwise_scaling_cute.cu b/examples/85_ada_ampere_gemm_with_blockwise_scaling/85a_ada_fp8_gemm_with_groupwise_scaling_cute.cu
index f9078d0c..089eddac 100644
--- a/examples/85_ada_ampere_gemm_with_blockwise_scaling/85a_ada_fp8_gemm_with_groupwise_scaling_cute.cu
+++ b/examples/85_ada_ampere_gemm_with_blockwise_scaling/85a_ada_fp8_gemm_with_groupwise_scaling_cute.cu
@@ -73,7 +73,7 @@
using namespace cute;
-template <typename ArchTag, typename Element, int BLK_M, int BLK_N, int BLK_K, int PipelineStages = 3, int WARP_M = 2, int WARP_N = 2>
+template <typename ArchTag, typename Element, typename Accumulator, int BLK_M, int BLK_N, int BLK_K, int PipelineStages = 3, int WARP_M = 2, int WARP_N = 2>
struct SM8x_Byte_Gemm_Traits {
static constexpr int MMA_WARP_M = WARP_M * 16;
static constexpr int MMA_WARP_N = WARP_N * 16;
@@ -90,7 +90,10 @@ struct SM8x_Byte_Gemm_Traits {
);
using MmaAtom = cute::conditional_t<cute::is_same_v<Element, cutlass::float_e4m3_t>,
- MMA_Atom<SM89_16x8x32_F32E4M3E4M3F32_TN>,
+ cute::conditional_t<cute::is_same_v<Accumulator, cutlass::half_t>,
+ MMA_Atom<SM89_16x8x32_F16E4M3E4M3F16_TN>,
+ MMA_Atom<SM89_16x8x32_F32E4M3E4M3F32_TN>
+ >,
MMA_Atom<SM80_16x8x32_S32S8S8S32_TN>>;
using TileShape = Shape<Int<BLK_M>, Int<BLK_N>, Int<BLK_K>>; // Threadblock-level tile size
@@ -154,7 +157,7 @@ constexpr int PipelineStages = 4;
constexpr int BLK_M = 64;
constexpr int BLK_N = 128;
constexpr int BLK_K = 128;
-using GemmTrait = SM8x_Byte_Gemm_Traits<ArchTag, ElementA, BLK_M, BLK_N, BLK_K, PipelineStages>;
+using GemmTrait = SM8x_Byte_Gemm_Traits<ArchTag, ElementA, cutlass::half_t, BLK_M, BLK_N, BLK_K, PipelineStages>;
using TileShape = GemmTrait::TileShape;
using DispatchPolicy = cutlass::gemm::MainloopSm80CpAsyncBlockScaling<PipelineStages, GemmTrait::ClusterShape>;
diff --git a/include/cutlass/gemm/collective/sm80_mma_multistage_blockwise_scaling.hpp b/include/cutlass/gemm/collective/sm80_mma_multistage_blockwise_scaling.hpp
index 8461657b..37c19693 100644
--- a/include/cutlass/gemm/collective/sm80_mma_multistage_blockwise_scaling.hpp
+++ b/include/cutlass/gemm/collective/sm80_mma_multistage_blockwise_scaling.hpp
@@ -154,8 +154,10 @@ struct CollectiveMma<
// Block scaling gmem-to-smem copy atom
// we can have partial tiles in M or N, so don't vectorize those loads
- using CopyAtomSFA = Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<ElementBlockScale>, ElementBlockScale>;
- using CopyAtomSFB = Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<ElementBlockScale>, ElementBlockScale>;
+ using CopyAtomSFA = cute::conditional_t<sizeof(ElementBlockScale) < 4,
+ Copy_Atom<UniversalCopy<ElementBlockScale>, ElementBlockScale>,
+ Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<ElementBlockScale>, ElementBlockScale>>;
+ using CopyAtomSFB = CopyAtomSFA;
using GmemTiledCopySFA = decltype(make_tiled_copy(
CopyAtomSFA{},
Layout<Shape<Int<32>>>{}, |
Inspired by #1932 and #2037, implement blockscaling kernel on platforms before SM90.
float
)int
)