Skip to content

Commit

Permalink
Add PTX wrapping functions for TMA features (#379)
Browse files Browse the repository at this point in the history
  • Loading branch information
ahendriksen committed Sep 13, 2023
1 parent 3ae6494 commit abb7cb7
Show file tree
Hide file tree
Showing 13 changed files with 1,229 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
//===----------------------------------------------------------------------===//
//
// Part of libcu++, the C++ Standard Library for your entire system,
// under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES.
//
//===----------------------------------------------------------------------===//
//
// UNSUPPORTED: libcpp-has-no-threads
// UNSUPPORTED: pre-sm-90

// <cuda/barrier>

#include <cuda/barrier>
#include <cuda/std/utility> // cuda::std::move
#include "test_macros.h" // TEST_NV_DIAG_SUPPRESS

// Suppress warning about barrier in shared memory
TEST_NV_DIAG_SUPPRESS(static_var_with_dynamic_init)

using barrier = cuda::barrier<cuda::thread_scope_block>;
namespace cde = cuda::device::experimental;

static constexpr int buf_len = 1024;
__device__ int gmem_buffer[buf_len];

__device__ void test()
{
// SETUP: fill global memory buffer
for (int i = threadIdx.x; i < buf_len; i += blockDim.x) {
gmem_buffer[i] = i;
}
// Ensure that writes to global memory are visible to others, including
// those in the async proxy.
__threadfence();
__syncthreads();

// TEST: Add i to buffer[i]
__shared__ alignas(16) int smem_buffer[buf_len];
__shared__ barrier bar;
if (threadIdx.x == 0) { init(&bar, blockDim.x); }
__syncthreads();

// Load data:
uint64_t token;
if (threadIdx.x == 0) {
cde::cp_async_bulk_global_to_shared(smem_buffer, gmem_buffer, sizeof(smem_buffer), bar);
token = cuda::device::barrier_arrive_tx(bar, 1, sizeof(smem_buffer));
} else {
token = bar.arrive();
}
bar.wait(cuda::std::move(token));

// Update in shared memory
for (int i = threadIdx.x; i < buf_len; i += blockDim.x) {
smem_buffer[i] += i;
}
cde::fence_proxy_async_shared_cta();
__syncthreads();

// Write back to global memory:
if (threadIdx.x == 0) {
cde::cp_async_bulk_shared_to_global(gmem_buffer, smem_buffer, sizeof(smem_buffer));
cde::cp_async_bulk_commit_group();
cde::cp_async_bulk_wait_group_read<0>();
}
__threadfence();
__syncthreads();

// TEAR-DOWN: check that global memory is correct
for (int i = threadIdx.x; i < buf_len; i += blockDim.x) {
assert(gmem_buffer[i] == 2 * i);
}
}

int main(int, char**)
{
NV_IF_TARGET(NV_IS_HOST,(
//Required by concurrent_agents_launch to know how many we're launching
cuda_thread_count = 512;
));

NV_DISPATCH_TARGET(
NV_IS_DEVICE, (
test();
)
);
return 0;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
//===----------------------------------------------------------------------===//
//
// Part of libcu++, the C++ Standard Library for your entire system,
// under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES.
//
//===----------------------------------------------------------------------===//
//
// UNSUPPORTED: libcpp-has-no-threads
// UNSUPPORTED: pre-sm-90

// <cuda/barrier>

#include <cuda/barrier>

#ifndef __cccl_lib_experimental_ctk12_cp_async_exposure
static_assert(false, "should define __cccl_lib_experimental_ctk12_cp_async_exposure");
#endif

int main(int, char**){
return 0;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
//===----------------------------------------------------------------------===//
//
// Part of libcu++, the C++ Standard Library for your entire system,
// under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES.
//
//===----------------------------------------------------------------------===//
//
// UNSUPPORTED: libcpp-has-no-threads
// UNSUPPORTED: pre-sm-90

// <cuda/barrier>

#include <cuda/barrier>
#include "test_macros.h" // TEST_NV_DIAG_SUPPRESS

// Suppress warning about barrier in shared memory
TEST_NV_DIAG_SUPPRESS(static_var_with_dynamic_init)

using barrier = cuda::barrier<cuda::thread_scope_block>;
namespace cde = cuda::device::experimental;

// Kernels below are intended to be compiled, but not run. This is to check if
// all generated PTX is valid.
__global__ void test_bulk_tensor(CUtensorMap* map) {
__shared__ int smem;
__shared__ barrier bar;

cde::cp_async_bulk_tensor_1d_global_to_shared(&smem, map, 0, bar);
cde::cp_async_bulk_tensor_2d_global_to_shared(&smem, map, 0, 0, bar);
cde::cp_async_bulk_tensor_3d_global_to_shared(&smem, map, 0, 0, 0, bar);
cde::cp_async_bulk_tensor_4d_global_to_shared(&smem, map, 0, 0, 0, 0, bar);
cde::cp_async_bulk_tensor_5d_global_to_shared(&smem, map, 0, 0, 0, 0, 0, bar);

cde::cp_async_bulk_tensor_1d_shared_to_global(map, 0, &smem);
cde::cp_async_bulk_tensor_2d_shared_to_global(map, 0, 0, &smem);
cde::cp_async_bulk_tensor_3d_shared_to_global(map, 0, 0, 0, &smem);
cde::cp_async_bulk_tensor_4d_shared_to_global(map, 0, 0, 0, 0, &smem);
cde::cp_async_bulk_tensor_5d_shared_to_global(map, 0, 0, 0, 0, 0, &smem);
}

__global__ void test_bulk(void * gmem) {
__shared__ int smem;
__shared__ barrier bar;
cde::cp_async_bulk_global_to_shared(&smem, gmem, 1024, bar);
cde::cp_async_bulk_shared_to_global(gmem, &smem, 1024);
}

__global__ void test_fences_async_group(void * gmem) {
cde::fence_proxy_async_shared_cta();

cde::cp_async_bulk_commit_group();
// Wait for up to 8 groups
cde::cp_async_bulk_wait_group_read<0>();
cde::cp_async_bulk_wait_group_read<1>();
cde::cp_async_bulk_wait_group_read<2>();
cde::cp_async_bulk_wait_group_read<3>();
cde::cp_async_bulk_wait_group_read<4>();
cde::cp_async_bulk_wait_group_read<5>();
cde::cp_async_bulk_wait_group_read<6>();
cde::cp_async_bulk_wait_group_read<7>();
cde::cp_async_bulk_wait_group_read<8>();
}

int main(int, char**){
return 0;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
//===----------------------------------------------------------------------===//
//
// Part of libcu++, the C++ Standard Library for your entire system,
// under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES.
//
//===----------------------------------------------------------------------===//
//
// UNSUPPORTED: libcpp-has-no-threads
// UNSUPPORTED: pre-sm-90
// UNSUPPORTED: nvrtc
// NVRTC_SKIP_KERNEL_RUN // This will have effect once PR 433 is merged (line above should be removed.)

// <cuda/barrier>

#include <cuda/barrier>
#include <cuda/std/utility> // cuda::std::move
#include "test_macros.h" // TEST_NV_DIAG_SUPPRESS

// NVRTC does not support cuda.h (due to import of stdlib.h)
#ifndef TEST_COMPILER_NVRTC
#include <cudaTypedefs.h> // PFN_cuTensorMapEncodeTiled, CUtensorMap
#endif // !TEST_COMPILER_NVRTC

// Suppress warning about barrier in shared memory
TEST_NV_DIAG_SUPPRESS(static_var_with_dynamic_init)

using barrier = cuda::barrier<cuda::thread_scope_block>;
namespace cde = cuda::device::experimental;

constexpr size_t GMEM_WIDTH = 1024; // Width of tensor (in # elements)
constexpr size_t GMEM_HEIGHT = 1024; // Height of tensor (in # elements)
constexpr size_t gmem_len = GMEM_WIDTH * GMEM_HEIGHT;

constexpr int SMEM_WIDTH = 32; // Width of shared memory buffer (in # elements)
constexpr int SMEM_HEIGHT = 8; // Height of shared memory buffer (in # elements)

static constexpr int buf_len = SMEM_HEIGHT * SMEM_WIDTH;
__device__ int gmem_tensor[gmem_len];

// We need a type with a size. On NVRTC, cuda.h cannot be imported, so we don't
// have access to the definition of CUTensorMap (only to the declaration of CUtensorMap inside
// cuda/barrier). So we use this type instead and reinterpret_cast in the
// kernel.
struct fake_cutensormap {
alignas(64) uint64_t opaque[16];
};
__constant__ fake_cutensormap global_fake_tensor_map;

__device__ void test(int base_i, int base_j)
{
CUtensorMap *global_tensor_map = reinterpret_cast<CUtensorMap*>(&global_fake_tensor_map);

// SETUP: fill global memory buffer
for (int i = threadIdx.x; i < gmem_len; i += blockDim.x) {
gmem_tensor[i] = i;
}
// Ensure that writes to global memory are visible to others, including
// those in the async proxy.
__threadfence();
__syncthreads();

// TEST: Add i to buffer[i]
__shared__ alignas(128) int smem_buffer[buf_len];
__shared__ barrier bar;
if (threadIdx.x == 0) { init(&bar, blockDim.x); }
__syncthreads();

// Load data:
uint64_t token;
if (threadIdx.x == 0) {
// Fastest moving coordinate first.
cde::cp_async_bulk_tensor_2d_global_to_shared(smem_buffer, global_tensor_map, base_j, base_i, bar);
token = cuda::device::barrier_arrive_tx(bar, 1, sizeof(smem_buffer));
} else {
token = bar.arrive();
}
bar.wait(cuda::std::move(token));

// Check smem
for (int i = 0; i < SMEM_HEIGHT; ++i) {
for (int j = 0; j < SMEM_HEIGHT; ++j) {
const int gmem_lin_idx = (base_i + i) * GMEM_WIDTH + base_j + j;
const int smem_lin_idx = i * SMEM_WIDTH + j;

assert(smem_buffer[smem_lin_idx] == gmem_lin_idx);
}
}

__syncthreads();

// Update smem
for (int i = threadIdx.x; i < buf_len; i += blockDim.x) {
smem_buffer[i] = 2 * smem_buffer[i] + 1;
}
cde::fence_proxy_async_shared_cta();
__syncthreads();

// Write back to global memory:
if (threadIdx.x == 0) {
cde::cp_async_bulk_tensor_2d_shared_to_global(global_tensor_map, base_j, base_i, smem_buffer);
cde::cp_async_bulk_commit_group();
cde::cp_async_bulk_wait_group_read<0>();
}
__threadfence();
__syncthreads();

// TEAR-DOWN: check that global memory is correct
for (int i = 0; i < SMEM_HEIGHT; ++i) {
for (int j = 0; j < SMEM_HEIGHT; ++j) {
int gmem_lin_idx = (base_i + i) * GMEM_WIDTH + base_j + j;

assert(gmem_tensor[gmem_lin_idx] == 2 * gmem_lin_idx + 1);
}
}
__syncthreads();
}

#ifndef TEST_COMPILER_NVRTC
PFN_cuTensorMapEncodeTiled get_cuTensorMapEncodeTiled() {
void* driver_ptr = nullptr;
cudaDriverEntryPointQueryResult driver_status;
auto code = cudaGetDriverEntryPoint("cuTensorMapEncodeTiled", &driver_ptr, cudaEnableDefault, &driver_status);
assert(code == cudaSuccess && "Could not get driver API");
return reinterpret_cast<PFN_cuTensorMapEncodeTiled>(driver_ptr);
}
#endif // ! TEST_COMPILER_NVRTC

int main(int, char**)
{
NV_IF_TARGET(NV_IS_HOST,(
//Required by concurrent_agents_launch to know how many we're launching
cuda_thread_count = 512;

int * tensor_ptr = nullptr;
auto code = cudaGetSymbolAddress((void**)&tensor_ptr, gmem_tensor);
assert(code == cudaSuccess && "getsymboladdress failed.");

// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html
CUtensorMap local_tensor_map{};
// rank is the number of dimensions of the array.
constexpr uint32_t rank = 2;
uint64_t size[rank] = {GMEM_WIDTH, GMEM_HEIGHT};
// The stride is the number of bytes to traverse from the first element of one row to the next.
// It must be a multiple of 16.
uint64_t stride[rank - 1] = {GMEM_WIDTH * sizeof(int)};
// The box_size is the size of the shared memory buffer that is used as the
// destination of a TMA transfer.
uint32_t box_size[rank] = {SMEM_WIDTH, SMEM_HEIGHT};
// The distance between elements in units of sizeof(element). A stride of 2
// can be used to load only the real component of a complex-valued tensor, for instance.
uint32_t elem_stride[rank] = {1, 1};

// Get a function pointer to the cuTensorMapEncodeTiled driver API.
auto cuTensorMapEncodeTiled = get_cuTensorMapEncodeTiled();

// Create the tensor descriptor.
CUresult res = cuTensorMapEncodeTiled(
&local_tensor_map, // CUtensorMap *tensorMap,
CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_INT32,
rank, // cuuint32_t tensorRank,
tensor_ptr, // void *globalAddress,
size, // const cuuint64_t *globalDim,
stride, // const cuuint64_t *globalStrides,
box_size, // const cuuint32_t *boxDim,
elem_stride, // const cuuint32_t *elementStrides,
CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE,
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE,
CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_NONE,
CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE
);

assert(res == CUDA_SUCCESS && "tensormap creation failed.");
code = cudaMemcpyToSymbol(global_fake_tensor_map, &local_tensor_map, sizeof(CUtensorMap));
assert(code == cudaSuccess && "memcpytosymbol failed.");
));

NV_DISPATCH_TARGET(
NV_IS_DEVICE, (
test(0, 0);
test(4, 0);
test(4, 4);
)
);
return 0;
}
Loading

0 comments on commit abb7cb7

Please sign in to comment.