diff --git a/examples/matmul/matmul_2.cuh b/examples/matmul/matmul_2.cuh index 4e8cefd..f56bdba 100644 --- a/examples/matmul/matmul_2.cuh +++ b/examples/matmul/matmul_2.cuh @@ -34,14 +34,17 @@ __device__ void warpgroup_wait() { template void create_tensor_map(CUtensorMap *tma_map, bf16* gmem_ptr, int blocks_height, int blocks_width) { void* gmem_address = (void*)gmem_ptr; - uint64_t gmem_prob_shape[5] = {(uint64_t)BlockMinorSize*blocks_width, (uint64_t)BlockMajorSize*blocks_height, 1, 1, 1}; - uint64_t gmem_prob_stride[5] = {sizeof(bf16), sizeof(bf16) * BlockMinorSize*blocks_width, 0, 0, 0}; - uint32_t smem_box_shape[5] = {uint32_t(BlockMinorSize), uint32_t(BlockMajorSize), 1, 1, 1}; - uint32_t smem_box_stride[5] = {1, 1, 1, 1, 1}; + // specifying dimensions from fastest (contiguous) to slowest; in our case: K and then M/N (for A/B) + uint64_t gmem_shape[2] = {(uint64_t)BlockMinorSize*blocks_width, (uint64_t)BlockMajorSize*blocks_height}; + uint64_t gmem_stride[2] = {sizeof(bf16), sizeof(bf16) * BlockMinorSize*blocks_width}; + // similarly here: BK and then BM/BN (for A/B) + uint32_t smem_box_shape[2] = {uint32_t(BlockMinorSize), uint32_t(BlockMajorSize)}; + uint32_t smem_box_stride[2] = {1, 1}; // when all elements of smem_box_stride are one, smem_box_shape specifies the number of elements to load + uint32_t tensor_rank = 2; // our input matrices A,B are 2D tensors (matrices) CUresult result = cuTensorMapEncodeTiled( - tma_map, CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, 2, gmem_address, gmem_prob_shape, - gmem_prob_stride + 1, smem_box_shape, smem_box_stride, CU_TENSOR_MAP_INTERLEAVE_NONE, + tma_map, CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, tensor_rank, gmem_address, gmem_shape, + gmem_stride + 1, smem_box_shape, smem_box_stride, CU_TENSOR_MAP_INTERLEAVE_NONE, CU_TENSOR_MAP_SWIZZLE_128B, CU_TENSOR_MAP_L2_PROMOTION_NONE, CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); assert(result == CUDA_SUCCESS); @@ -88,8 +91,8 @@ __device__ void wgmma64(float d[4][8], bf16* sA, bf16* sB) { template __global__ void __launch_bounds__(NUM_THREADS) matmulKernel2(int M, int N, int K, bf16* C, const CUtensorMap* tensorMapA, const CUtensorMap* tensorMapB) { - __shared__ alignas(128) bf16 sA[BM*BK]; - __shared__ alignas(128) bf16 sB[BK*BN]; + __shared__ alignas(1024) bf16 sA[BM*BK]; // 128B alignment is required by multi dimensional bulk tensor async copy + __shared__ alignas(1024) bf16 sB[BK*BN]; // but for swizzle mode 128B this should be 1024 alignment! float d[WGMMA_N/16][8]; static_assert(sizeof(d) * 128 == BM * BN * sizeof(float)); memset(d, 0, sizeof(d)); @@ -104,14 +107,15 @@ __global__ void __launch_bounds__(NUM_THREADS) matmulKernel2(int M, int N, int K if (threadIdx.x == 0) { init(&barA, blockDim.x); init(&barB, blockDim.x); - cde::fence_proxy_async_shared_cta(); + cde::fence_proxy_async_shared_cta(); // make initialized barrier visible in async proxy } - __syncthreads(); + __syncthreads(); // so that the initialized barrier is visible to all threads barrier::arrival_token tokenA, tokenB; for (int block_k_iter = 0; block_k_iter < num_blocks_k; ++block_k_iter) { // Load if (threadIdx.x == 0) { + // block_k_iter*BK, num_block_m*BM <- offsets into matrix A in GMEM for this thread block cde::cp_async_bulk_tensor_2d_global_to_shared(&sA[0], tensorMapA, block_k_iter*BK, num_block_m*BM, barA); tokenA = cuda::device::barrier_arrive_tx(barA, 1, sizeof(sA)); cde::cp_async_bulk_tensor_2d_global_to_shared(&sB[0], tensorMapB, block_k_iter*BK, num_block_n*BN, barB); @@ -122,8 +126,7 @@ __global__ void __launch_bounds__(NUM_THREADS) matmulKernel2(int M, int N, int K } barA.wait(std::move(tokenA)); barB.wait(std::move(tokenB)); - __syncthreads(); - + // Compute warpgroup_arrive(); wgmma64<1, 1, 1, 0, 0>(d, &sA[0], &sB[0]); @@ -152,7 +155,7 @@ __global__ void __launch_bounds__(NUM_THREADS) matmulKernel2(int M, int N, int K block_C[IDX(row, col+1)] = d[w][1]; block_C[IDX(row+8, col)] = d[w][2]; block_C[IDX(row+8, col+1)] = d[w][3]; - + block_C[IDX(row, col+8)] = d[w][4]; block_C[IDX(row, col+9)] = d[w][5]; block_C[IDX(row+8, col+8)] = d[w][6];