Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 16 additions & 13 deletions examples/matmul/matmul_2.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,17 @@ __device__ void warpgroup_wait() {
template <int BlockMajorSize, int BlockMinorSize>
void create_tensor_map(CUtensorMap *tma_map, bf16* gmem_ptr, int blocks_height, int blocks_width) {
void* gmem_address = (void*)gmem_ptr;
uint64_t gmem_prob_shape[5] = {(uint64_t)BlockMinorSize*blocks_width, (uint64_t)BlockMajorSize*blocks_height, 1, 1, 1};
uint64_t gmem_prob_stride[5] = {sizeof(bf16), sizeof(bf16) * BlockMinorSize*blocks_width, 0, 0, 0};
uint32_t smem_box_shape[5] = {uint32_t(BlockMinorSize), uint32_t(BlockMajorSize), 1, 1, 1};
uint32_t smem_box_stride[5] = {1, 1, 1, 1, 1};
// specifying dimensions from fastest (contiguous) to slowest; in our case: K and then M/N (for A/B)
uint64_t gmem_shape[2] = {(uint64_t)BlockMinorSize*blocks_width, (uint64_t)BlockMajorSize*blocks_height};
uint64_t gmem_stride[2] = {sizeof(bf16), sizeof(bf16) * BlockMinorSize*blocks_width};
// similarly here: BK and then BM/BN (for A/B)
uint32_t smem_box_shape[2] = {uint32_t(BlockMinorSize), uint32_t(BlockMajorSize)};
uint32_t smem_box_stride[2] = {1, 1}; // when all elements of smem_box_stride are one, smem_box_shape specifies the number of elements to load
uint32_t tensor_rank = 2; // our input matrices A,B are 2D tensors (matrices)

CUresult result = cuTensorMapEncodeTiled(
tma_map, CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, 2, gmem_address, gmem_prob_shape,
gmem_prob_stride + 1, smem_box_shape, smem_box_stride, CU_TENSOR_MAP_INTERLEAVE_NONE,
tma_map, CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, tensor_rank, gmem_address, gmem_shape,
gmem_stride + 1, smem_box_shape, smem_box_stride, CU_TENSOR_MAP_INTERLEAVE_NONE,
CU_TENSOR_MAP_SWIZZLE_128B, CU_TENSOR_MAP_L2_PROMOTION_NONE, CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE);

assert(result == CUDA_SUCCESS);
Expand Down Expand Up @@ -88,8 +91,8 @@ __device__ void wgmma64(float d[4][8], bf16* sA, bf16* sB) {

template<int BM, int BN, int BK, int WGMMA_M, int WGMMA_N, int WGMMA_K, int NUM_THREADS>
__global__ void __launch_bounds__(NUM_THREADS) matmulKernel2(int M, int N, int K, bf16* C, const CUtensorMap* tensorMapA, const CUtensorMap* tensorMapB) {
__shared__ alignas(128) bf16 sA[BM*BK];
__shared__ alignas(128) bf16 sB[BK*BN];
__shared__ alignas(1024) bf16 sA[BM*BK]; // 128B alignment is required by multi dimensional bulk tensor async copy
__shared__ alignas(1024) bf16 sB[BK*BN]; // but for swizzle mode 128B this should be 1024 alignment!
float d[WGMMA_N/16][8];
static_assert(sizeof(d) * 128 == BM * BN * sizeof(float));
memset(d, 0, sizeof(d));
Expand All @@ -104,14 +107,15 @@ __global__ void __launch_bounds__(NUM_THREADS) matmulKernel2(int M, int N, int K
if (threadIdx.x == 0) {
init(&barA, blockDim.x);
init(&barB, blockDim.x);
cde::fence_proxy_async_shared_cta();
cde::fence_proxy_async_shared_cta(); // make initialized barrier visible in async proxy
}
__syncthreads();
__syncthreads(); // so that the initialized barrier is visible to all threads

barrier::arrival_token tokenA, tokenB;
for (int block_k_iter = 0; block_k_iter < num_blocks_k; ++block_k_iter) {
// Load
if (threadIdx.x == 0) {
// block_k_iter*BK, num_block_m*BM <- offsets into matrix A in GMEM for this thread block
cde::cp_async_bulk_tensor_2d_global_to_shared(&sA[0], tensorMapA, block_k_iter*BK, num_block_m*BM, barA);
tokenA = cuda::device::barrier_arrive_tx(barA, 1, sizeof(sA));
cde::cp_async_bulk_tensor_2d_global_to_shared(&sB[0], tensorMapB, block_k_iter*BK, num_block_n*BN, barB);
Expand All @@ -122,8 +126,7 @@ __global__ void __launch_bounds__(NUM_THREADS) matmulKernel2(int M, int N, int K
}
barA.wait(std::move(tokenA));
barB.wait(std::move(tokenB));
__syncthreads();


// Compute
warpgroup_arrive();
wgmma64<1, 1, 1, 0, 0>(d, &sA[0], &sB[0]);
Expand Down Expand Up @@ -152,7 +155,7 @@ __global__ void __launch_bounds__(NUM_THREADS) matmulKernel2(int M, int N, int K
block_C[IDX(row, col+1)] = d[w][1];
block_C[IDX(row+8, col)] = d[w][2];
block_C[IDX(row+8, col+1)] = d[w][3];

block_C[IDX(row, col+8)] = d[w][4];
block_C[IDX(row, col+9)] = d[w][5];
block_C[IDX(row+8, col+8)] = d[w][6];
Expand Down