diff --git a/libcudacxx/.upstream-tests/test/cuda/barrier/cp_async_bulk.pass.cpp b/libcudacxx/.upstream-tests/test/cuda/barrier/cp_async_bulk.pass.cpp new file mode 100644 index 00000000000..d32aea85727 --- /dev/null +++ b/libcudacxx/.upstream-tests/test/cuda/barrier/cp_async_bulk.pass.cpp @@ -0,0 +1,91 @@ +//===----------------------------------------------------------------------===// +// +// Part of libcu++, the C++ Standard Library for your entire system, +// under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. +// +//===----------------------------------------------------------------------===// +// +// UNSUPPORTED: libcpp-has-no-threads +// UNSUPPORTED: pre-sm-90 + +// + +#include +#include // cuda::std::move +#include "test_macros.h" // TEST_NV_DIAG_SUPPRESS + +// Suppress warning about barrier in shared memory +TEST_NV_DIAG_SUPPRESS(static_var_with_dynamic_init) + +using barrier = cuda::barrier; +namespace cde = cuda::device::experimental; + +static constexpr int buf_len = 1024; +__device__ int gmem_buffer[buf_len]; + +__device__ void test() +{ + // SETUP: fill global memory buffer + for (int i = threadIdx.x; i < buf_len; i += blockDim.x) { + gmem_buffer[i] = i; + } + // Ensure that writes to global memory are visible to others, including + // those in the async proxy. + __threadfence(); + __syncthreads(); + + // TEST: Add i to buffer[i] + __shared__ alignas(16) int smem_buffer[buf_len]; + __shared__ barrier bar; + if (threadIdx.x == 0) { init(&bar, blockDim.x); } + __syncthreads(); + + // Load data: + uint64_t token; + if (threadIdx.x == 0) { + cde::cp_async_bulk_global_to_shared(smem_buffer, gmem_buffer, sizeof(smem_buffer), bar); + token = cuda::device::barrier_arrive_tx(bar, 1, sizeof(smem_buffer)); + } else { + token = bar.arrive(); + } + bar.wait(cuda::std::move(token)); + + // Update in shared memory + for (int i = threadIdx.x; i < buf_len; i += blockDim.x) { + smem_buffer[i] += i; + } + cde::fence_proxy_async_shared_cta(); + __syncthreads(); + + // Write back to global memory: + if (threadIdx.x == 0) { + cde::cp_async_bulk_shared_to_global(gmem_buffer, smem_buffer, sizeof(smem_buffer)); + cde::cp_async_bulk_commit_group(); + cde::cp_async_bulk_wait_group_read<0>(); + } + __threadfence(); + __syncthreads(); + + // TEAR-DOWN: check that global memory is correct + for (int i = threadIdx.x; i < buf_len; i += blockDim.x) { + assert(gmem_buffer[i] == 2 * i); + } +} + +int main(int, char**) +{ + NV_IF_TARGET(NV_IS_HOST,( + //Required by concurrent_agents_launch to know how many we're launching + cuda_thread_count = 512; + )); + + NV_DISPATCH_TARGET( + NV_IS_DEVICE, ( + test(); + ) + ); + return 0; +} diff --git a/libcudacxx/.upstream-tests/test/cuda/barrier/cp_async_bulk_feature_test.pass.cpp b/libcudacxx/.upstream-tests/test/cuda/barrier/cp_async_bulk_feature_test.pass.cpp new file mode 100644 index 00000000000..e98ee5e6aad --- /dev/null +++ b/libcudacxx/.upstream-tests/test/cuda/barrier/cp_async_bulk_feature_test.pass.cpp @@ -0,0 +1,24 @@ +//===----------------------------------------------------------------------===// +// +// Part of libcu++, the C++ Standard Library for your entire system, +// under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. +// +//===----------------------------------------------------------------------===// +// +// UNSUPPORTED: libcpp-has-no-threads +// UNSUPPORTED: pre-sm-90 + +// + +#include + +#ifndef __cccl_lib_experimental_ctk12_cp_async_exposure +static_assert(false, "should define __cccl_lib_experimental_ctk12_cp_async_exposure"); +#endif + +int main(int, char**){ + return 0; +} diff --git a/libcudacxx/.upstream-tests/test/cuda/barrier/cp_async_bulk_ptx_compiles.pass.cpp b/libcudacxx/.upstream-tests/test/cuda/barrier/cp_async_bulk_ptx_compiles.pass.cpp new file mode 100644 index 00000000000..eb4e9b876d8 --- /dev/null +++ b/libcudacxx/.upstream-tests/test/cuda/barrier/cp_async_bulk_ptx_compiles.pass.cpp @@ -0,0 +1,69 @@ +//===----------------------------------------------------------------------===// +// +// Part of libcu++, the C++ Standard Library for your entire system, +// under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. +// +//===----------------------------------------------------------------------===// +// +// UNSUPPORTED: libcpp-has-no-threads +// UNSUPPORTED: pre-sm-90 + +// + +#include +#include "test_macros.h" // TEST_NV_DIAG_SUPPRESS + +// Suppress warning about barrier in shared memory +TEST_NV_DIAG_SUPPRESS(static_var_with_dynamic_init) + +using barrier = cuda::barrier; +namespace cde = cuda::device::experimental; + +// Kernels below are intended to be compiled, but not run. This is to check if +// all generated PTX is valid. +__global__ void test_bulk_tensor(CUtensorMap* map) { + __shared__ int smem; + __shared__ barrier bar; + + cde::cp_async_bulk_tensor_1d_global_to_shared(&smem, map, 0, bar); + cde::cp_async_bulk_tensor_2d_global_to_shared(&smem, map, 0, 0, bar); + cde::cp_async_bulk_tensor_3d_global_to_shared(&smem, map, 0, 0, 0, bar); + cde::cp_async_bulk_tensor_4d_global_to_shared(&smem, map, 0, 0, 0, 0, bar); + cde::cp_async_bulk_tensor_5d_global_to_shared(&smem, map, 0, 0, 0, 0, 0, bar); + + cde::cp_async_bulk_tensor_1d_shared_to_global(map, 0, &smem); + cde::cp_async_bulk_tensor_2d_shared_to_global(map, 0, 0, &smem); + cde::cp_async_bulk_tensor_3d_shared_to_global(map, 0, 0, 0, &smem); + cde::cp_async_bulk_tensor_4d_shared_to_global(map, 0, 0, 0, 0, &smem); + cde::cp_async_bulk_tensor_5d_shared_to_global(map, 0, 0, 0, 0, 0, &smem); +} + +__global__ void test_bulk(void * gmem) { + __shared__ int smem; + __shared__ barrier bar; + cde::cp_async_bulk_global_to_shared(&smem, gmem, 1024, bar); + cde::cp_async_bulk_shared_to_global(gmem, &smem, 1024); +} + +__global__ void test_fences_async_group(void * gmem) { + cde::fence_proxy_async_shared_cta(); + + cde::cp_async_bulk_commit_group(); + // Wait for up to 8 groups + cde::cp_async_bulk_wait_group_read<0>(); + cde::cp_async_bulk_wait_group_read<1>(); + cde::cp_async_bulk_wait_group_read<2>(); + cde::cp_async_bulk_wait_group_read<3>(); + cde::cp_async_bulk_wait_group_read<4>(); + cde::cp_async_bulk_wait_group_read<5>(); + cde::cp_async_bulk_wait_group_read<6>(); + cde::cp_async_bulk_wait_group_read<7>(); + cde::cp_async_bulk_wait_group_read<8>(); +} + +int main(int, char**){ + return 0; +} diff --git a/libcudacxx/.upstream-tests/test/cuda/barrier/cp_async_bulk_tensor.pass.cpp b/libcudacxx/.upstream-tests/test/cuda/barrier/cp_async_bulk_tensor.pass.cpp new file mode 100644 index 00000000000..fb58b583153 --- /dev/null +++ b/libcudacxx/.upstream-tests/test/cuda/barrier/cp_async_bulk_tensor.pass.cpp @@ -0,0 +1,188 @@ +//===----------------------------------------------------------------------===// +// +// Part of libcu++, the C++ Standard Library for your entire system, +// under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. +// +//===----------------------------------------------------------------------===// +// +// UNSUPPORTED: libcpp-has-no-threads +// UNSUPPORTED: pre-sm-90 +// UNSUPPORTED: nvrtc +// NVRTC_SKIP_KERNEL_RUN // This will have effect once PR 433 is merged (line above should be removed.) + +// + +#include +#include // cuda::std::move +#include "test_macros.h" // TEST_NV_DIAG_SUPPRESS + +// NVRTC does not support cuda.h (due to import of stdlib.h) +#ifndef TEST_COMPILER_NVRTC +#include // PFN_cuTensorMapEncodeTiled, CUtensorMap +#endif // !TEST_COMPILER_NVRTC + +// Suppress warning about barrier in shared memory +TEST_NV_DIAG_SUPPRESS(static_var_with_dynamic_init) + +using barrier = cuda::barrier; +namespace cde = cuda::device::experimental; + +constexpr size_t GMEM_WIDTH = 1024; // Width of tensor (in # elements) +constexpr size_t GMEM_HEIGHT = 1024; // Height of tensor (in # elements) +constexpr size_t gmem_len = GMEM_WIDTH * GMEM_HEIGHT; + +constexpr int SMEM_WIDTH = 32; // Width of shared memory buffer (in # elements) +constexpr int SMEM_HEIGHT = 8; // Height of shared memory buffer (in # elements) + +static constexpr int buf_len = SMEM_HEIGHT * SMEM_WIDTH; +__device__ int gmem_tensor[gmem_len]; + +// We need a type with a size. On NVRTC, cuda.h cannot be imported, so we don't +// have access to the definition of CUTensorMap (only to the declaration of CUtensorMap inside +// cuda/barrier). So we use this type instead and reinterpret_cast in the +// kernel. +struct fake_cutensormap { + alignas(64) uint64_t opaque[16]; +}; +__constant__ fake_cutensormap global_fake_tensor_map; + +__device__ void test(int base_i, int base_j) +{ + CUtensorMap *global_tensor_map = reinterpret_cast(&global_fake_tensor_map); + + // SETUP: fill global memory buffer + for (int i = threadIdx.x; i < gmem_len; i += blockDim.x) { + gmem_tensor[i] = i; + } + // Ensure that writes to global memory are visible to others, including + // those in the async proxy. + __threadfence(); + __syncthreads(); + + // TEST: Add i to buffer[i] + __shared__ alignas(128) int smem_buffer[buf_len]; + __shared__ barrier bar; + if (threadIdx.x == 0) { init(&bar, blockDim.x); } + __syncthreads(); + + // Load data: + uint64_t token; + if (threadIdx.x == 0) { + // Fastest moving coordinate first. + cde::cp_async_bulk_tensor_2d_global_to_shared(smem_buffer, global_tensor_map, base_j, base_i, bar); + token = cuda::device::barrier_arrive_tx(bar, 1, sizeof(smem_buffer)); + } else { + token = bar.arrive(); + } + bar.wait(cuda::std::move(token)); + + // Check smem + for (int i = 0; i < SMEM_HEIGHT; ++i) { + for (int j = 0; j < SMEM_HEIGHT; ++j) { + const int gmem_lin_idx = (base_i + i) * GMEM_WIDTH + base_j + j; + const int smem_lin_idx = i * SMEM_WIDTH + j; + + assert(smem_buffer[smem_lin_idx] == gmem_lin_idx); + } + } + + __syncthreads(); + + // Update smem + for (int i = threadIdx.x; i < buf_len; i += blockDim.x) { + smem_buffer[i] = 2 * smem_buffer[i] + 1; + } + cde::fence_proxy_async_shared_cta(); + __syncthreads(); + + // Write back to global memory: + if (threadIdx.x == 0) { + cde::cp_async_bulk_tensor_2d_shared_to_global(global_tensor_map, base_j, base_i, smem_buffer); + cde::cp_async_bulk_commit_group(); + cde::cp_async_bulk_wait_group_read<0>(); + } + __threadfence(); + __syncthreads(); + + // TEAR-DOWN: check that global memory is correct + for (int i = 0; i < SMEM_HEIGHT; ++i) { + for (int j = 0; j < SMEM_HEIGHT; ++j) { + int gmem_lin_idx = (base_i + i) * GMEM_WIDTH + base_j + j; + + assert(gmem_tensor[gmem_lin_idx] == 2 * gmem_lin_idx + 1); + } + } + __syncthreads(); +} + +#ifndef TEST_COMPILER_NVRTC +PFN_cuTensorMapEncodeTiled get_cuTensorMapEncodeTiled() { + void* driver_ptr = nullptr; + cudaDriverEntryPointQueryResult driver_status; + auto code = cudaGetDriverEntryPoint("cuTensorMapEncodeTiled", &driver_ptr, cudaEnableDefault, &driver_status); + assert(code == cudaSuccess && "Could not get driver API"); + return reinterpret_cast(driver_ptr); +} +#endif // ! TEST_COMPILER_NVRTC + +int main(int, char**) +{ + NV_IF_TARGET(NV_IS_HOST,( + //Required by concurrent_agents_launch to know how many we're launching + cuda_thread_count = 512; + + int * tensor_ptr = nullptr; + auto code = cudaGetSymbolAddress((void**)&tensor_ptr, gmem_tensor); + assert(code == cudaSuccess && "getsymboladdress failed."); + + // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html + CUtensorMap local_tensor_map{}; + // rank is the number of dimensions of the array. + constexpr uint32_t rank = 2; + uint64_t size[rank] = {GMEM_WIDTH, GMEM_HEIGHT}; + // The stride is the number of bytes to traverse from the first element of one row to the next. + // It must be a multiple of 16. + uint64_t stride[rank - 1] = {GMEM_WIDTH * sizeof(int)}; + // The box_size is the size of the shared memory buffer that is used as the + // destination of a TMA transfer. + uint32_t box_size[rank] = {SMEM_WIDTH, SMEM_HEIGHT}; + // The distance between elements in units of sizeof(element). A stride of 2 + // can be used to load only the real component of a complex-valued tensor, for instance. + uint32_t elem_stride[rank] = {1, 1}; + + // Get a function pointer to the cuTensorMapEncodeTiled driver API. + auto cuTensorMapEncodeTiled = get_cuTensorMapEncodeTiled(); + + // Create the tensor descriptor. + CUresult res = cuTensorMapEncodeTiled( + &local_tensor_map, // CUtensorMap *tensorMap, + CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_INT32, + rank, // cuuint32_t tensorRank, + tensor_ptr, // void *globalAddress, + size, // const cuuint64_t *globalDim, + stride, // const cuuint64_t *globalStrides, + box_size, // const cuuint32_t *boxDim, + elem_stride, // const cuuint32_t *elementStrides, + CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, + CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE, + CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_NONE, + CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE + ); + + assert(res == CUDA_SUCCESS && "tensormap creation failed."); + code = cudaMemcpyToSymbol(global_fake_tensor_map, &local_tensor_map, sizeof(CUtensorMap)); + assert(code == cudaSuccess && "memcpytosymbol failed."); + )); + + NV_DISPATCH_TARGET( + NV_IS_DEVICE, ( + test(0, 0); + test(4, 0); + test(4, 4); + ) + ); + return 0; +} diff --git a/libcudacxx/.upstream-tests/test/cuda/barrier/cp_async_bulk_tensor_1d.pass.cpp b/libcudacxx/.upstream-tests/test/cuda/barrier/cp_async_bulk_tensor_1d.pass.cpp new file mode 100644 index 00000000000..4af32de2114 --- /dev/null +++ b/libcudacxx/.upstream-tests/test/cuda/barrier/cp_async_bulk_tensor_1d.pass.cpp @@ -0,0 +1,60 @@ +//===----------------------------------------------------------------------===// +// +// Part of libcu++, the C++ Standard Library for your entire system, +// under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. +// +//===----------------------------------------------------------------------===// +// +// UNSUPPORTED: libcpp-has-no-threads +// UNSUPPORTED: pre-sm-90 +// UNSUPPORTED: nvrtc +// NVRTC_SKIP_KERNEL_RUN // This will have effect once PR 433 is merged (line above should be removed.) + +// + +#include "cp_async_bulk_tensor_generic.h" + +// Define the size of contiguous tensor in global and shared memory. +// +// Note that the first dimension is the one with stride 1. This one must be a +// multiple of 4 to ensure that each new dimension starts at a 16-byte aligned +// offset. +// +// We have a separate variable for host and device because a constexpr +// std::initializer_list cannot be shared between host and device as some of its +// member functions take a const reference, which is unsupported by nvcc. + constexpr std::initializer_list GMEM_DIMS {256}; +__device__ constexpr std::initializer_list GMEM_DIMS_DEV{256}; + constexpr std::initializer_list SMEM_DIMS {32}; +__device__ constexpr std::initializer_list SMEM_DIMS_DEV{32}; + +__device__ constexpr std::initializer_list TEST_SMEM_COORDS[] = { + {0}, + {4}, + {8} +}; + +constexpr size_t gmem_len = tensor_len(GMEM_DIMS); +constexpr size_t smem_len = tensor_len(SMEM_DIMS); + +__device__ int gmem_tensor[gmem_len]; + +int main(int, char**) +{ + NV_DISPATCH_TARGET( + NV_IS_HOST, ( + //Required by concurrent_agents_launch to know how many we're launching + cuda_thread_count = 512; + init_tensor_map(gmem_tensor, GMEM_DIMS, SMEM_DIMS); + ), + NV_IS_DEVICE, ( + for (auto smem_coord : TEST_SMEM_COORDS) { + test(smem_coord, SMEM_DIMS_DEV, GMEM_DIMS_DEV, gmem_tensor, gmem_len); + } + ) + ); + return 0; +} diff --git a/libcudacxx/.upstream-tests/test/cuda/barrier/cp_async_bulk_tensor_2d.pass.cpp b/libcudacxx/.upstream-tests/test/cuda/barrier/cp_async_bulk_tensor_2d.pass.cpp new file mode 100644 index 00000000000..be0c29f5eeb --- /dev/null +++ b/libcudacxx/.upstream-tests/test/cuda/barrier/cp_async_bulk_tensor_2d.pass.cpp @@ -0,0 +1,61 @@ +//===----------------------------------------------------------------------===// +// +// Part of libcu++, the C++ Standard Library for your entire system, +// under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. +// +//===----------------------------------------------------------------------===// +// +// UNSUPPORTED: libcpp-has-no-threads +// UNSUPPORTED: pre-sm-90 +// UNSUPPORTED: nvrtc +// NVRTC_SKIP_KERNEL_RUN // This will have effect once PR 433 is merged (line above should be removed.) + +// + +#include "cp_async_bulk_tensor_generic.h" + +// Define the size of contiguous tensor in global and shared memory. +// +// Note that the first dimension is the one with stride 1. This one must be a +// multiple of 4 to ensure that each new dimension starts at a 16-byte aligned +// offset. +// +// We have a separate variable for host and device because a constexpr +// std::initializer_list cannot be shared between host and device as some of its +// member functions take a const reference, which is unsupported by nvcc. + constexpr std::initializer_list GMEM_DIMS {8, 11}; +__device__ constexpr std::initializer_list GMEM_DIMS_DEV{8, 11}; + constexpr std::initializer_list SMEM_DIMS {4, 2}; +__device__ constexpr std::initializer_list SMEM_DIMS_DEV{4, 2}; + +__device__ constexpr std::initializer_list TEST_SMEM_COORDS[] = { + {0, 0}, + {4, 1}, + {4, 5}, + {0, 5}, +}; + +constexpr size_t gmem_len = tensor_len(GMEM_DIMS); +constexpr size_t smem_len = tensor_len(SMEM_DIMS); + +__device__ int gmem_tensor[gmem_len]; + +int main(int, char**) +{ + NV_DISPATCH_TARGET( + NV_IS_HOST, ( + //Required by concurrent_agents_launch to know how many we're launching + cuda_thread_count = 512; + init_tensor_map(gmem_tensor, GMEM_DIMS, SMEM_DIMS); + ), + NV_IS_DEVICE, ( + for (auto smem_coord : TEST_SMEM_COORDS) { + test(smem_coord, SMEM_DIMS_DEV, GMEM_DIMS_DEV, gmem_tensor, gmem_len); + } + ) + ); + return 0; +} diff --git a/libcudacxx/.upstream-tests/test/cuda/barrier/cp_async_bulk_tensor_3d.pass.cpp b/libcudacxx/.upstream-tests/test/cuda/barrier/cp_async_bulk_tensor_3d.pass.cpp new file mode 100644 index 00000000000..0b3a12f3539 --- /dev/null +++ b/libcudacxx/.upstream-tests/test/cuda/barrier/cp_async_bulk_tensor_3d.pass.cpp @@ -0,0 +1,60 @@ +//===----------------------------------------------------------------------===// +// +// Part of libcu++, the C++ Standard Library for your entire system, +// under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. +// +//===----------------------------------------------------------------------===// +// +// UNSUPPORTED: libcpp-has-no-threads +// UNSUPPORTED: pre-sm-90 +// UNSUPPORTED: nvrtc +// NVRTC_SKIP_KERNEL_RUN // This will have effect once PR 433 is merged (line above should be removed.) + +// + +#include "cp_async_bulk_tensor_generic.h" + +// Define the size of contiguous tensor in global and shared memory. +// +// Note that the first dimension is the one with stride 1. This one must be a +// multiple of 4 to ensure that each new dimension starts at a 16-byte aligned +// offset. +// +// We have a separate variable for host and device because a constexpr +// std::initializer_list cannot be shared between host and device as some of its +// member functions take a const reference, which is unsupported by nvcc. + constexpr std::initializer_list GMEM_DIMS {8, 11, 13}; +__device__ constexpr std::initializer_list GMEM_DIMS_DEV{8, 11, 13}; + constexpr std::initializer_list SMEM_DIMS {4, 2, 4}; +__device__ constexpr std::initializer_list SMEM_DIMS_DEV{4, 2, 4}; + +__device__ constexpr std::initializer_list TEST_SMEM_COORDS[] = { + {0, 0, 0}, + {4, 1, 3}, + {4, 5, 1} +}; + +constexpr size_t gmem_len = tensor_len(GMEM_DIMS); +constexpr size_t smem_len = tensor_len(SMEM_DIMS); + +__device__ int gmem_tensor[gmem_len]; + +int main(int, char**) +{ + NV_DISPATCH_TARGET( + NV_IS_HOST, ( + //Required by concurrent_agents_launch to know how many we're launching + cuda_thread_count = 512; + init_tensor_map(gmem_tensor, GMEM_DIMS, SMEM_DIMS); + ), + NV_IS_DEVICE, ( + for (auto smem_coord : TEST_SMEM_COORDS) { + test(smem_coord, SMEM_DIMS_DEV, GMEM_DIMS_DEV, gmem_tensor, gmem_len); + } + ) + ); + return 0; +} diff --git a/libcudacxx/.upstream-tests/test/cuda/barrier/cp_async_bulk_tensor_4d.pass.cpp b/libcudacxx/.upstream-tests/test/cuda/barrier/cp_async_bulk_tensor_4d.pass.cpp new file mode 100644 index 00000000000..68371a45ca0 --- /dev/null +++ b/libcudacxx/.upstream-tests/test/cuda/barrier/cp_async_bulk_tensor_4d.pass.cpp @@ -0,0 +1,61 @@ +//===----------------------------------------------------------------------===// +// +// Part of libcu++, the C++ Standard Library for your entire system, +// under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. +// +//===----------------------------------------------------------------------===// +// +// UNSUPPORTED: libcpp-has-no-threads +// UNSUPPORTED: pre-sm-90 +// UNSUPPORTED: nvrtc +// NVRTC_SKIP_KERNEL_RUN // This will have effect once PR 433 is merged (line above should be removed.) + +// + +#include "cp_async_bulk_tensor_generic.h" + +// Define the size of contiguous tensor in global and shared memory. +// +// Note that the first dimension is the one with stride 1. This one must be a +// multiple of 4 to ensure that each new dimension starts at a 16-byte aligned +// offset. +// +// We have a separate variable for host and device because a constexpr +// std::initializer_list cannot be shared between host and device as some of its +// member functions take a const reference, which is unsupported by nvcc. + constexpr std::initializer_list GMEM_DIMS {8, 11, 13, 3}; +__device__ constexpr std::initializer_list GMEM_DIMS_DEV{8, 11, 13, 3}; + constexpr std::initializer_list SMEM_DIMS {4, 2, 4, 1}; +__device__ constexpr std::initializer_list SMEM_DIMS_DEV{4, 2, 4, 1}; + +__device__ constexpr std::initializer_list TEST_SMEM_COORDS[] = { + {0, 0, 0, 0}, + {4, 1, 3, 0}, + {4, 8, 7, 2}, + {4, 5, 1, 1} +}; + +constexpr size_t gmem_len = tensor_len(GMEM_DIMS); +constexpr size_t smem_len = tensor_len(SMEM_DIMS); + +__device__ int gmem_tensor[gmem_len]; + +int main(int, char**) +{ + NV_DISPATCH_TARGET( + NV_IS_HOST, ( + //Required by concurrent_agents_launch to know how many we're launching + cuda_thread_count = 512; + init_tensor_map(gmem_tensor, GMEM_DIMS, SMEM_DIMS); + ), + NV_IS_DEVICE, ( + for (auto smem_coord : TEST_SMEM_COORDS) { + test(smem_coord, SMEM_DIMS_DEV, GMEM_DIMS_DEV, gmem_tensor, gmem_len); + } + ) + ); + return 0; +} diff --git a/libcudacxx/.upstream-tests/test/cuda/barrier/cp_async_bulk_tensor_5d.pass.cpp b/libcudacxx/.upstream-tests/test/cuda/barrier/cp_async_bulk_tensor_5d.pass.cpp new file mode 100644 index 00000000000..cbf6141a0af --- /dev/null +++ b/libcudacxx/.upstream-tests/test/cuda/barrier/cp_async_bulk_tensor_5d.pass.cpp @@ -0,0 +1,60 @@ +//===----------------------------------------------------------------------===// +// +// Part of libcu++, the C++ Standard Library for your entire system, +// under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. +// +//===----------------------------------------------------------------------===// +// +// UNSUPPORTED: libcpp-has-no-threads +// UNSUPPORTED: pre-sm-90 +// UNSUPPORTED: nvrtc +// NVRTC_SKIP_KERNEL_RUN // This will have effect once PR 433 is merged (line above should be removed.) + +// + +#include "cp_async_bulk_tensor_generic.h" + +// Define the size of contiguous tensor in global and shared memory. +// +// Note that the first dimension is the one with stride 1. This one must be a +// multiple of 4 to ensure that each new dimension starts at a 16-byte aligned +// offset. +// +// We have a separate variable for host and device because a constexpr +// std::initializer_list cannot be shared between host and device as some of its +// member functions take a const reference, which is unsupported by nvcc. + constexpr std::initializer_list GMEM_DIMS {8, 11, 13, 3, 3}; +__device__ constexpr std::initializer_list GMEM_DIMS_DEV{8, 11, 13, 3, 3}; + constexpr std::initializer_list SMEM_DIMS {4, 2, 4, 1, 1}; +__device__ constexpr std::initializer_list SMEM_DIMS_DEV{4, 2, 4, 1, 1}; + +__device__ constexpr std::initializer_list TEST_SMEM_COORDS[] = { + {0, 0, 0, 0, 0}, + {4, 1, 3, 0, 1}, + {4, 5, 1, 1, 2} +}; + +constexpr size_t gmem_len = tensor_len(GMEM_DIMS); +constexpr size_t smem_len = tensor_len(SMEM_DIMS); + +__device__ int gmem_tensor[gmem_len]; + +int main(int, char**) +{ + NV_DISPATCH_TARGET( + NV_IS_HOST, ( + //Required by concurrent_agents_launch to know how many we're launching + cuda_thread_count = 512; + init_tensor_map(gmem_tensor, GMEM_DIMS, SMEM_DIMS); + ), + NV_IS_DEVICE, ( + for (auto smem_coord : TEST_SMEM_COORDS) { + test(smem_coord, SMEM_DIMS_DEV, GMEM_DIMS_DEV, gmem_tensor, gmem_len); + } + ) + ); + return 0; +} diff --git a/libcudacxx/.upstream-tests/test/cuda/barrier/cp_async_bulk_tensor_generic.h b/libcudacxx/.upstream-tests/test/cuda/barrier/cp_async_bulk_tensor_generic.h new file mode 100644 index 00000000000..8d29fa6df11 --- /dev/null +++ b/libcudacxx/.upstream-tests/test/cuda/barrier/cp_async_bulk_tensor_generic.h @@ -0,0 +1,283 @@ +//===----------------------------------------------------------------------===// +// +// Part of libcu++, the C++ Standard Library for your entire system, +// under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. +// +//===----------------------------------------------------------------------===// +// +// UNSUPPORTED: libcpp-has-no-threads +// UNSUPPORTED: pre-sm-90 + +// + +#ifndef TEST_CP_ASYNC_BULK_TENSOR_GENERIC_H_ +#define TEST_CP_ASYNC_BULK_TENSOR_GENERIC_H_ + +#include +#include // cuda::std::move +#include "test_macros.h" // TEST_NV_DIAG_SUPPRESS + +// NVRTC does not support cuda.h (due to import of stdlib.h) +#ifndef TEST_COMPILER_NVRTC +#include +#include // PFN_cuTensorMapEncodeTiled, CUtensorMap +#endif // ! TEST_COMPILER_NVRTC + +// Suppress warning about barrier in shared memory +TEST_NV_DIAG_SUPPRESS(static_var_with_dynamic_init) + +using barrier = cuda::barrier; +namespace cde = cuda::device::experimental; + +/* + * This header supports the 1d, 2d, ..., 5d test of the TMA PTX wrappers. + * + * The functions below help convert Nd coordinates into something useful. + * + */ + +// Compute the total number of elements in a tensor +constexpr __host__ __device__ int tensor_len(std::initializer_list dims) { + int len = 1; + for (int d : dims) { + len *= d; + } + return len; +} + +// Function to convert: +// a linear index into a shared memory tensor +// into +// a linear index into a global memory tensor. +inline __device__ +int smem_lin_idx_to_gmem_lin_idx( + int smem_lin_idx, + std::initializer_list smem_coord, + std::initializer_list smem_dims, + std::initializer_list gmem_dims) { + assert(smem_coord.size() == smem_dims.size()); + assert(smem_coord.size() == gmem_dims.size()); + + int gmem_lin_idx = 0; + int gmem_stride = 1; + for (int i = 0; i < (int) smem_coord.size(); ++i) { + int smem_i_idx = smem_lin_idx % smem_dims.begin()[i]; + gmem_lin_idx += (smem_coord.begin()[i] + smem_i_idx) * gmem_stride; + + smem_lin_idx /= smem_dims.begin()[i]; + gmem_stride *= gmem_dims.begin()[i]; + } + return gmem_lin_idx; +} + +__device__ inline void cp_tensor_global_to_shared( + CUtensorMap* tensor_map, + std::initializer_list indices, + void *smem, + barrier &bar) { + + const int* idxs = indices.begin(); + + switch (indices.size()) { + case 1: cde::cp_async_bulk_tensor_1d_global_to_shared(smem, tensor_map, idxs[0], bar); break; + case 2: cde::cp_async_bulk_tensor_2d_global_to_shared(smem, tensor_map, idxs[0], idxs[1], bar); break; + case 3: cde::cp_async_bulk_tensor_3d_global_to_shared(smem, tensor_map, idxs[0], idxs[1], idxs[2], bar); break; + case 4: cde::cp_async_bulk_tensor_4d_global_to_shared(smem, tensor_map, idxs[0], idxs[1], idxs[2], idxs[3], bar); break; + case 5: cde::cp_async_bulk_tensor_5d_global_to_shared(smem, tensor_map, idxs[0], idxs[1], idxs[2], idxs[3], idxs[4], bar); break; + default: + assert(false && "Wrong number of dimensions."); + } +} + +__device__ inline void cp_tensor_shared_to_global( + CUtensorMap* tensor_map, + std::initializer_list indices, + void *smem) { + + const int* idxs = indices.begin(); + + switch (indices.size()) { + case 1: cde::cp_async_bulk_tensor_1d_shared_to_global(tensor_map, idxs[0], smem); break; + case 2: cde::cp_async_bulk_tensor_2d_shared_to_global(tensor_map, idxs[0], idxs[1], smem); break; + case 3: cde::cp_async_bulk_tensor_3d_shared_to_global(tensor_map, idxs[0], idxs[1], idxs[2], smem); break; + case 4: cde::cp_async_bulk_tensor_4d_shared_to_global(tensor_map, idxs[0], idxs[1], idxs[2], idxs[3], smem); break; + case 5: cde::cp_async_bulk_tensor_5d_shared_to_global(tensor_map, idxs[0], idxs[1], idxs[2], idxs[3], idxs[4], smem); break; + default: + assert(false && "Wrong number of dimensions."); + } +} + +// To define a tensor map in constant memory, we need a type with a size. On +// NVRTC, cuda.h cannot be imported, so we don't have access to the definition +// of CUTensorMap (only to the declaration of CUtensorMap inside cuda/barrier). +// So we use this type instead and reinterpret_cast in the kernel. +struct fake_cutensormap { + alignas(64) uint64_t opaque[16]; +}; +__constant__ fake_cutensormap global_fake_tensor_map; + +/* + * This test has as primary purpose to make sure that the indices in the mapping + * from C++ to PTX didn't get mixed up. + * + * How does it test this? + * + * 1. It fills a global memory tensor with linear coordinates 0, 1, ... + * 2. It loads a tile into shared memory at some coordinate (x, y, ... ) + * 3. It checks that the coordinates that were received in shared memory match the expected. + * 4. It modifies the coordinates (c = 2 * c + 1) + * 5. It writes the tile back to global memory + * 6. It checks that all the values in global are properly modified. + */ +template +__device__ void test(std::initializer_list smem_coord, + std::initializer_list smem_dims, + std::initializer_list gmem_dims, + int* gmem_tensor, + int gmem_len) +{ + CUtensorMap *global_tensor_map = reinterpret_cast(&global_fake_tensor_map); + + // SETUP: fill global memory buffer + for (int i = threadIdx.x; i < gmem_len; i += blockDim.x) { + gmem_tensor[i] = i; + } + // Ensure that writes to global memory are visible to others, including + // those in the async proxy. + __threadfence(); + __syncthreads(); + + // TEST: Add i to buffer[i] + __shared__ alignas(128) int smem_buffer[smem_len]; + __shared__ barrier bar; + if (threadIdx.x == 0) { init(&bar, blockDim.x); } + __syncthreads(); + + // Load data: + uint64_t token; + if (threadIdx.x == 0) { + // Fastest moving coordinate first. + cp_tensor_global_to_shared(global_tensor_map, smem_coord, smem_buffer, bar); + token = cuda::device::barrier_arrive_tx(bar, 1, sizeof(smem_buffer)); + } else { + token = bar.arrive(); + } + bar.wait(cuda::std::move(token)); + + // Check smem + for (int i = threadIdx.x; i < smem_len; i += blockDim.x) { + int gmem_lin_idx = smem_lin_idx_to_gmem_lin_idx(i, smem_coord, smem_dims, gmem_dims); + assert(smem_buffer[i] == gmem_lin_idx); + } + + __syncthreads(); + + // Update smem + for (int i = threadIdx.x; i < smem_len; i += blockDim.x) { + smem_buffer[i] = 2 * smem_buffer[i] + 1; + } + cde::fence_proxy_async_shared_cta(); + __syncthreads(); + + // Write back to global memory: + if (threadIdx.x == 0) { + cp_tensor_shared_to_global(global_tensor_map, smem_coord, smem_buffer); + cde::cp_async_bulk_commit_group(); + cde::cp_async_bulk_wait_group_read<0>(); + } + __threadfence(); + __syncthreads(); + + // // TEAR-DOWN: check that global memory is correct + for (int i = threadIdx.x; i < smem_len; i += blockDim.x) { + int gmem_lin_idx = smem_lin_idx_to_gmem_lin_idx(i, smem_coord, smem_dims, gmem_dims); + + assert(gmem_tensor[gmem_lin_idx] == 2 * gmem_lin_idx + 1); + } + __syncthreads(); +} + +#ifndef TEST_COMPILER_NVRTC +PFN_cuTensorMapEncodeTiled get_cuTensorMapEncodeTiled() { + void* driver_ptr = nullptr; + cudaDriverEntryPointQueryResult driver_status; + auto code = cudaGetDriverEntryPoint("cuTensorMapEncodeTiled", &driver_ptr, cudaEnableDefault, &driver_status); + assert(code == cudaSuccess && "Could not get driver API"); + return reinterpret_cast(driver_ptr); +} +#endif + +#ifndef TEST_COMPILER_NVRTC +template +CUtensorMap map_encode(T *tensor_ptr, std::initializer_list gmem_dims, std::initializer_list smem_dims) { + // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html + CUtensorMap tensor_map{}; + assert(gmem_dims.size() == smem_dims.size()); + // rank is the number of dimensions of the array. + int rank = gmem_dims.size(); + uint64_t size[rank]; + for (int i = 0; i < rank; ++i) { + size[i] = gmem_dims.begin()[i]; + } + // The stride is the number of bytes to traverse from the first element of one row to the next. + // It must be a multiple of 16. + uint64_t stride[rank - 1]; + int base_stride = sizeof(T); + for (int i = 0; i < rank - 1; ++i) { + base_stride *= gmem_dims.begin()[i]; + stride[i] = base_stride; + } + // The box_size is the size of the shared memory buffer that is used as the + // destination of a TMA transfer. Casting from int -> uint32_t. + const uint32_t *box_size = reinterpret_cast(smem_dims.begin()); + + // The distance between elements in units of sizeof(element). A stride of 2 + // can be used to load only the real component of a complex-valued tensor, for instance. + uint32_t elem_stride[rank]; // = {1, .., 1}; + for (int i = 0; i < rank; ++i) { + elem_stride[i] = 1; + } + + // Get a function pointer to the cuTensorMapEncodeTiled driver API. + auto cuTensorMapEncodeTiled = get_cuTensorMapEncodeTiled(); + + // Create the tensor descriptor. + CUresult res = cuTensorMapEncodeTiled( + &tensor_map, // CUtensorMap *tensorMap, + CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_INT32, + rank, // cuuint32_t tensorRank, + tensor_ptr, // void *globalAddress, + size, // const cuuint64_t *globalDim, + stride, // const cuuint64_t *globalStrides, + box_size, // const cuuint32_t *boxDim, + elem_stride, // const cuuint32_t *elementStrides, + CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, + CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE, + CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_NONE, + CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); + + assert(res == CUDA_SUCCESS && "tensormap creation failed."); + + return tensor_map; +} + +template +void init_tensor_map(const T& gmem_tensor_symbol, std::initializer_list gmem_dims, std::initializer_list smem_dims) { + // Get pointer to gmem_tensor to create tensor map. + int * tensor_ptr = nullptr; + auto code = cudaGetSymbolAddress((void**)&tensor_ptr, gmem_tensor_symbol); + assert(code == cudaSuccess && "Could not get symbol address."); + + // Create tensor map + CUtensorMap local_tensor_map = map_encode(tensor_ptr, gmem_dims, smem_dims); + + // Copy it to device + code = cudaMemcpyToSymbol(global_fake_tensor_map, &local_tensor_map, sizeof(CUtensorMap)); + assert(code == cudaSuccess && "Could not copy symbol to device."); +} +#endif // ! TEST_COMPILER_NVRTC + +#endif // TEST_CP_ASYNC_BULK_TENSOR_GENERIC_H_ diff --git a/libcudacxx/include/cuda/barrier b/libcudacxx/include/cuda/barrier index 7e57cd585f8..15182816c0a 100644 --- a/libcudacxx/include/cuda/barrier +++ b/libcudacxx/include/cuda/barrier @@ -13,4 +13,273 @@ #include "std/barrier" +// Forward-declare CUtensorMap for use in cp_async_bulk_tensor_* PTX wrapping +// functions. These functions take a pointer to CUtensorMap, so do not need to +// know its size. This type is defined in cuda.h (driver API) as: +// +// typedef struct CUtensorMap_st { [ .. snip .. ] } CUtensorMap; +// +// We need to forward-declare both CUtensorMap_st (the struct) and CUtensorMap +// (the typedef): +struct CUtensorMap_st; +typedef struct CUtensorMap_st CUtensorMap; + +_LIBCUDACXX_BEGIN_NAMESPACE_CUDA_DEVICE_EXPERIMENTAL + +// Experimental exposure of TMA PTX: +// +// - cp_async_bulk_global_to_shared +// - cp_async_bulk_shared_to_global +// - cp_async_bulk_tensor_{1,2,3,4,5}d_global_to_shared +// - cp_async_bulk_tensor_{1,2,3,4,5}d_shared_to_global +// - fence_proxy_async_shared_cta +// - cp_async_bulk_commit_group +// - cp_async_bulk_wait_group_read<0, …, 7> + +// These PTX wrappers are only available when the code is compiled compute +// capability 9.0 and above. The check for (!defined(__CUDA_MINIMUM_ARCH__)) is +// necessary to prevent cudafe from ripping out the device functions before +// device compilation begins. +#if (!defined(__CUDA_MINIMUM_ARCH__)) || (defined(__CUDA_MINIMUM_ARCH__) && 900 <= __CUDA_MINIMUM_ARCH__) + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk +inline _LIBCUDACXX_DEVICE +void cp_async_bulk_global_to_shared(void *__dest, const void *__src, _CUDA_VSTD::uint32_t __size, ::cuda::barrier<::cuda::thread_scope_block> &__bar) +{ + _LIBCUDACXX_DEBUG_ASSERT(__size % 16 == 0, "Size must be multiple of 16."); + _LIBCUDACXX_DEBUG_ASSERT(__isShared(__dest), "Destination must be shared memory address."); + _LIBCUDACXX_DEBUG_ASSERT(__isGlobal(__src), "Source must be global memory address."); + + asm volatile( + "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];\n" + : + : "r"(static_cast<_CUDA_VSTD::uint32_t>(__cvta_generic_to_shared(__dest))), + "l"(static_cast<_CUDA_VSTD::uint64_t>(__cvta_generic_to_global(__src))), + "r"(__size), + "r"(static_cast<_CUDA_VSTD::uint32_t>(__cvta_generic_to_shared(::cuda::device::barrier_native_handle(__bar)))) + : "memory"); +} + + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk +inline _LIBCUDACXX_DEVICE +void cp_async_bulk_shared_to_global(void *__dest, const void * __src, _CUDA_VSTD::uint32_t __size) +{ + _LIBCUDACXX_DEBUG_ASSERT(__size % 16 == 0, "Size must be multiple of 16."); + _LIBCUDACXX_DEBUG_ASSERT(__isGlobal(__dest), "Destination must be global memory address."); + _LIBCUDACXX_DEBUG_ASSERT(__isShared(__src), "Source must be shared memory address."); + + asm volatile( + "cp.async.bulk.global.shared::cta.bulk_group [%0], [%1], %2;\n" + : + : "l"(static_cast<_CUDA_VSTD::uint64_t>(__cvta_generic_to_global(__dest))), + "r"(static_cast<_CUDA_VSTD::uint32_t>(__cvta_generic_to_shared(__src))), + "r"(__size) + : "memory"); +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor +inline _LIBCUDACXX_DEVICE +void cp_async_bulk_tensor_1d_global_to_shared( + void *__dest, const CUtensorMap *__tensor_map , int __c0, ::cuda::barrier<::cuda::thread_scope_block> &__bar) +{ + asm volatile( + "cp.async.bulk.tensor.1d.shared::cluster.global.tile.mbarrier::complete_tx::bytes " + "[%0], [%1, {%2}], [%3];\n" + : + : "r"(static_cast<_CUDA_VSTD::uint32_t>(__cvta_generic_to_shared(__dest))), + "l"(__tensor_map), + "r"(__c0), + "r"(static_cast<_CUDA_VSTD::uint32_t>(__cvta_generic_to_shared(::cuda::device::barrier_native_handle(__bar)))) + : "memory"); +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor +inline _LIBCUDACXX_DEVICE +void cp_async_bulk_tensor_2d_global_to_shared( + void *__dest, const CUtensorMap *__tensor_map , int __c0, int __c1, ::cuda::barrier<::cuda::thread_scope_block> &__bar) +{ + asm volatile( + "cp.async.bulk.tensor.2d.shared::cluster.global.tile.mbarrier::complete_tx::bytes " + "[%0], [%1, {%2, %3}], [%4];\n" + : + : "r"(static_cast<_CUDA_VSTD::uint32_t>(__cvta_generic_to_shared(__dest))), + "l"(__tensor_map), + "r"(__c0), + "r"(__c1), + "r"(static_cast<_CUDA_VSTD::uint32_t>(__cvta_generic_to_shared(::cuda::device::barrier_native_handle(__bar)))) + : "memory"); +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor +inline _LIBCUDACXX_DEVICE +void cp_async_bulk_tensor_3d_global_to_shared( + void *__dest, const CUtensorMap *__tensor_map, int __c0, int __c1, int __c2, ::cuda::barrier<::cuda::thread_scope_block> &__bar) +{ + asm volatile( + "cp.async.bulk.tensor.3d.shared::cluster.global.tile.mbarrier::complete_tx::bytes " + "[%0], [%1, {%2, %3, %4}], [%5];\n" + : + : "r"(static_cast<_CUDA_VSTD::uint32_t>(__cvta_generic_to_shared(__dest))), + "l"(__tensor_map), + "r"(__c0), + "r"(__c1), + "r"(__c2), + "r"(static_cast<_CUDA_VSTD::uint32_t>(__cvta_generic_to_shared(::cuda::device::barrier_native_handle(__bar)))) + : "memory"); +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor +inline _LIBCUDACXX_DEVICE +void cp_async_bulk_tensor_4d_global_to_shared( + void *__dest, const CUtensorMap *__tensor_map , int __c0, int __c1, int __c2, int __c3, ::cuda::barrier<::cuda::thread_scope_block> &__bar) +{ + asm volatile( + "cp.async.bulk.tensor.4d.shared::cluster.global.tile.mbarrier::complete_tx::bytes " + "[%0], [%1, {%2, %3, %4, %5}], [%6];\n" + : + : "r"(static_cast<_CUDA_VSTD::uint32_t>(__cvta_generic_to_shared(__dest))), + "l"(__tensor_map), + "r"(__c0), + "r"(__c1), + "r"(__c2), + "r"(__c3), + "r"(static_cast<_CUDA_VSTD::uint32_t>(__cvta_generic_to_shared(::cuda::device::barrier_native_handle(__bar)))) + : "memory"); +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor +inline _LIBCUDACXX_DEVICE +void cp_async_bulk_tensor_5d_global_to_shared( + void *__dest, const CUtensorMap *__tensor_map , int __c0, int __c1, int __c2, int __c3, int __c4, ::cuda::barrier<::cuda::thread_scope_block> &__bar) +{ + asm volatile( + "cp.async.bulk.tensor.5d.shared::cluster.global.tile.mbarrier::complete_tx::bytes " + "[%0], [%1, {%2, %3, %4, %5, %6}], [%7];\n" + : + : "r"(static_cast<_CUDA_VSTD::uint32_t>(__cvta_generic_to_shared(__dest))), + "l"(__tensor_map), + "r"(__c0), + "r"(__c1), + "r"(__c2), + "r"(__c3), + "r"(__c4), + "r"(static_cast<_CUDA_VSTD::uint32_t>(__cvta_generic_to_shared(::cuda::device::barrier_native_handle(__bar)))) + : "memory"); +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor +inline _LIBCUDACXX_DEVICE +void cp_async_bulk_tensor_1d_shared_to_global( + const CUtensorMap *__tensor_map, int __c0, const void *__src) +{ + asm volatile( + "cp.async.bulk.tensor.1d.global.shared::cta.tile.bulk_group " + "[%0, {%1}], [%2];\n" + : + : "l"(__tensor_map), + "r"(__c0), + "r"(static_cast<_CUDA_VSTD::uint32_t>(__cvta_generic_to_shared(__src))) + : "memory"); +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor +inline _LIBCUDACXX_DEVICE +void cp_async_bulk_tensor_2d_shared_to_global( + const CUtensorMap *__tensor_map, int __c0, int __c1, const void *__src) +{ + asm volatile( + "cp.async.bulk.tensor.2d.global.shared::cta.tile.bulk_group " + "[%0, {%1, %2}], [%3];\n" + : + : "l"(__tensor_map), + "r"(__c0), + "r"(__c1), + "r"(static_cast<_CUDA_VSTD::uint32_t>(__cvta_generic_to_shared(__src))) + : "memory"); +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor +inline _LIBCUDACXX_DEVICE +void cp_async_bulk_tensor_3d_shared_to_global( + const CUtensorMap *__tensor_map, int __c0, int __c1, int __c2, const void *__src) +{ + asm volatile( + "cp.async.bulk.tensor.3d.global.shared::cta.tile.bulk_group " + "[%0, {%1, %2, %3}], [%4];\n" + : + : "l"(__tensor_map), + "r"(__c0), + "r"(__c1), + "r"(__c2), + "r"(static_cast<_CUDA_VSTD::uint32_t>(__cvta_generic_to_shared(__src))) + : "memory"); +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor +inline _LIBCUDACXX_DEVICE +void cp_async_bulk_tensor_4d_shared_to_global( + const CUtensorMap *__tensor_map, int __c0, int __c1, int __c2, int __c3, const void *__src) +{ + asm volatile( + "cp.async.bulk.tensor.4d.global.shared::cta.tile.bulk_group " + "[%0, {%1, %2, %3, %4}], [%5];\n" + : + : "l"(__tensor_map), + "r"(__c0), + "r"(__c1), + "r"(__c2), + "r"(__c3), + "r"(static_cast<_CUDA_VSTD::uint32_t>(__cvta_generic_to_shared(__src))) + : "memory"); +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor +inline _LIBCUDACXX_DEVICE +void cp_async_bulk_tensor_5d_shared_to_global( + const CUtensorMap *__tensor_map, int __c0, int __c1, int __c2, int __c3, int __c4, const void *__src) +{ + asm volatile( + "cp.async.bulk.tensor.5d.global.shared::cta.tile.bulk_group " + "[%0, {%1, %2, %3, %4, %5}], [%6];\n" + : + : "l"(__tensor_map), + "r"(__c0), + "r"(__c1), + "r"(__c2), + "r"(__c3), + "r"(__c4), + "r"(static_cast<_CUDA_VSTD::uint32_t>(__cvta_generic_to_shared(__src))) + : "memory"); +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-membar +inline _LIBCUDACXX_DEVICE +void fence_proxy_async_shared_cta() { + asm volatile("fence.proxy.async.shared::cta; \n":::"memory"); +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-commit-group +inline _LIBCUDACXX_DEVICE +void cp_async_bulk_commit_group() +{ + asm volatile("cp.async.bulk.commit_group;\n" ::: "memory"); +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-wait-group +template +inline _LIBCUDACXX_DEVICE +void cp_async_bulk_wait_group_read() +{ + static_assert(n_prior <= 63, "cp_async_bulk_wait_group_read: waiting for more than 63 groups is not supported."); + asm volatile("cp.async.bulk.wait_group.read %0; \n" + : + : "n"(n_prior) + : "memory"); +} + +#endif // __CUDA_MINIMUM_ARCH__ + +_LIBCUDACXX_END_NAMESPACE_CUDA_DEVICE_EXPERIMENTAL + #endif // _CUDA_BARRIER diff --git a/libcudacxx/include/cuda/std/detail/libcxx/include/__config b/libcudacxx/include/cuda/std/detail/libcxx/include/__config index 8284eeb10b1..d0bd126cf4e 100644 --- a/libcudacxx/include/cuda/std/detail/libcxx/include/__config +++ b/libcudacxx/include/cuda/std/detail/libcxx/include/__config @@ -1482,6 +1482,8 @@ typedef __char32_t char32_t; #define _LIBCUDACXX_END_NAMESPACE_CUDA } } #define _LIBCUDACXX_BEGIN_NAMESPACE_CUDA_DEVICE namespace cuda { namespace device { inline namespace _LIBCUDACXX_ABI_NAMESPACE { #define _LIBCUDACXX_END_NAMESPACE_CUDA_DEVICE } } } +#define _LIBCUDACXX_BEGIN_NAMESPACE_CUDA_DEVICE_EXPERIMENTAL namespace cuda { namespace device { namespace experimental { inline namespace _LIBCUDACXX_ABI_NAMESPACE { +#define _LIBCUDACXX_END_NAMESPACE_CUDA_DEVICE_EXPERIMENTAL } } } } #endif // Inline namespaces are available in Clang/GCC/MSVC regardless of C++ dialect. diff --git a/libcudacxx/include/cuda/std/detail/libcxx/include/version b/libcudacxx/include/cuda/std/detail/libcxx/include/version index f52f38ccdc0..d6b4b45fdba 100644 --- a/libcudacxx/include/cuda/std/detail/libcxx/include/version +++ b/libcudacxx/include/cuda/std/detail/libcxx/include/version @@ -226,6 +226,7 @@ __cpp_lib_void_t 201411L // hardware supports it. #if (!defined(__CUDA_MINIMUM_ARCH__)) || (defined(__CUDA_MINIMUM_ARCH__) && 900 <= __CUDA_MINIMUM_ARCH__) # define __cccl_lib_local_barrier_arrive_tx +# define __cccl_lib_experimental_ctk12_cp_async_exposure #endif // We unconditionally define `__cccl_lib_meow` so that there is only one place to set the value