Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,368 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "387b6083",
"metadata": {},
"source": [
"# NVSHMEM4Py Device APIs\n",
"\n",
"## Introduction\n",
"\n",
"This notebook documents the NVSHMEM4Py device API for use in Numba CUDA kernels. It is a continuation of `Nvshmem4py` notebook.\n",
"\n",
"## Environment\n",
"\n",
"NVSHMEM4Py Numba-CUDA device APIs require the additional dependency `Numba-CUDA`, with a matching CUDA API version on your machine. Assuming you are using CUDA 12:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c5cf4c36",
"metadata": {
"vscode": {
"languageId": "shellscript"
}
},
"outputs": [],
"source": [
"!pip install mpi4py nvshmem4py-cu12 cupy-cuda13x numba-cuda[cu12]==0.20.1 cuda-core"
]
},
{
"cell_type": "markdown",
"id": "e0d3279c",
"metadata": {},
"source": [
"## NVSHMEM4Py Numba-CUDA Device API Overview\n",
"\n",
"Device APIs allow developers to write GPU-initiated, one-sided operations from `@numba.cuda.jit` kernels. Users who need fine-grained, low-latency inter-GPU communication entirely from device code are encouraged to use them. These APIs are available via the `nvshmem.device.numba` namespace.\n",
"\n",
"### Features\n",
"\n",
"- Querying\n",
"- Remote Memory Access (RMA)\n",
"- Signal Operations\n",
"- Atomic Memory Operations (AMO)\n",
"- Collectives\n",
"- Synchronization\n",
"- Memory Mapping (direct device loads/stores)\n",
"\n",
"### Pythonic Interface\n",
"\n",
"Unlike the C/C++ APIs, the Numba device APIs are Pythonic in that they accept the `numba.types.Array` type. Certain APIs (such as RMA) omit the transfer element size, as it is deduced from the input array size. Users who need to specify a transfer size should slice the input array to create a view before passing it as an argument. The APIs are data type-aware, so users only need to ensure that the operand arrays have the same data type.\n",
"\n",
"### Thread Scope Variants\n",
"\n",
"Most of these APIs provide `_warp` or `_block` variants, which provide different levels of thread granularity. For example, the `put` API has `put_block` and `put_warp` variants. When used, all threads within the designated scope must receive the same arguments. They are frequently used for these purposes:\n",
"\n",
"- Using `put_block` instead of `put` allows the GPU to copy data with all threads of the same block in parallel if two PEs are connected via a point-to-point connection. If they are connected via a remote connection, only a single GPU thread is used to initialize the copy instruction.\n",
"- The `block` and `warp` variants of collectives allow threads of different granularity levels to perform reductions across PEs."
]
},
{
"cell_type": "markdown",
"id": "0dd374c0",
"metadata": {},
"source": [
"## Example: Ring-Allreduce in Python\n",
"\n",
"The ring-allreduce example performs an allreduce operation using a ring algorithm. The algorithm is separated into two phases: the reduction phase, and the broadcast phase. The example demonstrates use of device side APIs like `put_signal_nbi` and `signal_wait`.\n",
"\n",
"### Problem Setup\n",
"\n",
"Each PE has a local `src` data array initialized to `my_pe() + 1` to indicate its uniqueness. It also has an empty `dst` array of the same size as `src` to hold reduced data. Finally, there's an integral `signal` array to hold signals sent from other PEs.\n",
"\n",
"The following image shows the initial setup of elements on 4 PEs. Signal and chunking is not represented for simplicity.\n",
"\n",
"![NVSHMEM4Py Device API Overview](assets/1.png)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a9d252b9",
"metadata": {},
"outputs": [],
"source": [
"from mpi4py import MPI\n",
"from cuda.core.experimental import Device, system, Stream\n",
"\n",
"from numba import cuda, uint64\n",
"\n",
"import nvshmem\n",
"import nvshmem.bindings\n",
"from nvshmem.core import SignalOp, ComparisonType\n",
"from nvshmem.core.device.numba import put_signal_nbi, signal_wait, my_pe, n_pes\n",
"\n",
"@cuda.jit\n",
"def ring_reduce(dst, src, nreduce, signal, chunk_size):\n",
" # Numba-CUDA constructs to setup thread-wise variables\n",
" mype = my_pe()\n",
" npes = n_pes()\n",
" peer = (mype + 1) % npes\n",
"\n",
" thread_id = cuda.threadIdx.x\n",
" num_threads = cuda.blockDim.x\n",
" num_blocks = cuda.gridDim.x\n",
" block_idx = cuda.blockIdx.x\n",
" elems_per_block = nreduce // num_blocks\n",
"\n",
" if elems_per_block * (block_idx + 1) > nreduce:\n",
" return"
]
},
{
"cell_type": "markdown",
"id": "035dfe05",
"metadata": {},
"source": [
"### Reduction Phase\n",
"\n",
"Initially, PE0 sends its local data to the next PE. Once finished, it increments PE1's signal flag by 1.\n",
"\n",
"![Reduction-1](images/chapter-nvshmem4py-device/2.png)\n",
"\n",
"Meanwhile, PE1 was waiting for an update to the signal flag. Once received, it indicates that PE0 has sent its data. It now performs a local compute.\n",
"\n",
"![Reduction-1](images/chapter-nvshmem4py-device/3.png)\n",
"\n",
"Once compute finishes, PE1 sends the data to the next PE. The next PE waits for the signal, and then performs local compute. It iterates to the last PE.\n",
"\n",
"![Reduction-1](images/chapter-nvshmem4py-device/4.png)\n",
"\n",
"On the last PE, once the compute finishes, it sends the result to PE0. Notice this time the result is already the final reduced result. PE0 is waiting for a signal to be updated. After receiving, it enters the broadcast phase.\n",
"\n",
"![Reduction-1](images/chapter-nvshmem4py-device/5.png)\n",
"\n",
"Each Cooperative Thread Array (CTA) handles a \"chunk\" of data within its assigned range for each iteration. Each chunk is handled independently from other chunks."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "24e1152e",
"metadata": {},
"outputs": [],
"source": [
" init_offset = block_idx * elems_per_block \n",
" signal_block = signal[block_idx:block_idx+1]\n",
" num_chunks = elems_per_block // chunk_size\n",
"\n",
" starts = range(init_offset, init_offset+elems_per_block, chunk_size)\n",
" ends = range(init_offset+chunk_size, init_offset+elems_per_block+chunk_size, chunk_size)\n",
" # Reduce phase\n",
" for chunk, (start, end) in enumerate(zip(starts, ends)):\n",
" src_block = src[start:end]\n",
" dst_block = dst[start:end]\n",
" if mype != 0:\n",
" if thread_id == 0:\n",
" signal_wait(signal_block, ComparisonType.CMP_GE, chunk + 1)\n",
" \n",
" cuda.syncthreads()\n",
" for i in range(thread_id, chunk_size, num_threads):\n",
" dst_block[i] = dst_block[i] + src_block[i]\n",
" cuda.syncthreads()\n",
" \n",
" if thread_id == 0:\n",
" src_data = src_block if mype == 0 else dst_block\n",
" put_signal_nbi(dst_block, src_data, \n",
" signal_block, uint64(1), SignalOp.SIGNAL_ADD, peer)"
]
},
{
"cell_type": "markdown",
"id": "cce6cdba",
"metadata": {},
"source": [
"### Broadcast Phase\n",
"\n",
"The broadcast phase is kicked off by the last PE's `put` instruction to PE0. Once PE0 receives the final result, a chain of `put`s are invoked following the PE order. Afterward, all PEs possess the final computed result.\n",
"\n",
"![Broadcast-1](images/chapter-nvshmem4py-device/6.png)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8d7c3357",
"metadata": {},
"outputs": [],
"source": [
" # Broadcast phase\n",
" if thread_id == 0:\n",
" for chunk, (start, end) in enumerate(zip(starts, ends)):\n",
" dst_block = dst[start:end]\n",
" if mype < npes - 1: # Last pe already has the final result\n",
" expected_val = (chunk + 1) if mype == 0 else (num_chunks + chunk + 1)\n",
" signal_wait(signal_block, ComparisonType.CMP_GE, expected_val)\n",
" \n",
" if mype < npes - 2:\n",
" put_signal_nbi(dst_block, dst_block,\n",
" signal_block, uint64(1), SignalOp.SIGNAL_ADD, peer)"
]
},
{
"cell_type": "markdown",
"id": "150a02dc",
"metadata": {},
"source": [
"## Full Code Example"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "414db538",
"metadata": {},
"outputs": [],
"source": [
"from mpi4py import MPI\n",
"from cuda.core.experimental import Device, system, Stream\n",
"\n",
"from numba import cuda, uint64\n",
"\n",
"import nvshmem\n",
"import nvshmem.bindings\n",
"from nvshmem.core import SignalOp, ComparisonType\n",
"from nvshmem.core.device.numba import put_signal_nbi, signal_wait, my_pe, n_pes\n",
"\n",
"@cuda.jit\n",
"def ring_reduce(dst, src, nreduce, signal, chunk_size):\n",
" mype = my_pe()\n",
" npes = n_pes()\n",
" peer = (mype + 1) % npes\n",
"\n",
" thread_id = cuda.threadIdx.x\n",
" num_threads = cuda.blockDim.x\n",
" num_blocks = cuda.gridDim.x\n",
" block_idx = cuda.blockIdx.x\n",
" elems_per_block = nreduce // num_blocks\n",
"\n",
" # Change src, dst, nreduce, signal to what this block is going to process\n",
" # Each CTA will work independently\n",
" if elems_per_block * (block_idx + 1) > nreduce:\n",
" return\n",
" \n",
" # Adjust pointers for this block\n",
" init_offset = block_idx * elems_per_block\n",
" \n",
" signal_block = signal[block_idx:block_idx+1]\n",
"\n",
" num_chunks = elems_per_block // chunk_size\n",
"\n",
" starts = range(init_offset, init_offset+elems_per_block, chunk_size)\n",
" ends = range(init_offset+chunk_size, init_offset+elems_per_block+chunk_size, chunk_size)\n",
" # Reduce phase\n",
" for chunk, (start, end) in enumerate(zip(starts, ends)):\n",
" src_block = src[start:end]\n",
" dst_block = dst[start:end]\n",
" if mype != 0:\n",
" if thread_id == 0:\n",
" signal_wait(signal_block, ComparisonType.CMP_GE, chunk + 1)\n",
" \n",
" cuda.syncthreads()\n",
" for i in range(thread_id, chunk_size, num_threads):\n",
" dst_block[i] = dst_block[i] + src_block[i]\n",
" cuda.syncthreads()\n",
" \n",
" if thread_id == 0:\n",
" src_data = src_block if mype == 0 else dst_block\n",
" put_signal_nbi(dst_block, src_data, \n",
" signal_block, uint64(1), SignalOp.SIGNAL_ADD, peer)\n",
"\n",
" # if signal is printed here, it will be 0 for first and last PE, num_chunks for other PEs.\n",
"\n",
" # Broadcast phase\n",
" if thread_id == 0:\n",
" for chunk, (start, end) in enumerate(zip(starts, ends)):\n",
" dst_block = dst[start:end]\n",
" if mype < npes - 1: # Last pe already has the final result\n",
" expected_val = (chunk + 1) if mype == 0 else (num_chunks + chunk + 1)\n",
" signal_wait(signal_block, ComparisonType.CMP_GE, expected_val)\n",
" \n",
" if mype < npes - 2:\n",
" put_signal_nbi(dst_block, dst_block,\n",
" signal_block, uint64(1), SignalOp.SIGNAL_ADD, peer)\n",
" \n",
"\n",
"# Initialize MPI and NVSHMEM\n",
"local_rank_per_node = MPI.COMM_WORLD.Get_rank() % system.num_devices\n",
"dev = Device(local_rank_per_node)\n",
"dev.set_current()\n",
"\n",
"nb_stream = cuda.stream() # WAR: Numba-CUDA takes numba stream object or int\n",
"cu_stream_ref = Stream.from_handle(nb_stream.handle.value)\n",
"\n",
"nvshmem.core.init(\n",
" device=dev,\n",
" uid=None,\n",
" rank=None,\n",
" nranks=None,\n",
" mpi_comm=MPI.COMM_WORLD,\n",
" initializer_method=\"mpi\",\n",
")\n",
"\n",
"mype = nvshmem.bindings.my_pe()\n",
"npes = nvshmem.bindings.n_pes()\n",
"\n",
"# Test parameters\n",
"nreduce = 1024\n",
"\n",
"num_blocks = 32\n",
"elems_per_block = nreduce // num_blocks\n",
"num_chunk_per_block = 4\n",
"chunk_size = elems_per_block // num_chunk_per_block\n",
"\n",
"threads_per_block = 512 \n",
"\n",
"# Allocate arrays\n",
"src = nvshmem.core.array((nreduce,), dtype=\"int32\")\n",
"dst = nvshmem.core.array((nreduce,), dtype=\"int32\")\n",
"signal = nvshmem.core.array((num_blocks,), dtype=\"uint64\")\n",
"\n",
"# Initialize data\n",
"for i in range(nreduce):\n",
" src[i] = mype + 1\n",
"\n",
"dst[:] = 0\n",
"\n",
"# Initialize signal\n",
"for i in range(num_blocks):\n",
" signal[i] = 0\n",
"\n",
"# Launch kernel\n",
"ring_reduce[num_blocks, threads_per_block, nb_stream, 0](dst, src, nreduce, signal, chunk_size)\n",
"\n",
"nvshmem.core.barrier(nvshmem.core.Teams.TEAM_WORLD, stream=cu_stream_ref)\n",
"dev.sync()\n",
"\n",
"# Check results\n",
"expected_result = sum(range(1, npes + 1))\n",
"for i in range(nreduce):\n",
" assert dst[i] == expected_result, f\"PE {mype}: Mismatch at index {i}: got {dst[i]}, expected {expected_result}\"\n",
"print(f\"PE {mype}: Ring allreduce test passed\")\n",
"\n",
"# Clean up\n",
"nvshmem.core.free_array(src)\n",
"nvshmem.core.free_array(dst)\n",
"nvshmem.core.free_array(signal)\n",
"nvshmem.core.finalize()\n"
]
},
{
"cell_type": "markdown",
"id": "14ce7aaf",
"metadata": {},
"source": []
}
],
"metadata": {
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.