Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 10 additions & 17 deletions examples/matmul/matmul_12.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,6 @@ __device__ static inline uint64_t matrix_descriptor_encode(uint64_t x) {
return (((x) & 0x3FFFF) >> 0x4);
}

// Descriptor for a shared memory matrix.
// Implementation is derived from PTX guide: https://docs.nvidia.com/cuda/parallel-thread-execution/#matrix-descriptor-format
__device__ uint64_t make_smem_desc(bf16* ptr) {
// Convert shared memory pointer to integer
uint32_t addr = static_cast<uint32_t>(__cvta_generic_to_shared(ptr));
uint64_t desc = matrix_descriptor_encode(addr);
desc |= matrix_descriptor_encode((uint64_t)16) << 16;
desc |= matrix_descriptor_encode((uint64_t)1024) << 32;
desc |= 1llu << 62; // 128B swizzle
return desc;
}

__device__ void warpgroup_arrive() {
asm volatile("wgmma.fence.sync.aligned;\n" ::: "memory");
Expand Down Expand Up @@ -56,8 +45,12 @@ int _prev_m=0, _prev_n=0, _prev_k=0;

template<int ScaleD, int ScaleA, int ScaleB, int TransA, int TransB>
__device__ __forceinline__ void wgmma256(float d[16][8], bf16* sA, bf16* sB) {
uint64_t desc_a = make_smem_desc(&sA[0]);
uint64_t desc_b = make_smem_desc(&sB[0]);
// descriptor will always be 0x4000004000000000 | addr
// ^ represents the 128B swizzle
// ^ represents the stride dimension
// ^^^^ will be replaced by addr
uint64_t desc_a = 0x4000004000000000 | (matrix_descriptor_encode(static_cast<uint32_t>(__cvta_generic_to_shared(&sA[0]))));
uint64_t desc_b = 0x4000004000000000 | (matrix_descriptor_encode(static_cast<uint32_t>(__cvta_generic_to_shared(&sB[0]))));
asm volatile(
"{\n"
"wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16 "
Expand Down Expand Up @@ -103,8 +96,8 @@ __device__ __forceinline__ void wgmma256(float d[16][8], bf16* sA, bf16* sB) {

template<int ScaleD, int ScaleA, int ScaleB, int TransA, int TransB>
__device__ __forceinline__ void wgmma192(float d[12][8], bf16* sA, bf16* sB) {
uint64_t desc_a = make_smem_desc(&sA[0]);
uint64_t desc_b = make_smem_desc(&sB[0]);
uint64_t desc_a = 0x4000004000000000 | (matrix_descriptor_encode(static_cast<uint32_t>(__cvta_generic_to_shared(&sA[0]))));
uint64_t desc_b = 0x4000004000000000 | (matrix_descriptor_encode(static_cast<uint32_t>(__cvta_generic_to_shared(&sB[0]))));
asm volatile(
"{\n"
"wgmma.mma_async.sync.aligned.m64n192k16.f32.bf16.bf16 "
Expand Down Expand Up @@ -143,8 +136,8 @@ __device__ __forceinline__ void wgmma192(float d[12][8], bf16* sA, bf16* sB) {

template<int ScaleD, int ScaleA, int ScaleB, int TransA, int TransB>
__device__ __forceinline__ void wgmma128(float d[8][8], bf16* sA, bf16* sB) {
uint64_t desc_a = make_smem_desc(&sA[0]);
uint64_t desc_b = make_smem_desc(&sB[0]);
uint64_t desc_a = 0x4000004000000000 | (matrix_descriptor_encode(static_cast<uint32_t>(__cvta_generic_to_shared(&sA[0]))));
uint64_t desc_b = 0x4000004000000000 | (matrix_descriptor_encode(static_cast<uint32_t>(__cvta_generic_to_shared(&sB[0]))));
asm volatile(
"{\n"
"wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 "
Expand Down