Skip to content

Conversation

solrex
Copy link
Contributor

@solrex solrex commented May 24, 2025

Inspired by #1932 and #2037, implement blockscaling kernel on platforms before SM90.

  • FP8 blockwise/groupwise scaling kernel for Ada(L20, L40S, 4090) (Requires accumulator type to be float)
  • INT8 blockwise/groupwise scaling kernel for Ampere(A100/800, A10, A30) (Requires accumulator type to be int)
  • CUTLASS 3.x API

@solrex solrex changed the title Add SM80/89 blockwise scaling kernel, support FP8 block/groupwise on Ada, INT8 block/groupwise on Ampere Add SM80/89 blockwise scaling kernel, support FP8 block/groupwise on Ada, INT8 on Ampere May 24, 2025
@solrex solrex force-pushed the sm80-blockscale branch from 2b2a88b to 5c58e77 Compare May 26, 2025 18:03
@hwu36
Copy link
Collaborator

hwu36 commented May 28, 2025

@jackkosaian

@solrex
Copy link
Contributor Author

solrex commented May 28, 2025

The following are the example benchmark results on L40S with CUDA 12.4 and CUTLASS main:

FP8:

$ ./examples/85_ada_ampere_gemm_with_blockwise_scaling/85a_ada_fp8_gemm_with_groupwise_scaling_cute
Problem Size: 1024x1024x1024x1
  Tile shape (M, N, K): _64, _128, _128
  ScaleGranularityM: 1 (ScaleMsPerTile: 64)
  ScaleGranularityN: 128 (ScaleNsPerTile: 1)
  Running... 
  Result MSE: 2.79446e-06, MRE: 12.0697, greatest error: 0.0196838
  Disposition: Passed
  Avg runtime: 0.00905421 ms
  GFLOPS: 237181

$ ./examples/85_ada_ampere_gemm_with_blockwise_scaling/85b_ada_fp8_gemm_with_blockwise_scaling_cute
  Problem Size: 1024x1024x1024x1
  Tile shape (M, N, K): _128, _128, _128
  ScaleGranularityM: 128 (ScaleMsPerTile: 1)
  ScaleGranularityN: 128 (ScaleNsPerTile: 1)
  Running... 
  Result MSE: 2.61817e-06, MRE: 11.7382, greatest error: 0.0210075
  Disposition: Passed
  Avg runtime: 0.0233175 ms
  GFLOPS: 92097.5

INT8: 

$ ./examples/85_ada_ampere_gemm_with_blockwise_scaling/85c_ampere_int8_gemm_with_groupwise_scaling_cute
  Problem Size: 1024x1024x1024x1
  Tile shape (M, N, K): _64, _128, _128
  ScaleGranularityM: 1 (ScaleMsPerTile: 64)
  ScaleGranularityN: 128 (ScaleNsPerTile: 1)
  Running... 
  Result MSE: 0, MRE: 81.7363, greatest error: 0
  Disposition: Passed
  Avg runtime: 0.00911155 ms
  GFLOPS: 235688

$ ./examples/85_ada_ampere_gemm_with_blockwise_scaling/85d_ampere_int8_gemm_with_blockwise_scaling_cute
  Problem Size: 1024x1024x1024x1
  Tile shape (M, N, K): _128, _128, _128
  ScaleGranularityM: 128 (ScaleMsPerTile: 1)
  ScaleGranularityN: 128 (ScaleNsPerTile: 1)
  Running... 
  Result MSE: 0, MRE: 77.9124, greatest error: 0
  Disposition: Passed
  Avg runtime: 0.0239155 ms
  GFLOPS: 89794.6

Copy link

This PR has been labeled inactive-30d due to no recent activity in the past 30 days. Please close this PR if it is no longer required. Otherwise, please respond with a comment indicating any updates. This PR will be labeled inactive-90d if there is no activity in the next 60 days.

@sunjianxide
Copy link

Thank you very much for your work. Recently, I saw a CUTLASS PR [https://github.com//pull/2378] that uses FP16 accumulation on SM89. Do you think this can be combined with the kernel from your current PR? I experimented with it, but encountered many issues. Would you have time to work on this? Thank you very much.

@solrex solrex force-pushed the sm80-blockscale branch from 3521a01 to ab3b26e Compare July 10, 2025 07:27
@solrex
Copy link
Contributor Author

solrex commented Jul 10, 2025

Thank you very much for your work. Recently, I saw a CUTLASS PR [https://github.com//pull/2378] that uses FP16 accumulation on SM89. Do you think this can be combined with the kernel from your current PR? I experimented with it, but encountered many issues. Would you have time to work on this? Thank you very much.

I can try to add an option for F16 accumulator in the examples after the PR is merged.

@solrex
Copy link
Contributor Author

solrex commented Jul 31, 2025

Thank you very much for your work. Recently, I saw a CUTLASS PR #2378 that uses FP16 accumulation on SM89. Do you think this can be combined with the kernel from your current PR? I experimented with it, but encountered many issues. Would you have time to work on this? Thank you very much.

@sunjianxide I've attempted it, and it compiles successfully, but it's failing at runtime. I haven't been able to diagnose the cause yet. You can check the patch:

diff --git a/examples/85_ada_ampere_gemm_with_blockwise_scaling/85a_ada_fp8_gemm_with_groupwise_scaling_cute.cu b/examples/85_ada_ampere_gemm_with_blockwise_scaling/85a_ada_fp8_gemm_with_groupwise_scaling_cute.cu
index f9078d0c..089eddac 100644
--- a/examples/85_ada_ampere_gemm_with_blockwise_scaling/85a_ada_fp8_gemm_with_groupwise_scaling_cute.cu
+++ b/examples/85_ada_ampere_gemm_with_blockwise_scaling/85a_ada_fp8_gemm_with_groupwise_scaling_cute.cu
@@ -73,7 +73,7 @@
 
 using namespace cute;
 
-template <typename ArchTag, typename Element, int BLK_M, int BLK_N, int BLK_K, int PipelineStages = 3, int WARP_M = 2, int WARP_N = 2>
+template <typename ArchTag, typename Element, typename Accumulator, int BLK_M, int BLK_N, int BLK_K, int PipelineStages = 3, int WARP_M = 2, int WARP_N = 2>
 struct SM8x_Byte_Gemm_Traits {
   static constexpr int MMA_WARP_M = WARP_M * 16;
   static constexpr int MMA_WARP_N = WARP_N * 16;
@@ -90,7 +90,10 @@ struct SM8x_Byte_Gemm_Traits {
   );
 
   using MmaAtom = cute::conditional_t<cute::is_same_v<Element, cutlass::float_e4m3_t>,
-    MMA_Atom<SM89_16x8x32_F32E4M3E4M3F32_TN>,
+    cute::conditional_t<cute::is_same_v<Accumulator, cutlass::half_t>,
+      MMA_Atom<SM89_16x8x32_F16E4M3E4M3F16_TN>,
+      MMA_Atom<SM89_16x8x32_F32E4M3E4M3F32_TN>
+    >,
     MMA_Atom<SM80_16x8x32_S32S8S8S32_TN>>;
 
   using TileShape = Shape<Int<BLK_M>, Int<BLK_N>, Int<BLK_K>>;              // Threadblock-level tile size
@@ -154,7 +157,7 @@ constexpr int PipelineStages = 4;
 constexpr int BLK_M = 64;
 constexpr int BLK_N = 128;
 constexpr int BLK_K = 128;
-using GemmTrait = SM8x_Byte_Gemm_Traits<ArchTag, ElementA, BLK_M, BLK_N, BLK_K, PipelineStages>;
+using GemmTrait = SM8x_Byte_Gemm_Traits<ArchTag, ElementA, cutlass::half_t, BLK_M, BLK_N, BLK_K, PipelineStages>;
 using TileShape = GemmTrait::TileShape;
 using DispatchPolicy = cutlass::gemm::MainloopSm80CpAsyncBlockScaling<PipelineStages, GemmTrait::ClusterShape>;
 
diff --git a/include/cutlass/gemm/collective/sm80_mma_multistage_blockwise_scaling.hpp b/include/cutlass/gemm/collective/sm80_mma_multistage_blockwise_scaling.hpp
index 8461657b..37c19693 100644
--- a/include/cutlass/gemm/collective/sm80_mma_multistage_blockwise_scaling.hpp
+++ b/include/cutlass/gemm/collective/sm80_mma_multistage_blockwise_scaling.hpp
@@ -154,8 +154,10 @@ struct CollectiveMma<
 
   // Block scaling gmem-to-smem copy atom
   //  we can have partial tiles in M or N, so don't vectorize those loads
-  using CopyAtomSFA = Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<ElementBlockScale>, ElementBlockScale>;
-  using CopyAtomSFB = Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<ElementBlockScale>, ElementBlockScale>;
+  using CopyAtomSFA = cute::conditional_t<sizeof(ElementBlockScale) < 4,
+    Copy_Atom<UniversalCopy<ElementBlockScale>, ElementBlockScale>,
+    Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<ElementBlockScale>, ElementBlockScale>>;
+  using CopyAtomSFB = CopyAtomSFA;
   using GmemTiledCopySFA = decltype(make_tiled_copy(
     CopyAtomSFA{},
     Layout<Shape<Int<32>>>{},

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants