diff --git a/examples/matmul/matmul_12.cuh b/examples/matmul/matmul_12.cuh index c7766a7..520f9ef 100644 --- a/examples/matmul/matmul_12.cuh +++ b/examples/matmul/matmul_12.cuh @@ -5,17 +5,6 @@ __device__ static inline uint64_t matrix_descriptor_encode(uint64_t x) { return (((x) & 0x3FFFF) >> 0x4); } -// Descriptor for a shared memory matrix. -// Implementation is derived from PTX guide: https://docs.nvidia.com/cuda/parallel-thread-execution/#matrix-descriptor-format -__device__ uint64_t make_smem_desc(bf16* ptr) { - // Convert shared memory pointer to integer - uint32_t addr = static_cast(__cvta_generic_to_shared(ptr)); - uint64_t desc = matrix_descriptor_encode(addr); - desc |= matrix_descriptor_encode((uint64_t)16) << 16; - desc |= matrix_descriptor_encode((uint64_t)1024) << 32; - desc |= 1llu << 62; // 128B swizzle - return desc; -} __device__ void warpgroup_arrive() { asm volatile("wgmma.fence.sync.aligned;\n" ::: "memory"); @@ -56,8 +45,12 @@ int _prev_m=0, _prev_n=0, _prev_k=0; template __device__ __forceinline__ void wgmma256(float d[16][8], bf16* sA, bf16* sB) { - uint64_t desc_a = make_smem_desc(&sA[0]); - uint64_t desc_b = make_smem_desc(&sB[0]); + // descriptor will always be 0x4000004000000000 | addr + // ^ represents the 128B swizzle + // ^ represents the stride dimension + // ^^^^ will be replaced by addr + uint64_t desc_a = 0x4000004000000000 | (matrix_descriptor_encode(static_cast(__cvta_generic_to_shared(&sA[0])))); + uint64_t desc_b = 0x4000004000000000 | (matrix_descriptor_encode(static_cast(__cvta_generic_to_shared(&sB[0])))); asm volatile( "{\n" "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16 " @@ -103,8 +96,8 @@ __device__ __forceinline__ void wgmma256(float d[16][8], bf16* sA, bf16* sB) { template __device__ __forceinline__ void wgmma192(float d[12][8], bf16* sA, bf16* sB) { - uint64_t desc_a = make_smem_desc(&sA[0]); - uint64_t desc_b = make_smem_desc(&sB[0]); + uint64_t desc_a = 0x4000004000000000 | (matrix_descriptor_encode(static_cast(__cvta_generic_to_shared(&sA[0])))); + uint64_t desc_b = 0x4000004000000000 | (matrix_descriptor_encode(static_cast(__cvta_generic_to_shared(&sB[0])))); asm volatile( "{\n" "wgmma.mma_async.sync.aligned.m64n192k16.f32.bf16.bf16 " @@ -143,8 +136,8 @@ __device__ __forceinline__ void wgmma192(float d[12][8], bf16* sA, bf16* sB) { template __device__ __forceinline__ void wgmma128(float d[8][8], bf16* sA, bf16* sB) { - uint64_t desc_a = make_smem_desc(&sA[0]); - uint64_t desc_b = make_smem_desc(&sB[0]); + uint64_t desc_a = 0x4000004000000000 | (matrix_descriptor_encode(static_cast(__cvta_generic_to_shared(&sA[0])))); + uint64_t desc_b = 0x4000004000000000 | (matrix_descriptor_encode(static_cast(__cvta_generic_to_shared(&sB[0])))); asm volatile( "{\n" "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 "