diff --git a/Hybrid-EP_Implementation.md b/Hybrid-EP_Implementation.md index 1ea07c89..fae9ec7a 100644 --- a/Hybrid-EP_Implementation.md +++ b/Hybrid-EP_Implementation.md @@ -192,16 +192,34 @@ export RDMA_CORE_HOME=/path/to/rdma-core # Path to your RDMA core installation export TORCH_ARCH_LIST="9.0;10.0" # Adjust based on your GPU architecture pip install . ``` + +> RDMA Core requirement: install `rdma-core` v60.0 ([reference](https://github.com/linux-rdma/rdma-core/tree/v60.0)), and the latest release is also recommended ([linux-rdma/rdma-core](https://github.com/linux-rdma/rdma-core.git)). +Example: +```bash +git clone https://github.com/linux-rdma/rdma-core.git +cd rdma-core +git checkout tags/v60.0 +sh build.sh +export RDMA_CORE_HOME=/path/to/rdma-core/build +``` -### Quick Start +Hybrid EP’s RDMA topology probing relies on `libnvidia-ml.so.1`. During Dockerfile builds, compile against the NVML stubs (for example, those shipped in `libnvidia-ml-dev`), then at runtime launch the container with `--gpus all` or a Kubernetes device plugin so that the NVIDIA container runtime injects the host’s real NVML library and prevents driver/library mismatches. -> **⚠️ Important Note for RDMA Inter-node Configuration** -> Currently, the RDMA inter-node kernel implementation requires manual specification of nic names for each GPU. You need to provide the mapping between GPUs and their corresponding IB device names via the `--ib-dev-name-list` parameter. See `tests/test_hybrid_ep.py` for detailed usage examples. -> In addition, when using the RDMA part, after setting num-tokens-per-rank during initialization, all subsequent communications must use the same value. Currently, dynamic sequence length is not supported. -> -> **Automatic topology detection will be supported soon.** -> **Dynamic sequence length will be supported soon.** +Example: +```bash +RUN apt-get update && \ + apt-get install -y --no-install-recommends libnvidia-ml-dev +RUN git clone -b hybrid_ep https://github.com/deepseek-ai/DeepEP.git +ENV HYBRID_EP_MULTINODE=1 +RUN cd DeepEP && \ + TORCH_CUDA_ARCH_LIST="9.0 10.0" MAX_JOBS=8 pip install --no-build-isolation . && \ + apt-get purge -y libnvidia-ml-dev && \ + apt-get autoremove -y && \ + rm -rf /var/lib/apt/lists/* +``` + +### Quick Start Refer to `tests/test_hybrid_ep.py` for comprehensive usage examples including: - Multi-node configuration @@ -209,6 +227,16 @@ Refer to `tests/test_hybrid_ep.py` for comprehensive usage examples including: - Inter-node testing scenarios - Performance benchmarking setups +**Explicitly configure `num_of_hybrid_ep_ranks_per_nvlink_domain` (default 8, representing the number of Hybrid-EP ranks that participate in the same Hybrid-EP communication within a single NVLink domain, this value is critical for MNNVL case) and `USE_MNNVL` (default disabled/False) either via uppercase environment variables or by passing arguments to `HybridEPBuffer.__init__`. In multi-node NVLink deployments you must enable `USE_MNNVL=1`.** + +Example configuration on EP64, MNNVL: +- Environment variables: + ``` + export NUM_OF_HYBRID_EP_RANKS_PER_NVLINK_DOMAIN=64 + export USE_MNNVL=1 + ``` +- Python init: `HybridEPBuffer(..., num_of_hybrid_ep_ranks_per_nvlink_domain=64, use_mnnvl=True)` + ### Important Configuration Note Here are important parameter settings in `csrc/hybrid_ep/config.cuh`. You can modify these parameters via `HybridEPBuffer.init_config()` or by setting proper environment variables (see `deep_ep/hybrid_ep_buffer.py`) to achieve better performance/usability: @@ -264,7 +292,6 @@ Here are important parameter settings in `csrc/hybrid_ep/config.cuh`. You can mo - Comprehensive performance improvements ### 🚧 Upcoming Features -- **Automatic Topology Detection**: Automatic detection of GPU-NIC mapping for RDMA inter-node communication, eliminating the need for manual `--ib-dev-name-list` configuration - **Low Latency Mode**: Enhanced performance for latency-critical workloads - Performance optimization diff --git a/csrc/hybrid_ep/backend/NCCL_LICENSE.txt b/csrc/hybrid_ep/backend/NCCL_LICENSE.txt new file mode 100644 index 00000000..d9de63aa --- /dev/null +++ b/csrc/hybrid_ep/backend/NCCL_LICENSE.txt @@ -0,0 +1,38 @@ + Copyright (c) 2015-2020, NVIDIA CORPORATION. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + * Neither the name of NVIDIA CORPORATION, Lawrence Berkeley National + Laboratory, the U.S. Department of Energy, nor the names of their + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY + EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR + CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY + OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + The U.S. Department of Energy funded the development of this software + under subcontract 7078610 with Lawrence Berkeley National Laboratory. + + +This code also includes files from the NVIDIA Tools Extension SDK project. + +See: + + https://github.com/NVIDIA/NVTX + +for more information and license details. diff --git a/csrc/hybrid_ep/backend/hybrid_ep_backend.cuh b/csrc/hybrid_ep/backend/hybrid_ep_backend.cuh index 87df22e9..96c61e3b 100644 --- a/csrc/hybrid_ep/backend/hybrid_ep_backend.cuh +++ b/csrc/hybrid_ep/backend/hybrid_ep_backend.cuh @@ -92,6 +92,8 @@ template{ // Shared memory token buffer. Should be 128B alignment for optimal perf for TMA. alignas(128) uint8_t intra_node_token_buffer[NUM_OF_STAGES][HIDDEN_DIM]; + // Shared memory ping-pong buffer for sparse_to_dense map for token data chunks. Should be 128B alignment for optimal perf for TMA. + alignas(128) int32_t sparse_to_dense_map_buffer[2][NUM_OF_TOKENS_PER_CHUNK][NUM_OF_RANKS_PER_NODE]; // Shared memory Prob buffer. Only used in FW dispatch. Should be 16B alignment so can be used with TMA. 128B is too strict. alignas(16) float intra_node_prob_buffer[NUM_OF_STAGES][NUM_OF_EXPERTS_PER_RANK * NUM_OF_RANKS_PER_NODE]; // Shared memory scaling factor buffer. Only when using FP8 token. Should be 16B alignment so can be used with TMA. 128B is too strict. @@ -100,6 +102,10 @@ struct dispatch_kernel_dynamic_shared_memory_buffer_tconsumer, 2nd for consumer->producer. Should be 8B alignment(natural alignment). alignas(8) uint64_t intra_node_mbarrier_buffer[NUM_OF_STAGES][2]; + // Shared memory mbarrier that protect sparse_to_dense map. Should be 8B alignment(natural alignment). + alignas(8) uint64_t sparse_to_dense_map_mbarrier_buffer[2]; + // Shared memory mbarrier that perform sync within S2G warp group. Should be 8B alignment(natural alignment). + alignas(8) uint64_t S2G_group_mbarrier_buffer; // Shared memory mr info for dispatch. (Mr info can be cached in shared memory, while qp info can't be cached.) alignas(8) dispatch_memory_region_info_t dispatch_memory_region_info[NUM_OF_NODES - 1]; // Num of tx messages. @@ -115,12 +121,18 @@ template{ // Shared memory token buffer. Should be 128B alignment for optimal perf for TMA. alignas(128) uint16_t intra_node_token_buffer[NUM_OF_STAGES][HIDDEN_DIM]; + // Shared memory ping-pong buffer for sparse_to_dense map for token data chunks. Should be 128B alignment for optimal perf for TMA. + alignas(128) int32_t sparse_to_dense_map_buffer[2][NUM_OF_TOKENS_PER_CHUNK][NUM_OF_RANKS_PER_NODE]; // Shared memory Prob buffer. Only used in FW dispatch. Should be 16B alignment so can be used with TMA. 128B is too strict. alignas(16) float intra_node_prob_buffer[NUM_OF_STAGES][NUM_OF_EXPERTS_PER_RANK * NUM_OF_RANKS_PER_NODE]; // Shared memory attn_to_rdma_map buffer, Should be 16B alignment. alignas(16) bool attn_to_rdma_map_buffer[NUM_OF_TOKENS_PER_CHUNK * (NUM_OF_NODES - 1)]; // Shared memory mbarrier that protect token entry, 1st for producer->consumer, 2nd for consumer->producer. Should be 8B alignment(natural alignment). alignas(8) uint64_t intra_node_mbarrier_buffer[NUM_OF_STAGES][2]; + // Shared memory mbarrier that protect sparse_to_dense map. Should be 8B alignment(natural alignment). + alignas(8) uint64_t sparse_to_dense_map_mbarrier_buffer[2]; + // Shared memory mbarrier that perform sync within S2G warp group. Should be 8B alignment(natural alignment). + alignas(8) uint64_t S2G_group_mbarrier_buffer; // Shared memory mr info for dispatch. (Mr info can be cached in shared memory, while qp info can't be cached.) alignas(8) dispatch_memory_region_info_t dispatch_memory_region_info[NUM_OF_NODES - 1]; // Num of tx messages. @@ -136,12 +148,18 @@ template{ // Shared memory token buffer. Should be 128B alignment for optimal perf for TMA. alignas(128) uint8_t intra_node_token_buffer[NUM_OF_STAGES][HIDDEN_DIM]; + // Shared memory ping-pong buffer for sparse_to_dense map for token data chunks. Should be 128B alignment for optimal perf for TMA. + alignas(128) int32_t sparse_to_dense_map_buffer[2][NUM_OF_TOKENS_PER_CHUNK][NUM_OF_RANKS_PER_NODE]; // Shared memory scaling factor buffer. Only when using FP8 token. Should be 16B alignment so can be used with TMA. 128B is too strict. alignas(16) float intra_node_scaling_factor_buffer[NUM_OF_STAGES][HIDDEN_DIM / 128]; // Shared memory attn_to_rdma_map buffer, Should be 16B alignment. alignas(16) bool attn_to_rdma_map_buffer[NUM_OF_TOKENS_PER_CHUNK * (NUM_OF_NODES - 1)]; // Shared memory mbarrier that protect token entry, 1st for producer->consumer, 2nd for consumer->producer. Should be 8B alignment(natural alignment). alignas(8) uint64_t intra_node_mbarrier_buffer[NUM_OF_STAGES][2]; + // Shared memory mbarrier that protect sparse_to_dense map. Should be 8B alignment(natural alignment). + alignas(8) uint64_t sparse_to_dense_map_mbarrier_buffer[2]; + // Shared memory mbarrier that perform sync within S2G warp group. Should be 8B alignment(natural alignment). + alignas(8) uint64_t S2G_group_mbarrier_buffer; // Shared memory mr info for dispatch. (Mr info can be cached in shared memory, while qp info can't be cached.) alignas(8) dispatch_memory_region_info_t dispatch_memory_region_info[NUM_OF_NODES - 1]; // Num of tx messages. @@ -157,10 +175,16 @@ template{ // Shared memory token buffer. Should be 128B alignment for optimal perf for TMA. alignas(128) uint16_t intra_node_token_buffer[NUM_OF_STAGES][HIDDEN_DIM]; + // Shared memory ping-pong buffer for sparse_to_dense map for token data chunks. Should be 128B alignment for optimal perf for TMA. + alignas(128) int32_t sparse_to_dense_map_buffer[2][NUM_OF_TOKENS_PER_CHUNK][NUM_OF_RANKS_PER_NODE]; // Shared memory attn_to_rdma_map buffer, Should be 16B alignment. alignas(16) bool attn_to_rdma_map_buffer[NUM_OF_TOKENS_PER_CHUNK * (NUM_OF_NODES - 1)]; // Shared memory mbarrier that protect token entry, 1st for producer->consumer, 2nd for consumer->producer. Should be 8B alignment(natural alignment). alignas(8) uint64_t intra_node_mbarrier_buffer[NUM_OF_STAGES][2]; + // Shared memory mbarrier that protect sparse_to_dense map. Should be 8B alignment(natural alignment). + alignas(8) uint64_t sparse_to_dense_map_mbarrier_buffer[2]; + // Shared memory mbarrier that perform sync within S2G warp group. Should be 8B alignment(natural alignment). + alignas(8) uint64_t S2G_group_mbarrier_buffer; // Shared memory mr info for dispatch. (Mr info can be cached in shared memory, while qp info can't be cached.) alignas(8) dispatch_memory_region_info_t dispatch_memory_region_info[NUM_OF_NODES - 1]; // Num of tx messages. @@ -176,12 +200,18 @@ template{ // Shared memory token buffer. Should be 128B alignment for optimal perf for TMA. alignas(128) uint8_t intra_node_token_buffer[NUM_OF_STAGES][HIDDEN_DIM]; + // Shared memory ping-pong buffer for sparse_to_dense map for token data chunks. Should be 128B alignment for optimal perf for TMA. + alignas(128) int32_t sparse_to_dense_map_buffer[2][NUM_OF_TOKENS_PER_CHUNK][NUM_OF_RANKS_PER_NODE]; // Shared memory Prob buffer. Only used in FW dispatch. Should be 16B alignment so can be used with TMA. 128B is too strict. alignas(16) float intra_node_prob_buffer[NUM_OF_STAGES][NUM_OF_EXPERTS_PER_RANK * NUM_OF_RANKS_PER_NODE]; // Shared memory scaling factor buffer. Only when using FP8 token. Should be 16B alignment so can be used with TMA. 128B is too strict. alignas(16) float intra_node_scaling_factor_buffer[NUM_OF_STAGES][HIDDEN_DIM / 128]; // Shared memory mbarrier that protect token entry, 1st for producer->consumer, 2nd for consumer->producer. Should be 8B alignment(natural alignment). alignas(8) uint64_t intra_node_mbarrier_buffer[NUM_OF_STAGES][2]; + // Shared memory mbarrier that protect sparse_to_dense map. Should be 8B alignment(natural alignment). + alignas(8) uint64_t sparse_to_dense_map_mbarrier_buffer[2]; + // Shared memory mbarrier that perform sync within S2G warp group. Should be 8B alignment(natural alignment). + alignas(8) uint64_t S2G_group_mbarrier_buffer; }; template{ // Shared memory token buffer. Should be 128B alignment for optimal perf for TMA. alignas(128) uint16_t intra_node_token_buffer[NUM_OF_STAGES][HIDDEN_DIM]; + // Shared memory ping-pong buffer for sparse_to_dense map for token data chunks. Should be 128B alignment for optimal perf for TMA. + alignas(128) int32_t sparse_to_dense_map_buffer[2][NUM_OF_TOKENS_PER_CHUNK][NUM_OF_RANKS_PER_NODE]; // Shared memory Prob buffer. Only used in FW dispatch. Should be 16B alignment so can be used with TMA. 128B is too strict. alignas(16) float intra_node_prob_buffer[NUM_OF_STAGES][NUM_OF_EXPERTS_PER_RANK * NUM_OF_RANKS_PER_NODE]; // Shared memory mbarrier that protect token entry, 1st for producer->consumer, 2nd for consumer->producer. Should be 8B alignment(natural alignment). alignas(8) uint64_t intra_node_mbarrier_buffer[NUM_OF_STAGES][2]; + // Shared memory mbarrier that protect sparse_to_dense map. Should be 8B alignment(natural alignment). + alignas(8) uint64_t sparse_to_dense_map_mbarrier_buffer[2]; + // Shared memory mbarrier that perform sync within S2G warp group. Should be 8B alignment(natural alignment). + alignas(8) uint64_t S2G_group_mbarrier_buffer; }; template{ // Shared memory token buffer. Should be 128B alignment for optimal perf for TMA. alignas(128) uint8_t intra_node_token_buffer[NUM_OF_STAGES][HIDDEN_DIM]; + // Shared memory ping-pong buffer for sparse_to_dense map for token data chunks. Should be 128B alignment for optimal perf for TMA. + alignas(128) int32_t sparse_to_dense_map_buffer[2][NUM_OF_TOKENS_PER_CHUNK][NUM_OF_RANKS_PER_NODE]; // Shared memory scaling factor buffer. Only when using FP8 token. Should be 16B alignment so can be used with TMA. 128B is too strict. alignas(16) float intra_node_scaling_factor_buffer[NUM_OF_STAGES][HIDDEN_DIM / 128]; // Shared memory mbarrier that protect token entry, 1st for producer->consumer, 2nd for consumer->producer. Should be 8B alignment(natural alignment). alignas(8) uint64_t intra_node_mbarrier_buffer[NUM_OF_STAGES][2]; + // Shared memory mbarrier that protect sparse_to_dense map. Should be 8B alignment(natural alignment). + alignas(8) uint64_t sparse_to_dense_map_mbarrier_buffer[2]; + // Shared memory mbarrier that perform sync within S2G warp group. Should be 8B alignment(natural alignment). + alignas(8) uint64_t S2G_group_mbarrier_buffer; }; template{ // Shared memory token buffer. Should be 128B alignment for optimal perf for TMA. alignas(128) uint16_t intra_node_token_buffer[NUM_OF_STAGES][HIDDEN_DIM]; + // Shared memory ping-pong buffer for sparse_to_dense map for token data chunks. Should be 128B alignment for optimal perf for TMA. + alignas(128) int32_t sparse_to_dense_map_buffer[2][NUM_OF_TOKENS_PER_CHUNK][NUM_OF_RANKS_PER_NODE]; // Shared memory mbarrier that protect token entry, 1st for producer->consumer, 2nd for consumer->producer. Should be 8B alignment(natural alignment). alignas(8) uint64_t intra_node_mbarrier_buffer[NUM_OF_STAGES][2]; + // Shared memory mbarrier that protect sparse_to_dense map. Should be 8B alignment(natural alignment). + alignas(8) uint64_t sparse_to_dense_map_mbarrier_buffer[2]; + // Shared memory mbarrier that perform sync within S2G warp group. Should be 8B alignment(natural alignment). + alignas(8) uint64_t S2G_group_mbarrier_buffer; }; template= 2 for inter-node reduction warp group, +// RDMA warp group currently only contains 1 warp so does not use named bar yet, if it need to use, it should use 2 + NUM_OF_DATA_PIPELINE_PER_BLOCK. inline __device__ void arrive_and_wait(uint32_t num_threads, uint32_t barrier_id = 0) { asm volatile("bar.sync %0, %1;" : : "r"(barrier_id), "r"(num_threads)); } @@ -455,8 +498,8 @@ template( - doca_gpu_dev_verbs_qp_get_cq_sq(qp), - smem_inter_node_num_of_write_per_node_ptr[INTER_NODE_GROUP::thread_rank()]); - assert(status >= 0); + uint32_t wc_num_to_poll = smem_inter_node_num_of_write_per_node_ptr[INTER_NODE_GROUP::thread_rank()]; + if (wc_num_to_poll > 0) { + int status = doca_gpu_dev_verbs_poll_cq( + doca_gpu_dev_verbs_qp_get_cq_sq(qp), wc_num_to_poll); + assert(status >= 0); + } } } #endif // Device function for intra-node G2S warp for dispatch kernel. There can be only 1 intra-node G2S warp per CUDA block! -template(&smem_buffer_ptr->sparse_to_dense_map_buffer[sparse_to_dense_map_stage][0][0]), + reinterpret_cast(sparse_to_dense_map_load_base_addr), + (uint32_t)(current_chunk_size * NUM_OF_RANKS_PER_NODE * sizeof(int32_t)), + &smem_buffer_ptr->sparse_to_dense_map_mbarrier_buffer[sparse_to_dense_map_stage]); + + cuda::ptx::mbarrier_arrive_expect_tx(cuda::ptx::sem_release, + cuda::ptx::scope_cta, + cuda::ptx::space_shared, + &smem_buffer_ptr->sparse_to_dense_map_mbarrier_buffer[sparse_to_dense_map_stage], + (uint32_t)(current_chunk_size * NUM_OF_RANKS_PER_NODE * sizeof(int32_t))); + } + } // Loop through all data chunk. Data(chunk) parallel between multiple CUDA blocks. for(int i = blockIdx.x; i < num_of_chunks_per_rank; i += NUM_OF_BLOCKS){ // How many rdma_to_attn load iter for this chunk. @@ -861,15 +957,61 @@ inline __device__ void S2G_warp_group_device_function(const int local_rank, current_chunk_size = NUM_OF_TOKENS_PER_CHUNK; } for(int j = 0; j < NUM_OF_NODES; j++){ + // All S2G warps(threads) need to sync to make sure all of them have finished consuming the sparse_to_dense map for the last chunk before prefetching the sparse_to_dense map for next chunk. + // Equal to arrive_and_wait. But arrive_and_wait can only used for whole warps. + uint64_t state_token = cuda::ptx::mbarrier_arrive(&smem_buffer_ptr->S2G_group_mbarrier_buffer); + while(!cuda::ptx::mbarrier_try_wait(&smem_buffer_ptr->S2G_group_mbarrier_buffer, state_token)){} + + // First warp(thread) will prefetch sparse_to_dense map for next chunk. + if(INTRA_NODE_S2G_GROUP::warp_rank() == 0){ + // Calculate next chunk id for this CUDA block to prefetch sparse_to_dense map for next chunk. + int next_chunk_id; + int next_node_id; + int next_node_iter = j + 1; + if(next_node_iter < NUM_OF_NODES){ + next_chunk_id = i; + next_node_id = node_rank >= next_node_iter ? node_rank - next_node_iter : node_rank + NUM_OF_NODES - next_node_iter; + }else{ + next_chunk_id = i + NUM_OF_BLOCKS; + next_node_id = node_rank; + } + + // If next chunk exist, load the sparse_to_dense map for next chunk. + if(next_chunk_id < num_of_chunks_per_rank){ + // How many token for this chunk. + int current_chunk_size; + if(remainder_chunk_size != 0 && next_chunk_id == num_of_chunks_per_rank - 1){ + current_chunk_size = remainder_chunk_size; + }else{ + current_chunk_size = NUM_OF_TOKENS_PER_CHUNK; + } + // sparse_to_dense map load base addr. + const int32_t* sparse_to_dense_map_load_base_addr = sparse_to_dense_map + (next_node_id * num_of_tokens_per_rank + next_chunk_id * NUM_OF_TOKENS_PER_CHUNK) * NUM_OF_RANKS_PER_NODE; + // Load the sparse_to_dense map for the next chunk. + cuda::ptx::cp_async_bulk(cuda::ptx::space_shared, + cuda::ptx::space_global, + reinterpret_cast(&smem_buffer_ptr->sparse_to_dense_map_buffer[sparse_to_dense_map_stage ^ 1][0][0]), + reinterpret_cast(sparse_to_dense_map_load_base_addr), + (uint32_t)(current_chunk_size * NUM_OF_RANKS_PER_NODE * sizeof(int32_t)), + &smem_buffer_ptr->sparse_to_dense_map_mbarrier_buffer[sparse_to_dense_map_stage ^ 1]); + + cuda::ptx::mbarrier_arrive_expect_tx(cuda::ptx::sem_release, + cuda::ptx::scope_cta, + cuda::ptx::space_shared, + &smem_buffer_ptr->sparse_to_dense_map_mbarrier_buffer[sparse_to_dense_map_stage ^ 1], + (uint32_t)(current_chunk_size * NUM_OF_RANKS_PER_NODE * sizeof(int32_t))); + } + } + // The current node been processed. For each chunk id, node_id order is local_node, local_node - 1, local_node - 2, ......, local_node + 1 and will wrap around. int node_id = node_rank >= j ? node_rank - j : node_rank + NUM_OF_NODES - j; // Store every token and its properties from Shared to Global. Only store tokens that is needed by this node. const rdma_to_attn_map_load_t* rdma_to_attn_map_load_base_addr = reinterpret_cast(rdma_to_attn_map + (node_id * rdma_to_attn_map_size_per_node + i * NUM_OF_TOKENS_PER_CHUNK)); - const int32_t* sparse_to_dense_map_load_base_addr = sparse_to_dense_map + (node_id * num_of_tokens_per_rank + i * NUM_OF_TOKENS_PER_CHUNK) * NUM_OF_RANKS_PER_NODE; + // Wait for sparse_to_dense map ready in smem for current chunk. + while(!cuda::ptx::mbarrier_try_wait_parity(&smem_buffer_ptr->sparse_to_dense_map_mbarrier_buffer[sparse_to_dense_map_stage], sparse_to_dense_map_parity)){} - //#pragma unroll for(int k = 0; k < num_of_routing_info_load_iter_for_current_chunk; k++){ rdma_to_attn_map_load_t rdma_to_attn_map_data = rdma_to_attn_map_load_base_addr[k]; #pragma unroll @@ -882,14 +1024,14 @@ inline __device__ void S2G_warp_group_device_function(const int local_rank, bool token_needed_by_this_node = *(reinterpret_cast(&rdma_to_attn_map_data) + n); if(token_needed_by_this_node){ const sparse_to_dense_map_load_t* sparse_to_dense_map_load_addr = reinterpret_cast - (sparse_to_dense_map_load_base_addr + (k * NUM_OF_TOKENS_PER_LOAD_ITER + n) * NUM_OF_RANKS_PER_NODE); + (&smem_buffer_ptr->sparse_to_dense_map_buffer[sparse_to_dense_map_stage][k * NUM_OF_TOKENS_PER_LOAD_ITER + n][0]); // Wait until token entry within the shared memory has been produced. while(!cuda::ptx::mbarrier_try_wait_parity(&smem_buffer_ptr->intra_node_mbarrier_buffer[stage][0], producer_parity)){} // This token entry will be multicast to all ranks within this node which need this token and its properties. // The current implementation do the multicast by issue each unicast separately(we call it a unicast group). If NVLS can be used, we should use it here. - #pragma unroll - for(int m = 0; m < NUM_OF_SPARSE_TO_DENSE_MAP_LOAD_ITER_PER_INPUT_TOKEN; m++){ + // Multicast of a src token will be ditributed to multiple S2G threads. + for(int m = INTRA_NODE_S2G_GROUP::warp_rank(); m < NUM_OF_SPARSE_TO_DENSE_MAP_LOAD_ITER_PER_INPUT_TOKEN; m += INTRA_NODE_S2G_GROUP::warp_size()){ // Load sparse_to_dense_map. sparse_to_dense_map_load_t sparse_to_dense_map_data = sparse_to_dense_map_load_addr[m]; #pragma unroll @@ -932,10 +1074,19 @@ inline __device__ void S2G_warp_group_device_function(const int local_rank, } // Commit the previous issued S2G TMA instructions for the same shared memory token entry to a bulk async copy group. cuda::ptx::cp_async_bulk_commit_group(); - // Wait for previous commited TMA instructions to finish reading the shared memory, so the shared memory can be reused by the producer. - cuda::ptx::cp_async_bulk_wait_group_read(cuda::ptx::n32_t<0>{}); - // Notify the producer warp to load next token entry to the shared memory as the shared memory can be reused. - cuda::ptx::mbarrier_arrive(&smem_buffer_ptr->intra_node_mbarrier_buffer[stage][1]); + // Add 1 more in-flight S2G token entry to the counter. + in_flight_s2g += 1; + // If in-flight S2G token entry count has exceeded the expectation, release the 1 oldest token entry for the producer. + if(in_flight_s2g > NUM_OF_IN_FLIGHT_S2G){ + // Wait for all TMA S2G instructions for the 1 oldest token entry to finish reading the shared memory, so the token entry can be reused by the producer. + cuda::ptx::cp_async_bulk_wait_group_read(cuda::ptx::n32_t{}); + // Reduce 1 in-flight S2G token entry from the counter. + in_flight_s2g -= 1; + // Notify the producer warp to load next token entry to the oldest token entry as the shared memory can be reused. + int notify_stage = (stage - NUM_OF_IN_FLIGHT_S2G) >= 0 ? (stage - NUM_OF_IN_FLIGHT_S2G) : (stage - NUM_OF_IN_FLIGHT_S2G + NUM_OF_STAGES); + cuda::ptx::mbarrier_arrive(&smem_buffer_ptr->intra_node_mbarrier_buffer[notify_stage][1]); + } + // Goto next token entry in shared memory. stage += 1; if(stage == NUM_OF_STAGES){ @@ -945,38 +1096,14 @@ inline __device__ void S2G_warp_group_device_function(const int local_rank, } } } - } - } - // All S2G TMA operations for all tokens assigned to this CUDA block have been issued. - // If the synchronization for output buffer for current rank is on host-side(i.e. cudaStreamSynchronize + MPI_Barrier etc.), then all CUDA block can exit. - // The result of output buffer for current rank is not ready when the dipatch kernel is completed, a Barrier within the node is needed. - // Otherwise, the S2G warp of the first CUDA block must wait for all writes to the local output buffer complete before exit. So kernel completion means the output buffers for current rank is ready. - /*if constexpr(DEVICE_SIDE_SYNC){ - // Wait for all previous issued TMA instructions to complete writing to remote global memory. - cuda::ptx::cp_async_bulk_wait_group(cuda::ptx::n32_t<0>{}); - // Atomically add 1 to the remote flag on remote ranks within the node to notify the remote rank. - for(int i = 0; i < NUM_OF_RANKS_PER_NODE; i++){ - // red.release.sys.global.add.u32 [a], 1; - asm volatile("red.release.sys.global.add.u32 [%0], %1;" - : - : "l"(__cvta_generic_to_global(&remote_write_completion_flags[i][local_rank])) , "n"(1) - : "memory"); - } - if(blockIdx.x == 0){ - // Wait for all flags on local rank to reach the expected value before exit. - for(int i = 0; i < NUM_OF_RANKS_PER_NODE; i++){ - uint32_t flag_data = 0; - do{ - flag_data = 0; - // Need a strong system-scope load to observe peer ranks' Atomic result. - asm volatile("ld.relaxed.sys.global.u32 %0, [%1];" - : "=r"(flag_data) - : "l"(__cvta_generic_to_global(&remote_write_completion_flags[local_rank][i])) - : "memory"); - }while(flag_data != expected_flag_value); + // Before goto next chunk, go to next sparse_to_dense map stage. + sparse_to_dense_map_stage += 1; + if(sparse_to_dense_map_stage == 2){ + sparse_to_dense_map_stage = 0; + sparse_to_dense_map_parity ^= 1; } } - }*/ + } } } @@ -1154,6 +1281,7 @@ template= NUM_OF_EXPERTS_PER_RANK * NUM_OF_RANKS_PER_NODE, "The size of intra-node reduction warp group must not be smaller than prob size."); @@ -1223,6 +1353,7 @@ inline __device__ void intra_node_red_warp_group_device_function(const int node_ int chunk_id = i / (NUM_OF_NODES - 1); // Which node this chunk belongs to in output rdma reduction buffers. int rdma_remote_node_id = node_id > node_rank ? node_id - 1 : node_id; + int rdma_intra_node_red_id = rdma_remote_node_id * MAX_NUM_OF_TOKENS_PER_RANK + chunk_id * NUM_OF_TOKENS_PER_CHUNK; // How many rdma_to_attn load iter for this chunk. int num_of_routing_info_load_iter_for_current_chunk; // How many token for this chunk. @@ -1238,12 +1369,10 @@ inline __device__ void intra_node_red_warp_group_device_function(const int node_ const rdma_to_attn_map_load_t* rdma_to_attn_map_load_base_addr = reinterpret_cast(rdma_to_attn_map + (node_id * rdma_to_attn_map_size_per_node + chunk_id * NUM_OF_TOKENS_PER_CHUNK)); - uint16_t* rdma_intra_node_red_token_base_ptr = rdma_intra_node_red_token + (rdma_remote_node_id * num_of_tokens_per_rank + chunk_id * NUM_OF_TOKENS_PER_CHUNK) * HIDDEN_DIM; + uint16_t* rdma_intra_node_red_token_base_ptr = rdma_intra_node_red_token + rdma_intra_node_red_id * HIDDEN_DIM; float* rdma_intra_node_red_prob_base_ptr; if constexpr(BACKWARD_COMBINE){ - rdma_intra_node_red_prob_base_ptr = rdma_intra_node_red_prob + - (rdma_remote_node_id * num_of_tokens_per_rank + chunk_id * NUM_OF_TOKENS_PER_CHUNK) * - (NUM_OF_EXPERTS_PER_RANK * NUM_OF_RANKS_PER_NODE); + rdma_intra_node_red_prob_base_ptr = rdma_intra_node_red_prob + rdma_intra_node_red_id * (NUM_OF_EXPERTS_PER_RANK * NUM_OF_RANKS_PER_NODE); } // How many dst token entry of current chunk have been in-flight. @@ -1316,10 +1445,12 @@ inline __device__ void intra_node_red_warp_group_device_function(const int node_ #pragma unroll for(int n = 0; n < NUM_OF_ELEMENT_PER_THREAD; n++){ int element_id = (n * INTRA_NODE_RED_GROUP::size()) + INTRA_NODE_RED_GROUP::thread_rank(); - __nv_bfloat162 src_data = load_token_base_ptr[element_id]; - float2 src_data_fp32 = __bfloat1622float2(src_data); - acc_token_fp32[n].x += src_data_fp32.x; - acc_token_fp32[n].y += src_data_fp32.y; + if(element_id < NUM_OF_BF16X2_ELEMENTS_PER_TOKEN){ + __nv_bfloat162 src_data = load_token_base_ptr[element_id]; + float2 src_data_fp32 = __bfloat1622float2(src_data); + acc_token_fp32[n].x += src_data_fp32.x; + acc_token_fp32[n].y += src_data_fp32.y; + } } if constexpr(BACKWARD_COMBINE){ @@ -1374,8 +1505,10 @@ inline __device__ void intra_node_red_warp_group_device_function(const int node_ #pragma unroll for(int n = 0; n < NUM_OF_ELEMENT_PER_THREAD; n++){ int element_id = (n * INTRA_NODE_RED_GROUP::size()) + INTRA_NODE_RED_GROUP::thread_rank(); - // Convert accumulated token back to BF16 and store the result back to shared memory token entry. - store_token_base_ptr[element_id] = __float22bfloat162_rn(acc_token_fp32[n]); + if(element_id < NUM_OF_BF16X2_ELEMENTS_PER_TOKEN){ + // Convert accumulated token back to BF16 and store the result back to shared memory token entry. + store_token_base_ptr[element_id] = __float22bfloat162_rn(acc_token_fp32[n]); + } } // Store the prob(optional). @@ -1564,6 +1697,7 @@ inline __device__ void inter_node_N2N_warp_group_device_function(const int node_ token_idx_in_chunk < NUM_OF_TOKENS_PER_CHUNK; token_idx_in_chunk += INTER_NODE_RDMA_GROUP::size()) { int64_t token_idx = token_idx_in_chunk + chunk_id * NUM_OF_TOKENS_PER_CHUNK; + int64_t local_token_idx = rdma_remote_node_id * MAX_NUM_OF_TOKENS_PER_RANK + token_idx; bool need_write = false; if (token_idx_in_chunk < token_range) { need_write = rdma_to_attn_map[token_idx_in_chunk + chunk_base_token_idx]; @@ -1581,8 +1715,7 @@ inline __device__ void inter_node_N2N_warp_group_device_function(const int node_ DOCA_GPUNETIO_IB_MLX5_WQE_CTRL_CQ_UPDATE, 0, smem_mr_info_ptr[rdma_remote_node_id].token_raddr + token_idx * HIDDEN_DIM * sizeof(uint16_t), smem_mr_info_ptr[rdma_remote_node_id].token_rkey, - smem_mr_info_ptr[rdma_remote_node_id].token_laddr + (rdma_remote_node_id * num_of_tokens_per_rank + token_idx) * HIDDEN_DIM * sizeof(uint16_t), - smem_mr_info_ptr[rdma_remote_node_id].token_lkey, + smem_mr_info_ptr[rdma_remote_node_id].token_laddr + local_token_idx * HIDDEN_DIM * sizeof(uint16_t), smem_mr_info_ptr[rdma_remote_node_id].token_lkey, HIDDEN_DIM * sizeof(uint16_t)); if constexpr(BACKWARD_COMBINE) { my_wqe_idx += write_cnt; @@ -1592,8 +1725,7 @@ inline __device__ void inter_node_N2N_warp_group_device_function(const int node_ DOCA_GPUNETIO_IB_MLX5_WQE_CTRL_CQ_UPDATE, 0, smem_mr_info_ptr[rdma_remote_node_id].prob_raddr + token_idx * (NUM_OF_EXPERTS_PER_RANK * NUM_OF_RANKS_PER_NODE) * sizeof(float), smem_mr_info_ptr[rdma_remote_node_id].prob_rkey, - smem_mr_info_ptr[rdma_remote_node_id].prob_laddr + (rdma_remote_node_id * num_of_tokens_per_rank + token_idx) * (NUM_OF_EXPERTS_PER_RANK * NUM_OF_RANKS_PER_NODE) * sizeof(float), - smem_mr_info_ptr[rdma_remote_node_id].prob_lkey, + smem_mr_info_ptr[rdma_remote_node_id].prob_laddr + local_token_idx * (NUM_OF_EXPERTS_PER_RANK * NUM_OF_RANKS_PER_NODE) * sizeof(float), smem_mr_info_ptr[rdma_remote_node_id].prob_lkey, (NUM_OF_EXPERTS_PER_RANK * NUM_OF_RANKS_PER_NODE) * sizeof(float)); } } @@ -1622,21 +1754,25 @@ inline __device__ void inter_node_N2N_warp_group_device_function(const int node_ } if (INTER_NODE_RDMA_GROUP::thread_rank() < NUM_OF_NODES - 1) { struct doca_gpu_dev_verbs_qp *qp = d_qps_gpu[block_offset + INTER_NODE_RDMA_GROUP::thread_rank()]; - int status = doca_gpu_dev_verbs_poll_cq( - doca_gpu_dev_verbs_qp_get_cq_sq(qp), - smem_inter_node_num_of_write_per_node_ptr[INTER_NODE_RDMA_GROUP::thread_rank()]); - assert(status >= 0); + uint32_t wc_num_to_poll = smem_inter_node_num_of_write_per_node_ptr[INTER_NODE_RDMA_GROUP::thread_rank()]; + if (wc_num_to_poll > 0) { + int status = doca_gpu_dev_verbs_poll_cq( + doca_gpu_dev_verbs_qp_get_cq_sq(qp), wc_num_to_poll); + assert(status >= 0); + } } token_consumer_parity ^= 1; } #endif -// Device function for inter-node G2S warp for combine kernel. There can be only 1 such warp per CUDA block! +// Device function for inter-node G2S warp for combine kernel. template; - static_assert(NUM_OF_TOKENS_PER_GROUP == sizeof(rdma_to_attn_map_load_t), "Current implementation requires NUM_OF_TOKENS_PER_GROUP to be 1/2/4/8/16."); - - //constexpr int NUM_OF_RDMA_TO_ATTN_LOAD_ITER_PER_CHUNK = NUM_OF_TOKENS_PER_CHUNK / sizeof(rdma_to_attn_map_load_t); - //constexpr int NUM_OF_TOKENS_PER_RDMA_TO_ATTN_LOAD_ITER = sizeof(rdma_to_attn_map_load_t) / sizeof(bool); // Load sparse_to_dense_map according to the NUM_OF_RANKS_PER_NODE. using sparse_to_dense_map_load_t = Copy_t; @@ -1682,17 +1819,21 @@ inline __device__ void inter_node_G2S_warp_group_device_function(const int node_ const int remainder_chunk_size = num_of_tokens_per_rank % NUM_OF_TOKENS_PER_CHUNK; // How many chunks per rank. Including full chunks and the remainder chunk. const int num_of_chunks_per_rank = ((num_of_tokens_per_rank - 1) / NUM_OF_TOKENS_PER_CHUNK) + 1; + const int max_num_of_chunks_per_rank = ((MAX_NUM_OF_TOKENS_PER_RANK - 1) / NUM_OF_TOKENS_PER_CHUNK) + 1; // Total number of chunks to process in the output buffer(attn buffer). output buffer(attn buffer) will only have 1 rank's tokens. const int total_num_of_chunks = num_of_chunks_per_rank; // The rdma_to_attn_map need to be paded to multiple of rdma_to_attn_map_load_t per node. // The largest size of rdma_to_attn_map_load_t allowed in all Hybrid-EP kernels are 16B(16 bools), so need to be paded to 16B per node. // That means the size of rdma_to_attn_map should be rdma_to_attn_map_size_per_node * NUM_OF_NODES. const int rdma_to_attn_map_size_per_node = (((num_of_tokens_per_rank - 1) / 16) + 1) * 16; + // Starting and ending index within G2S FIFO for this warp(pipeline). + const int starting_G2S_index = NUM_OF_STAGES_G2S_PER_WARP * INTER_NODE_G2S_GROUP::warp_rank(); + const int ending_G2S_index = NUM_OF_STAGES_G2S_PER_WARP * (INTER_NODE_G2S_GROUP::warp_rank() + 1); // Token stage id and phase. - int token_stage = 0; + int token_stage = starting_G2S_index; uint32_t token_consumer_parity = 1; - // Only 1 thread within the intra-node G2S warp will be active, other threads will just exit. + // Only 1 thread within each inter-node G2S warp will be active, other threads will just exit. if(elect_sync(~0)){ // Iterate through all chunks. All chunks will assign to all CUDA block. for(int i = 0; i < total_num_of_chunks; i++){ @@ -1707,8 +1848,8 @@ inline __device__ void inter_node_G2S_warp_group_device_function(const int node_ num_of_token_groups_for_current_chunk = NUM_OF_TOKEN_GROUPS_PER_CHUNK; current_chunk_size = NUM_OF_TOKENS_PER_CHUNK; } - const rdma_to_attn_map_load_t* rdma_to_attn_map_load_base_addr = reinterpret_cast(rdma_to_attn_map + - (node_rank * rdma_to_attn_map_size_per_node + i * NUM_OF_TOKENS_PER_CHUNK)); + + const bool* rdma_to_attn_map_load_base_addr = rdma_to_attn_map + (node_rank * rdma_to_attn_map_size_per_node + i * NUM_OF_TOKENS_PER_CHUNK); const int32_t* sparse_to_dense_map_load_base_addr = sparse_to_dense_map + (node_rank * num_of_tokens_per_rank + i * NUM_OF_TOKENS_PER_CHUNK) * NUM_OF_RANKS_PER_NODE; const bool* attn_to_rdma_map_load_base_addr = attn_to_rdma_map + (i * NUM_OF_TOKENS_PER_CHUNK) * (NUM_OF_NODES - 1); @@ -1723,10 +1864,9 @@ inline __device__ void inter_node_G2S_warp_group_device_function(const int node_ // Iterate through all token groups within this chunk which assign to this CUDA block. for(int j = blockIdx.x; j < num_of_token_groups_for_current_chunk; j += NUM_OF_BLOCKS){ - rdma_to_attn_map_load_t rdma_to_attn_map_data = rdma_to_attn_map_load_base_addr[j]; // Iterate through all dst(output) tokens within this token group. - #pragma unroll - for(int k = 0; k < NUM_OF_TOKENS_PER_GROUP; k++){ + // Assign each dst token to each G2S warp(pipeline) using a round-robin fasion. + for(int k = INTER_NODE_G2S_GROUP::warp_rank(); k < NUM_OF_TOKENS_PER_GROUP; k += INTER_NODE_G2S_GROUP::warp_size()){ int current_token_id = j * NUM_OF_TOKENS_PER_GROUP + k; // If the current token is out-of-bound, then just end this load iter. if(current_token_id >= current_chunk_size){ @@ -1736,7 +1876,7 @@ inline __device__ void inter_node_G2S_warp_group_device_function(const int node_ // Accumulate local tokens first, then rdma tokens. // Check whether this dst token is needed by this(local) node. If not needed, just skip local accumulation. - bool token_needed_by_this_node = *(reinterpret_cast(&rdma_to_attn_map_data) + k); + bool token_needed_by_this_node = rdma_to_attn_map_load_base_addr[current_token_id]; // If this dst token is needed by this node, load the sparse_to_dense map and load the local src token for this dst token. if(token_needed_by_this_node){ const sparse_to_dense_map_load_t* sparse_to_dense_map_load_addr = reinterpret_cast @@ -1803,8 +1943,8 @@ inline __device__ void inter_node_G2S_warp_group_device_function(const int node_ // Goto next token entry in shared memory. token_stage += 1; - if(token_stage == NUM_OF_STAGES_G2S){ - token_stage = 0; + if(token_stage == ending_G2S_index){ + token_stage = starting_G2S_index; token_consumer_parity ^= 1; } } @@ -1824,7 +1964,7 @@ inline __device__ void inter_node_G2S_warp_group_device_function(const int node_ if(attn_to_rdma_map_load_addr[rdma_buffer_tile_id]){ // If the current chunk is not ready yet, wait for related rdma inter-node group buffer chunks ready first. if(rdma_flag_clear[n - 1] == false){ - const uint64_t* flag_location = rdma_inter_node_group_flags + (rdma_buffer_tile_id * num_of_chunks_per_rank + i); + const uint64_t* flag_location = rdma_inter_node_group_flags + (rdma_buffer_tile_id * max_num_of_chunks_per_rank + i); uint64_t rdma_flag = 0; do{ rdma_flag = 0; @@ -1843,7 +1983,7 @@ inline __device__ void inter_node_G2S_warp_group_device_function(const int node_ // Load the src token from this rdma inter-node group buffer chunk to shared memory entry. uint32_t total_tx_size = 0; const uint16_t* rdma_inter_node_group_token_load_addr = rdma_inter_node_group_token + - (rdma_buffer_tile_id * num_of_tokens_per_rank + + (rdma_buffer_tile_id * MAX_NUM_OF_TOKENS_PER_RANK + i * NUM_OF_TOKENS_PER_CHUNK + j * NUM_OF_TOKENS_PER_GROUP + k) * HIDDEN_DIM; cuda::ptx::cp_async_bulk(cuda::ptx::space_shared, @@ -1857,7 +1997,7 @@ inline __device__ void inter_node_G2S_warp_group_device_function(const int node_ if constexpr(BACKWARD_COMBINE){ const float* rdma_inter_node_group_prob_load_addr = rdma_inter_node_group_prob + - (rdma_buffer_tile_id * num_of_tokens_per_rank + + (rdma_buffer_tile_id * MAX_NUM_OF_TOKENS_PER_RANK + i * NUM_OF_TOKENS_PER_CHUNK + j * NUM_OF_TOKENS_PER_GROUP + k) * (NUM_OF_EXPERTS_PER_RANK * NUM_OF_RANKS_PER_NODE); @@ -1881,8 +2021,8 @@ inline __device__ void inter_node_G2S_warp_group_device_function(const int node_ // Goto next token entry in shared memory. token_stage += 1; - if(token_stage == NUM_OF_STAGES_G2S){ - token_stage = 0; + if(token_stage == ending_G2S_index){ + token_stage = starting_G2S_index; token_consumer_parity ^= 1; } } @@ -1891,11 +2031,22 @@ inline __device__ void inter_node_G2S_warp_group_device_function(const int node_ } } } + #ifdef HYBRID_EP_BUILD_MULTINODE_ENABLE + // Update residue flags. + int residue_flag_count = max_num_of_chunks_per_rank - num_of_chunks_per_rank; + for (int node_id = blockIdx.x; node_id < NUM_OF_NODES - 1; node_id += gridDim.x) { + uint64_t *residue_flag_base_ptr = rdma_inter_node_group_flags + (node_id * max_num_of_chunks_per_rank + num_of_chunks_per_rank); + for (int flag_id = INTER_NODE_G2S_GROUP::thread_rank(); flag_id < residue_flag_count; flag_id += INTER_NODE_G2S_GROUP::size()) { + residue_flag_base_ptr[flag_id] = *expected_flag_value; + } + } + #endif // HYBRID_EP_BUILD_MULTINODE_ENABLE } // Device function for inter-node reduction warp group for combine kernel. template; - static_assert(NUM_OF_TOKENS_PER_GROUP == sizeof(rdma_to_attn_map_load_t), "Current implementation requires NUM_OF_TOKENS_PER_GROUP to be 1/2/4/8/16."); // Processing token using BF16x2 intruction, HIDDEN_DIM must be multiple of 2. static_assert(HIDDEN_DIM % 2 == 0, "HIDDEN_DIM must be multiple of 2."); - constexpr int NUM_OF_ELEMENT_PER_THREAD = (HIDDEN_DIM / 2) / INTER_NODE_RED_GROUP::size(); + constexpr int NUM_OF_BF16X2_ELEMENTS_PER_TOKEN = HIDDEN_DIM / 2; + //static_assert((HIDDEN_DIM / 2) % NUM_OF_THREADS_PER_PIPELINE == 0, "HIDDEN_DIM / 2 must be multiple of NUM_OF_THREADS_PER_PIPELINE, we may relax this if it is the problem."); + constexpr int NUM_OF_ELEMENT_PER_THREAD = ((NUM_OF_BF16X2_ELEMENTS_PER_TOKEN - 1) / NUM_OF_THREADS_PER_PIPELINE) + 1; // Processing prob using fp32. - constexpr int NUM_OF_PROB_VEC_ELEMENT_PER_THREAD = ((NUM_OF_EXPERTS_PER_RANK * NUM_OF_RANKS_PER_NODE - 1) / INTER_NODE_RED_GROUP::size()) + 1; + constexpr int NUM_OF_PROB_VEC_ELEMENT_PER_THREAD = ((NUM_OF_EXPERTS_PER_RANK * NUM_OF_RANKS_PER_NODE - 1) / NUM_OF_THREADS_PER_PIPELINE) + 1; //static_assert(INTER_NODE_RED_GROUP::size() >= NUM_OF_EXPERTS_PER_RANK * NUM_OF_RANKS_PER_NODE, "The size of inter-node reduction warp group must not be smaller than prob size."); // The inter node reduction warp group of each CUDA block produce a token group of a chunk at a time. Token groups of each chunk assigned to each CUDA block in interleave pattern. @@ -1945,12 +2107,22 @@ inline __device__ void inter_node_red_warp_group_device_function(const int node_ // The largest size of rdma_to_attn_map_load_t allowed in all Hybrid-EP kernels are 16B(16 bools), so need to be paded to 16B per node. // That means the size of rdma_to_attn_map should be rdma_to_attn_map_size_per_node * NUM_OF_NODES. const int rdma_to_attn_map_size_per_node = (((num_of_tokens_per_rank - 1) / 16) + 1) * 16; + // Pipeline rank and thread/warp rank within the pipeline for this thread. + const int pipeline_rank = INTER_NODE_RED_GROUP::thread_rank() / NUM_OF_THREADS_PER_PIPELINE; + const int thread_rank_within_pipeline = INTER_NODE_RED_GROUP::thread_rank() % NUM_OF_THREADS_PER_PIPELINE; + const int warp_rank_within_pipeline = thread_rank_within_pipeline / WARP_SIZE; + // Starting and ending index within G2S FIFO for this pipeline. + const int starting_G2S_index = NUM_OF_STAGES_G2S_PER_PIPELINE * pipeline_rank; + const int ending_G2S_index = NUM_OF_STAGES_G2S_PER_PIPELINE * (pipeline_rank + 1); // Src token stage id and phase. - int token_stage = 0; + int token_stage = starting_G2S_index; uint32_t token_producer_parity = 0; + // Starting and ending index within S2G FIFO for this pipeline. + const int starting_S2G_index = NUM_OF_STAGES_S2G_PER_PIPELINE * pipeline_rank; + const int ending_S2G_index = NUM_OF_STAGES_S2G_PER_PIPELINE * (pipeline_rank + 1); // Dst token stage id. - int dst_token_stage = 0; + int dst_token_stage = starting_S2G_index; // Iterate through all chunks. All chunks will assign to all CUDA block. for(int i = 0; i < total_num_of_chunks; i++){ @@ -1965,8 +2137,8 @@ inline __device__ void inter_node_red_warp_group_device_function(const int node_ num_of_token_groups_for_current_chunk = NUM_OF_TOKEN_GROUPS_PER_CHUNK; current_chunk_size = NUM_OF_TOKENS_PER_CHUNK; } - const rdma_to_attn_map_load_t* rdma_to_attn_map_load_base_addr = reinterpret_cast(rdma_to_attn_map + - (node_rank * rdma_to_attn_map_size_per_node + i * NUM_OF_TOKENS_PER_CHUNK)); + + const bool* rdma_to_attn_map_load_base_addr = rdma_to_attn_map + (node_rank * rdma_to_attn_map_size_per_node + i * NUM_OF_TOKENS_PER_CHUNK); const bool* attn_to_rdma_map_load_base_addr = attn_to_rdma_map + (i * NUM_OF_TOKENS_PER_CHUNK) * (NUM_OF_NODES - 1); uint16_t* attn_output_token_base_ptr = attn_output_token + (i * NUM_OF_TOKENS_PER_CHUNK) * HIDDEN_DIM; float* attn_output_prob_base_ptr; @@ -1975,10 +2147,9 @@ inline __device__ void inter_node_red_warp_group_device_function(const int node_ } // Iterate through all token groups within this chunk which assign to this CUDA block. for(int j = blockIdx.x; j < num_of_token_groups_for_current_chunk; j += NUM_OF_BLOCKS){ - rdma_to_attn_map_load_t rdma_to_attn_map_data = rdma_to_attn_map_load_base_addr[j]; // Iterate through all dst(output) tokens within this token group. - #pragma unroll - for(int k = 0; k < NUM_OF_TOKENS_PER_GROUP; k++){ + // Assign each dst token to each pipeline using a round-robin fasion. + for(int k = pipeline_rank; k < NUM_OF_TOKENS_PER_GROUP; k += NUM_OF_DATA_PIPELINE_PER_BLOCK){ int current_token_id = j * NUM_OF_TOKENS_PER_GROUP + k; // If the current token is out-of-bound, then just end this load iter. if(current_token_id >= current_chunk_size){ @@ -2007,7 +2178,7 @@ inline __device__ void inter_node_red_warp_group_device_function(const int node_ } // Check whether this dst token is needed by this(local) node. If not needed, just skip local accumulation. - bool token_needed_by_this_node = *(reinterpret_cast(&rdma_to_attn_map_data) + k); + bool token_needed_by_this_node = rdma_to_attn_map_load_base_addr[current_token_id]; // If this dst token is needed by this node, load the local src token from shared memory and accumulate them. if(token_needed_by_this_node){ // End reduction group flag. @@ -2023,27 +2194,29 @@ inline __device__ void inter_node_red_warp_group_device_function(const int node_ } // Wait until current src token ready in shared memory. - if(INTER_NODE_RED_GROUP::warp_rank() == 0){ + if(warp_rank_within_pipeline == 0){ if(elect_sync(~0)){ while(!cuda::ptx::mbarrier_try_wait_parity(&smem_buffer_ptr->inter_node_mbarrier_G2S_buffer[token_stage][0], token_producer_parity)){} } } - arrive_and_wait(INTER_NODE_RED_GROUP::size(), 2); + arrive_and_wait(NUM_OF_THREADS_PER_PIPELINE, 2 + pipeline_rank); // Accumulate token and prob(optional). #pragma unroll for(int n = 0; n < NUM_OF_ELEMENT_PER_THREAD; n++){ - int element_id = (n * INTER_NODE_RED_GROUP::size()) + INTER_NODE_RED_GROUP::thread_rank(); - __nv_bfloat162 src_data = load_token_base_ptr[element_id]; - float2 src_data_fp32 = __bfloat1622float2(src_data); - acc_token_fp32[n].x += src_data_fp32.x; - acc_token_fp32[n].y += src_data_fp32.y; + int element_id = (n * NUM_OF_THREADS_PER_PIPELINE) + thread_rank_within_pipeline; + if(element_id < NUM_OF_BF16X2_ELEMENTS_PER_TOKEN){ + __nv_bfloat162 src_data = load_token_base_ptr[element_id]; + float2 src_data_fp32 = __bfloat1622float2(src_data); + acc_token_fp32[n].x += src_data_fp32.x; + acc_token_fp32[n].y += src_data_fp32.y; + } } if constexpr(BACKWARD_COMBINE){ #pragma unroll for(int n = 0; n < NUM_OF_PROB_VEC_ELEMENT_PER_THREAD; n++){ - int element_id = INTER_NODE_RED_GROUP::thread_rank() + n * INTER_NODE_RED_GROUP::size(); + int element_id = thread_rank_within_pipeline + n * NUM_OF_THREADS_PER_PIPELINE; if(element_id < NUM_OF_EXPERTS_PER_RANK * NUM_OF_RANKS_PER_NODE){ float src_data = load_prob_base_ptr[element_id]; acc_prob[0][n] += src_data; @@ -2054,10 +2227,10 @@ inline __device__ void inter_node_red_warp_group_device_function(const int node_ // Check flag for last src token. last_local_node_src_token = smem_buffer_ptr->inter_node_flag_G2S_buffer[token_stage]; - // Make sure all warp group have finished loading the token entry and accumulate it to the register accumulator. + // Make sure all threads within the pipeline have finished loading the token entry and accumulate it to the register accumulator. // Then notify the producer warp to load next token entry to the shared memory as the shared memory can be reused. - arrive_and_wait(INTER_NODE_RED_GROUP::size(), 2); - if(INTER_NODE_RED_GROUP::warp_rank() == 0){ + arrive_and_wait(NUM_OF_THREADS_PER_PIPELINE, 2 + pipeline_rank); + if(warp_rank_within_pipeline == 0){ if(elect_sync(~0)){ cuda::ptx::mbarrier_arrive(&smem_buffer_ptr->inter_node_mbarrier_G2S_buffer[token_stage][1]); } @@ -2065,8 +2238,8 @@ inline __device__ void inter_node_red_warp_group_device_function(const int node_ // Goto next src token entry. token_stage += 1; - if(token_stage == NUM_OF_STAGES_G2S){ - token_stage = 0; + if(token_stage == ending_G2S_index){ + token_stage = starting_G2S_index; token_producer_parity ^= 1; } @@ -2091,27 +2264,29 @@ inline __device__ void inter_node_red_warp_group_device_function(const int node_ load_prob_base_ptr = &smem_buffer_ptr->inter_node_prob_G2S_buffer[token_stage][0]; } // Wait until current src token ready in shared memory. - if(INTER_NODE_RED_GROUP::warp_rank() == 0){ + if(warp_rank_within_pipeline == 0){ if(elect_sync(~0)){ while(!cuda::ptx::mbarrier_try_wait_parity(&smem_buffer_ptr->inter_node_mbarrier_G2S_buffer[token_stage][0], token_producer_parity)){} } } - arrive_and_wait(INTER_NODE_RED_GROUP::size(), 2); + arrive_and_wait(NUM_OF_THREADS_PER_PIPELINE, 2 + pipeline_rank); // Accumulate token and prob(optional). #pragma unroll for(int m = 0; m < NUM_OF_ELEMENT_PER_THREAD; m++){ - int element_id = (m * INTER_NODE_RED_GROUP::size()) + INTER_NODE_RED_GROUP::thread_rank(); - __nv_bfloat162 src_data = load_token_base_ptr[element_id]; - float2 src_data_fp32 = __bfloat1622float2(src_data); - acc_token_fp32[m].x += src_data_fp32.x; - acc_token_fp32[m].y += src_data_fp32.y; + int element_id = (m * NUM_OF_THREADS_PER_PIPELINE) + thread_rank_within_pipeline; + if(element_id < NUM_OF_BF16X2_ELEMENTS_PER_TOKEN){ + __nv_bfloat162 src_data = load_token_base_ptr[element_id]; + float2 src_data_fp32 = __bfloat1622float2(src_data); + acc_token_fp32[m].x += src_data_fp32.x; + acc_token_fp32[m].y += src_data_fp32.y; + } } if constexpr(BACKWARD_COMBINE){ #pragma unroll for(int m = 0; m < NUM_OF_PROB_VEC_ELEMENT_PER_THREAD; m++){ - int element_id = INTER_NODE_RED_GROUP::thread_rank() + m * INTER_NODE_RED_GROUP::size(); + int element_id = thread_rank_within_pipeline + m * NUM_OF_THREADS_PER_PIPELINE; if(element_id < NUM_OF_EXPERTS_PER_RANK * NUM_OF_RANKS_PER_NODE){ acc_prob[n][m] = load_prob_base_ptr[element_id]; } @@ -2120,10 +2295,10 @@ inline __device__ void inter_node_red_warp_group_device_function(const int node_ // Inter-node token does not need flag. - // Make sure all warp group have finished loading the token entry and accumulate it to the register accumulator. + // Make sure all threads within the pipeline have finished loading the token entry and accumulate it to the register accumulator. // Then notify the producer warp to load next token entry to the shared memory as the shared memory can be reused. - arrive_and_wait(INTER_NODE_RED_GROUP::size(), 2); - if(INTER_NODE_RED_GROUP::warp_rank() == 0){ + arrive_and_wait(NUM_OF_THREADS_PER_PIPELINE, 2 + pipeline_rank); + if(warp_rank_within_pipeline == 0){ if(elect_sync(~0)){ cuda::ptx::mbarrier_arrive(&smem_buffer_ptr->inter_node_mbarrier_G2S_buffer[token_stage][1]); } @@ -2131,8 +2306,8 @@ inline __device__ void inter_node_red_warp_group_device_function(const int node_ // Goto next src token entry. token_stage += 1; - if(token_stage == NUM_OF_STAGES_G2S){ - token_stage = 0; + if(token_stage == ending_G2S_index){ + token_stage = starting_G2S_index; token_producer_parity ^= 1; } } @@ -2147,21 +2322,23 @@ inline __device__ void inter_node_red_warp_group_device_function(const int node_ store_prob_base_ptr = &smem_buffer_ptr->inter_node_prob_S2G_buffer[dst_token_stage][0]; } - // Let the TMA thread to wait for previously issued TMA S2G operations finish reading this entry. - if(INTER_NODE_RED_GROUP::warp_rank() == 0){ + // Select the TMA thread within the pipeline to wait for previously issued TMA S2G operations finish reading this entry. + if(warp_rank_within_pipeline == 0){ if(elect_sync(~0)){ - cuda::ptx::cp_async_bulk_wait_group_read(cuda::ptx::n32_t{}); + cuda::ptx::cp_async_bulk_wait_group_read(cuda::ptx::n32_t{}); } } - // Make sure all threads within the red warp group have wait for previously issued TMA S2G operations finish reading this entry before storing new data to this entry. - arrive_and_wait(INTER_NODE_RED_GROUP::size(), 2); + // Make sure all threads within the pipeline have wait for previously issued TMA S2G operations finish reading this entry before storing new data to this entry. + arrive_and_wait(NUM_OF_THREADS_PER_PIPELINE, 2 + pipeline_rank); // Store the token. #pragma unroll for(int n = 0; n < NUM_OF_ELEMENT_PER_THREAD; n++){ - int element_id = (n * INTER_NODE_RED_GROUP::size()) + INTER_NODE_RED_GROUP::thread_rank(); - // Convert accumulated token back to BF16 and store the result back to shared memory token entry. - store_token_base_ptr[element_id] = __float22bfloat162_rn(acc_token_fp32[n]); + int element_id = (n * NUM_OF_THREADS_PER_PIPELINE) + thread_rank_within_pipeline; + if(element_id < NUM_OF_BF16X2_ELEMENTS_PER_TOKEN){ + // Convert accumulated token back to BF16 and store the result back to shared memory token entry. + store_token_base_ptr[element_id] = __float22bfloat162_rn(acc_token_fp32[n]); + } } // Store the prob(optional). @@ -2172,7 +2349,7 @@ inline __device__ void inter_node_red_warp_group_device_function(const int node_ int element_base_id = attn_prob_output_node_id * (NUM_OF_EXPERTS_PER_RANK * NUM_OF_RANKS_PER_NODE); #pragma unroll for(int m = 0; m < NUM_OF_PROB_VEC_ELEMENT_PER_THREAD; m++){ - int element_id = INTER_NODE_RED_GROUP::thread_rank() + m * INTER_NODE_RED_GROUP::size(); + int element_id = thread_rank_within_pipeline + m * NUM_OF_THREADS_PER_PIPELINE; if(element_id < NUM_OF_EXPERTS_PER_RANK * NUM_OF_RANKS_PER_NODE){ store_prob_base_ptr[element_base_id + element_id] = acc_prob[n][m]; } @@ -2183,11 +2360,11 @@ inline __device__ void inter_node_red_warp_group_device_function(const int node_ // Make sure the shared memory stored by current thread is visible by async proxy. cuda::ptx::fence_proxy_async(cuda::ptx::space_shared); - // Make sure all threads within the red warp group have finished storing the current token entry and making it visible to async proxy. - arrive_and_wait(INTER_NODE_RED_GROUP::size(), 2); + // Make sure all threads within the pipeline have finished storing the current token entry and making it visible to async proxy. + arrive_and_wait(NUM_OF_THREADS_PER_PIPELINE, 2 + pipeline_rank); - // Let the TMA thread to issue S2G TMA operations for current token entry. - if(INTER_NODE_RED_GROUP::warp_rank() == 0){ + // Select the TMA thread within the pipeline to issue S2G TMA operations for current token entry. + if(warp_rank_within_pipeline == 0){ if(elect_sync(~0)){ uint16_t* current_token_addr = attn_output_token_base_ptr + (j * NUM_OF_TOKENS_PER_GROUP + k) * HIDDEN_DIM; // Store the token from shared to global output. @@ -2214,8 +2391,8 @@ inline __device__ void inter_node_red_warp_group_device_function(const int node_ // Goto next dst token entry. dst_token_stage += 1; - if(dst_token_stage == NUM_OF_STAGES_S2G){ - dst_token_stage = 0; + if(dst_token_stage == ending_S2G_index){ + dst_token_stage = starting_S2G_index; } } } @@ -2279,10 +2456,13 @@ template // Each CUDA block of dispatch kernel has 3 warp groups and has the following layout: // 1. inter-node warp group(i.e. RDMA N2N warp group, 1 warp, only valid for multinode scenario) 2. intra-node G2S warp group(i.e. NVL G2S warp group, 1 warp). -// 3. intra-node S2G warp group(i.e. NVL S2G warp group, 1 warp). Total 2 or 3 warps per CUDA block/SM. +// 3. intra-node S2G warp group(i.e. NVL S2G warp group, 2(multinode scenario)-3(single-node scenario) warps). Total 4 warps per CUDA block/SM. __launch_bounds__(INTER_NODE_GROUP::size() + INTRA_NODE_G2S_GROUP::size() + INTRA_NODE_S2G_GROUP::size(), 1) __global__ void dispatch_kernel(const __grid_constant__ dispatch_kernel_param_t param) { - // Compile-time check. For now, 1 G2S and 1 S2G warp should be enough. + // Compile-time check. #ifdef HYBRID_EP_BUILD_MULTINODE_ENABLE static_assert(INTER_NODE_GROUP::size() == 32, "Dispatch kernel only support 1 N2N warp currently."); #endif static_assert(INTRA_NODE_G2S_GROUP::size() == 32, "Dispatch kernel only support 1 G2S warp currently."); - static_assert(INTRA_NODE_S2G_GROUP::size() == 32, "Dispatch kernel only support 1 S2G warp currently."); // The token and its properties should meet size and alignment requirement. // Currently, we use TMA to copy prob data, which need at least 16B size and alignment(which requires expert per node to be multiple of 4). // We need to add padding or not using TMA for prob, if we want to support other scenario. static_assert((NUM_OF_EXPERTS_PER_RANK * NUM_OF_RANKS_PER_NODE * sizeof(float)) % 16 == 0, "Currently, expert per node must be multiple of 4(So the prob for each token is multiple of 16B) to make TMA work."); - // If FP8 token is used, HIDDEN_DIM must be multiple of 512 to make scaling factor multiple of 16B to make TMA work. - static_assert(((HIDDEN_DIM / 128) * sizeof(float)) % 16 == 0, "Currently, scaling factor per token must be multiple of 16B."); - + static_assert((HIDDEN_DIM * sizeof(TOKEN_DATA_TYPE)) % 16 == 0, "Currently, the size of token must be multiple of 16B to make TMA work."); + if constexpr(std::is_same::value){ + // If FP8 token is used, HIDDEN_DIM must be multiple of 128 for scaling factor usage. + static_assert(HIDDEN_DIM % 128 == 0, "HIDDEN_DIM must be multiple of 128 for scaling factor"); + // If FP8 token is used, HIDDEN_DIM must be multiple of 512 to make scaling factor multiple of 16B to make TMA work. + static_assert(((HIDDEN_DIM / 128) * sizeof(float)) % 16 == 0, "Currently, scaling factor per token must be multiple of 16B."); + } // Shared memory used over 48KB, should use dynamic shared memory. extern __shared__ uint8_t smem_bytes[]; @@ -2320,8 +2503,13 @@ __global__ void dispatch_kernel(const __grid_constant__ dispatch_kernel_param_t< for(int i = 0; i < NUM_OF_STAGES; i++){ // Initialize mbarrier cuda::ptx::mbarrier_init(&smem_buffer_ptr->intra_node_mbarrier_buffer[i][0], 1); - cuda::ptx::mbarrier_init(&smem_buffer_ptr->intra_node_mbarrier_buffer[i][1], 1); + cuda::ptx::mbarrier_init(&smem_buffer_ptr->intra_node_mbarrier_buffer[i][1], INTRA_NODE_S2G_GROUP::warp_size()); } + // Initialize sparse_to_dense map mbarrier. + cuda::ptx::mbarrier_init(&smem_buffer_ptr->sparse_to_dense_map_mbarrier_buffer[0], 1); + cuda::ptx::mbarrier_init(&smem_buffer_ptr->sparse_to_dense_map_mbarrier_buffer[1], 1); + // Initialize S2G warp group mbarrier. + cuda::ptx::mbarrier_init(&smem_buffer_ptr->S2G_group_mbarrier_buffer, INTRA_NODE_S2G_GROUP::warp_size()); // Make mbarriers initialization visible to async proxy(TMA). cuda::ptx::fence_proxy_async(cuda::ptx::space_shared); } @@ -2338,20 +2526,22 @@ __global__ void dispatch_kernel(const __grid_constant__ dispatch_kernel_param_t< // Inter-node warps groups. if constexpr(NUM_OF_NODES != 1){ N2N_warp_group_device_function - - (param.node_rank, param.num_of_tokens_per_rank, param.attn_to_rdma_map, param.d_qps_gpu, param.mr_info, smem_buffer_ptr); + + (param.node_rank, param.num_of_tokens_per_rank, param.attn_to_rdma_map, reinterpret_cast(param.d_qps_gpu), reinterpret_cast(param.mr_info), smem_buffer_ptr); } #endif }else if(threadIdx_x_int < INTER_NODE_GROUP::size() + INTRA_NODE_G2S_GROUP::size()){ // Intra-node G2S warp groups. G2S_warp_group_device_function - - (param.node_rank, param.num_of_tokens_per_rank, param.expected_rdma_flag_value, param.rdma_to_attn_map, param.attn_input_token, param.attn_input_prob, param.attn_input_token_scaling_factor, param.rdma_inter_node_group_token, - param.rdma_inter_node_group_prob, param.rdma_inter_node_group_scaling_factor, param.rdma_inter_node_group_flags, smem_buffer_ptr); + + (param.node_rank, param.num_of_tokens_per_rank, param.expected_rdma_flag_value, param.rdma_to_attn_map, + param.attn_input_token, param.attn_input_prob, param.attn_input_token_scaling_factor, param.rdma_inter_node_group_token, + param.rdma_inter_node_group_prob, param.rdma_inter_node_group_scaling_factor, param.rdma_inter_node_group_flags, smem_buffer_ptr); }else if(threadIdx_x_int < INTER_NODE_GROUP::size() + INTRA_NODE_G2S_GROUP::size() + INTRA_NODE_S2G_GROUP::size()){ // Intra-node S2G warp groups. S2G_warp_group_device_function - + (param.local_rank, param.node_rank, param.num_of_tokens_per_rank, param.rdma_to_attn_map, param.sparse_to_dense_map, param.expert_output_token, param.expert_output_prob, param.expert_output_scaling_factor, smem_buffer_ptr); }else{ @@ -2369,6 +2559,8 @@ template // Each CUDA block of combine kernel has 5 warp groups and has the following layout: -// 1. intra-node reduction warp group(4 warps, only valid for multinode scenario). 2. inter-node reduction warp group(4 warps). -// 3. intra-node G2S warp group(1 warp, only valid for multinode scenario). 4. inter-node G2S warp group(1 warp). 5. inter-node N2N rdma warp group(1 warp, only valid for multinode scenario). -// Total 5 or 11 warps per CUDA block/SM. +// 1. intra-node reduction warp group(4 warps, only valid for multinode scenario). 2. inter-node reduction warp group(4 warps, 1 pipeline for multinode scenario, 2 pipeline otherwise). +// 3. intra-node G2S warp group(1 warp, only valid for multinode scenario). 4. inter-node G2S warp group(1 warp for multinode scenario, 2 warps otherwise). 5. inter-node N2N rdma warp group(1 warp, only valid for multinode scenario). +// Total 6(single-node) or 11(multi-node) warps per CUDA block/SM. __launch_bounds__(INTRA_NODE_RED_GROUP::size() + INTER_NODE_RED_GROUP::size() + INTRA_NODE_G2S_GROUP::size() + INTER_NODE_G2S_GROUP::size() + INTER_NODE_RDMA_GROUP::size(), 1) __global__ void combine_kernel(const __grid_constant__ combine_kernel_param_t param) { - // Compile-time check. For now, 1 G2S and 1 S2G warp should be enough. + // Compile-time check. #ifdef HYBRID_EP_BUILD_MULTINODE_ENABLE static_assert(INTRA_NODE_G2S_GROUP::size() == 32, "Combine kernel only support 1 INTRA_NODE_G2S warp currently."); -#endif static_assert(INTER_NODE_G2S_GROUP::size() == 32, "Combine kernel only support 1 INTER_NODE_G2S warp currently."); +#endif // The token and its properties should meet size and alignment requirement. // Currently, we use TMA to copy prob data, which need at least 16B size and alignment(which requires expert per node to be multiple of 4). // We need to add padding or not using TMA for prob, if we want to support other scenario. static_assert((NUM_OF_EXPERTS_PER_RANK * NUM_OF_RANKS_PER_NODE * sizeof(float)) % 16 == 0, "Currently, expert per node must be multiple of 4(So the prob for each token is multiple of 16B) to make TMA work."); + static_assert((HIDDEN_DIM * sizeof(uint16_t)) % 16 == 0, "Currently, the size of token must be multiple of 16B to make TMA work."); static_assert(MAX_NUM_OF_TOKENS_PER_RANK % NUM_OF_TOKENS_PER_CHUNK == 0, "MAX_NUM_OF_TOKENS_PER_RANK must be multiple of NUM_OF_TOKENS_PER_CHUNK."); constexpr int MAX_NUM_OF_CHUNKS_PER_RANK = MAX_NUM_OF_TOKENS_PER_RANK / NUM_OF_TOKENS_PER_CHUNK; @@ -2412,6 +2605,7 @@ __global__ void combine_kernel(const __grid_constant__ combine_kernel_param_t pa extern __shared__ uint8_t smem_bytes[]; using cur_smem_t = combine_kernel_dynamic_shared_memory_buffer_t ; + cur_smem_t* smem_buffer_ptr = reinterpret_cast(smem_bytes); // Let first thread of each CUDA block initialize the mbarrier. @@ -2448,14 +2642,14 @@ __global__ void combine_kernel(const __grid_constant__ combine_kernel_param_t pa // Intra-node reduction warp group. if constexpr(NUM_OF_NODES != 1){ intra_node_red_warp_group_device_function - + (param.node_rank, param.num_of_tokens_per_rank, param.rdma_to_attn_map, param.rdma_intra_node_red_token, param.rdma_intra_node_red_prob, smem_buffer_ptr); } }else if(threadIdx_x_int < INTRA_NODE_RED_GROUP::size() + INTER_NODE_RED_GROUP::size()){ // Inter-node reduction warp group. inter_node_red_warp_group_device_function - (param.node_rank, param.num_of_tokens_per_rank, param.rdma_to_attn_map, param.attn_to_rdma_map, param.attn_output_token, param.attn_output_prob, smem_buffer_ptr); }else if(threadIdx_x_int < INTRA_NODE_RED_GROUP::size() + INTER_NODE_RED_GROUP::size() + INTRA_NODE_G2S_GROUP::size()){ @@ -2468,7 +2662,7 @@ __global__ void combine_kernel(const __grid_constant__ combine_kernel_param_t pa }else if(threadIdx_x_int < INTRA_NODE_RED_GROUP::size() + INTER_NODE_RED_GROUP::size() + INTRA_NODE_G2S_GROUP::size() + INTER_NODE_G2S_GROUP::size()){ // Inter-node G2S warp group. inter_node_G2S_warp_group_device_function - (param.node_rank, param.num_of_tokens_per_rank, param.expected_rdma_flag_value, param.rdma_to_attn_map, param.attn_to_rdma_map, param.sparse_to_dense_map, param.expert_input_token, param.expert_input_prob, param.rdma_inter_node_group_token, param.rdma_inter_node_group_prob, param.rdma_inter_node_group_flags, smem_buffer_ptr); @@ -2478,7 +2672,7 @@ __global__ void combine_kernel(const __grid_constant__ combine_kernel_param_t pa if constexpr(NUM_OF_NODES != 1){ inter_node_N2N_warp_group_device_function - (param.node_rank, param.num_of_tokens_per_rank, param.rdma_to_attn_map, param.d_qps_gpu, param.mr_info, smem_buffer_ptr); + (param.node_rank, param.num_of_tokens_per_rank, param.rdma_to_attn_map, reinterpret_cast(param.d_qps_gpu), reinterpret_cast(param.mr_info), smem_buffer_ptr); } #endif }else{ @@ -2983,6 +3177,8 @@ public: typename TOKEN_DATA_TYPE, // Number of token entry in the shared memory. int NUM_OF_STAGES, + // Number of in-flight S2G token entry in the shared memory, must be smaller than NUM_OF_STAGES. + int NUM_OF_IN_FLIGHT_S2G, // The size of token chunk used in dispatch kernel. int NUM_OF_TOKENS_PER_CHUNK, // Grid size for dispatch kernel(1:1 block:SM mapping). @@ -2997,16 +3193,17 @@ public: #ifdef HYBRID_EP_BUILD_MULTINODE_ENABLE using INTER_NODE_GROUP = warp_group<1, 0>; using INTRA_NODE_G2S_GROUP = warp_group<1, 1>; - using INTRA_NODE_S2G_GROUP = warp_group<1, 2>; + using INTRA_NODE_S2G_GROUP = warp_group<2, 2>; #else using INTER_NODE_GROUP = warp_group<0, 0>; using INTRA_NODE_G2S_GROUP = warp_group<1, 0>; - using INTRA_NODE_S2G_GROUP = warp_group<1, 1>; + using INTRA_NODE_S2G_GROUP = warp_group<3, 1>; #endif // The shared memory needed by the dispatch kernel. - using dispatch_kernel_smem_t = dispatch_kernel_dynamic_shared_memory_buffer_t; + using dispatch_kernel_smem_t = dispatch_kernel_dynamic_shared_memory_buffer_t; // The dispatch kernel to be launched. - const auto dispatch_kernel_ptr = dispatch_kernel; // Configure dynamic shared memory for the dispatch kernel. @@ -3063,19 +3260,22 @@ public: using INTRA_NODE_G2S_GROUP = warp_group<1, 8>; using INTER_NODE_G2S_GROUP = warp_group<1, 9>; using INTER_NODE_RDMA_GROUP = warp_group<1, 10>; + constexpr int NUM_OF_DATA_PIPELINE_PER_BLOCK = 1; #else using INTRA_NODE_RED_GROUP = warp_group<0, 0>; using INTER_NODE_RED_GROUP = warp_group<4, 0>; using INTRA_NODE_G2S_GROUP = warp_group<0, 4>; - using INTER_NODE_G2S_GROUP = warp_group<1, 4>; - using INTER_NODE_RDMA_GROUP = warp_group<0, 5>; + using INTER_NODE_G2S_GROUP = warp_group<2, 4>; + using INTER_NODE_RDMA_GROUP = warp_group<0, 6>; + constexpr int NUM_OF_DATA_PIPELINE_PER_BLOCK = 2; #endif + static_assert(INTER_NODE_G2S_GROUP::warp_size() == NUM_OF_DATA_PIPELINE_PER_BLOCK, "Inter-node G2S warp group pipeline and inter-node red warp group pipeline mismatch."); // The shared memory needed by the combine kernel. using combine_kernel_smem_t = combine_kernel_dynamic_shared_memory_buffer_t; // The combine kernel to be launched. - const auto combine_kernel_ptr = combine_kernel; diff --git a/csrc/hybrid_ep/backend/ibvcore.h b/csrc/hybrid_ep/backend/ibvcore.h new file mode 100644 index 00000000..2984a797 --- /dev/null +++ b/csrc/hybrid_ep/backend/ibvcore.h @@ -0,0 +1,221 @@ +/************************************************************************* + * Copyright (c) 2016-2025, NVIDIA CORPORATION. All rights reserved. + * + * See NCCL_LICENSE.txt for license information + ************************************************************************/ + + #pragma once + + #ifdef HYBRID_EP_BUILD_MULTINODE_ENABLE + #include + #include + #include + #include + #include "utils.cuh" + + namespace hybrid_ep { + namespace { + static const int IB_GID_INDEX = -1; + static const int IB_ROUTABLE_FLID_GID_INDEX = 1; + static const int IB_ROCE_VERSION_NUM = 2; + static const sa_family_t DEFAULT_FAMILY = AF_INET; + static const char NCCL_IB_ADDR_RANGE[] = { 0 }; + + int ncclIbExtractFlid(union ibv_gid *gid) { + return ntohs(*((uint16_t*)((uintptr_t)(gid->raw) + 4))); + } + + static void* envIbAddrRange(sa_family_t af, int* mask) { + *mask = 0; + static struct in_addr addr; + static struct in6_addr addr6; + void *ret = (af == AF_INET) ? (void *)&addr : (void *)&addr6; + const char* env = NCCL_IB_ADDR_RANGE; + if (NULL == env || strlen(env) == 0) { + return NULL; + } + // INFO(NCCL_ENV, "NCCL_IB_ADDR_RANGE set by environment to %s", env); + char addrString[128] = { 0 }; + snprintf(addrString, 128, "%s", env); + char *addrStrPtr = addrString; + char *maskStrPtr = strstr(addrString, "/"); + if (NULL == maskStrPtr) { + return NULL; + } + *(maskStrPtr++) = '\0'; + if (inet_pton(af, addrStrPtr, ret) == 0) { + // INFO(NCCL_INIT|NCCL_NET, "NET/IB: Ip address '%s' is invalid for family %s, ignoring address", addrStrPtr, (af == AF_INET) ? "AF_INET" : "AF_INET6"); + return NULL; + } + *mask = (int)strtol(maskStrPtr, NULL, 10); + if (af == AF_INET && *mask > 32) { + // INFO(NCCL_INIT|NCCL_NET, "NET/IB: Ip address mask '%d' is invalid for family %s, ignoring mask", *mask, (af == AF_INET) ? "AF_INET" : "AF_INET6"); + *mask = 0; + ret = NULL; + } else if (af == AF_INET6 && *mask > 128) { + // INFO(NCCL_INIT|NCCL_NET, "NET/IB: Ip address mask '%d' is invalid for family %s, ignoring mask", *mask, (af == AF_INET) ? "AF_INET" : "AF_INET6"); + *mask = 0; + ret = NULL; + } + return ret; + } + + sa_family_t getGidAddrFamily(union ibv_gid* gid) { + const struct in6_addr *a = (struct in6_addr *)gid->raw; + bool isIpV4Mapped = ((a->s6_addr32[0] | a->s6_addr32[1]) | (a->s6_addr32[2] ^ htonl(0x0000ffff))) == 0UL; + bool isIpV4MappedMulticast = (a->s6_addr32[0] == htonl(0xff0e0000) && ((a->s6_addr32[1] | (a->s6_addr32[2] ^ htonl(0x0000ffff))) == 0UL)); + return (isIpV4Mapped || isIpV4MappedMulticast) ? AF_INET : AF_INET6; + } + + bool matchGidAddrPrefix(sa_family_t af, void* prefix, int prefixlen, union ibv_gid* gid) { + struct in_addr *base = NULL; + struct in6_addr *base6 = NULL; + struct in6_addr *addr6 = NULL;; + if (af == AF_INET) { + base = (struct in_addr *)prefix; + } else { + base6 = (struct in6_addr *)prefix; + } + addr6 = (struct in6_addr *)gid->raw; + #define NETMASK(bits) (htonl(0xffffffff ^ ((1 << (32 - bits)) - 1))) + int i = 0; + while (prefixlen > 0 && i < 4) { + if (af == AF_INET) { + int mask = NETMASK(prefixlen); + if ((base->s_addr & mask) ^ (addr6->s6_addr32[3] & mask)) { + break; + } + prefixlen = 0; + break; + } else { + if (prefixlen >= 32) { + if (base6->s6_addr32[i] ^ addr6->s6_addr32[i]) { + break; + } + prefixlen -= 32; + ++i; + } else { + int mask = NETMASK(prefixlen); + if ((base6->s6_addr32[i] & mask) ^ (addr6->s6_addr32[i] & mask)) { + break; + } + prefixlen = 0; + } + } + } + return (prefixlen == 0) ? true : false; + } + + bool configuredGid(union ibv_gid* gid) { + const struct in6_addr *a = (struct in6_addr *)gid->raw; + int trailer = (a->s6_addr32[1] | a->s6_addr32[2] | a->s6_addr32[3]); + if (((a->s6_addr32[0] | trailer) == 0UL) || ((a->s6_addr32[0] == htonl(0xfe800000)) && (trailer == 0UL))) { + return false; + } + return true; + } + + bool linkLocalGid(union ibv_gid* gid) { + const struct in6_addr *a = (struct in6_addr *)gid->raw; + if (a->s6_addr32[0] == htonl(0xfe800000) && a->s6_addr32[1] == 0UL) { + return true; + } + return false; + } + + bool validGid(union ibv_gid* gid) { + return (configuredGid(gid) && !linkLocalGid(gid)); + } + + ncclResult_t ncclIbRoceGetVersionNum(const char* deviceName, int portNum, int gidIndex, int* version) { + char gidRoceVerStr[16] = { 0 }; + char roceTypePath[PATH_MAX] = { 0 }; + snprintf(roceTypePath, sizeof(roceTypePath), "/sys/class/infiniband/%s/ports/%d/gid_attrs/types/%d", deviceName, portNum, gidIndex); + int fd = open(roceTypePath, O_RDONLY); + if (fd == -1) { + // WARN("NET/IB: open failed in ncclIbRoceGetVersionNum: %s", strerror(errno)); + return ncclSystemError; + } + int ret = read(fd, gidRoceVerStr, 15); + close(fd); + if (ret == -1) { + // In containerized environments, read could return EINVAL if the GID index is not mapped to the + // container sysfs. In this case return ncclSuccess and let the caller move to next GID index. + if (errno == EINVAL) return ncclSuccess; + // WARN("NET/IB: read failed in ncclIbRoceGetVersionNum: %s", strerror(errno)); + return ncclSystemError; + } + if (strlen(gidRoceVerStr)) { + if (strncmp(gidRoceVerStr, "IB/RoCE v1", strlen("IB/RoCE v1")) == 0 || strncmp(gidRoceVerStr, "RoCE v1", strlen("RoCE v1")) == 0) { + *version = 1; + } else if (strncmp(gidRoceVerStr, "RoCE v2", strlen("RoCE v2")) == 0) { + *version = 2; + } + } + return ncclSuccess; + } + + ncclResult_t ncclUpdateGidIndex(struct ibv_context* context, uint8_t portNum, sa_family_t af, void* prefix, int prefixlen, int roceVer, int gidIndexCandidate, int* gidIndex) { + union ibv_gid gid, gidCandidate; + CALL_CHECK(ibv_query_gid(context, portNum, *gidIndex, &gid)); + CALL_CHECK(ibv_query_gid(context, portNum, gidIndexCandidate, &gidCandidate)); + sa_family_t usrFam = af; + sa_family_t gidFam = getGidAddrFamily(&gid); + sa_family_t gidCandidateFam = getGidAddrFamily(&gidCandidate); + bool gidCandidateMatchSubnet = matchGidAddrPrefix(usrFam, prefix, prefixlen, &gidCandidate); + if (gidCandidateFam != gidFam && gidCandidateFam == usrFam && gidCandidateMatchSubnet) { + *gidIndex = gidIndexCandidate; + } else { + if (gidCandidateFam != usrFam || !validGid(&gidCandidate) || !gidCandidateMatchSubnet) { + return ncclSuccess; + } + int usrRoceVer = roceVer; + int gidRoceVerNum, gidRoceVerNumCandidate = -1; + const char* deviceName = ibv_get_device_name(context->device); + NCCL_CHECK(ncclIbRoceGetVersionNum(deviceName, portNum, *gidIndex, &gidRoceVerNum)); + NCCL_CHECK(ncclIbRoceGetVersionNum(deviceName, portNum, gidIndexCandidate, &gidRoceVerNumCandidate)); + if ((gidRoceVerNum != gidRoceVerNumCandidate || !validGid(&gid)) && gidRoceVerNumCandidate == usrRoceVer) { + *gidIndex = gidIndexCandidate; + } + } + + return ncclSuccess; + } + + } + + static ncclResult_t ncclIbGetGidIndex(struct ibv_context *context, uint8_t portNum, struct ibv_port_attr* portAttr, int *gidIndex) { + int gidTblLen = portAttr->gid_tbl_len; + //for IB, choose GID Index that will have routable FLID if present + if (portAttr->link_layer == IBV_LINK_LAYER_INFINIBAND) { + union ibv_gid gid; + int routableGidIndex = IB_ROUTABLE_FLID_GID_INDEX; + if (routableGidIndex < gidTblLen) { + CALL_CHECK(ibv_query_gid(context, portNum, routableGidIndex, &gid)); + if (ncclIbExtractFlid(&gid) != 0) { + *gidIndex = routableGidIndex; + return ncclSuccess; + } + } + *gidIndex = 0; + return ncclSuccess; + } + //for ROCE + *gidIndex = IB_GID_INDEX; + if (*gidIndex >= 0) { + return ncclSuccess; + } + sa_family_t userAddrFamily = DEFAULT_FAMILY; + int userRoceVersion = IB_ROCE_VERSION_NUM; + int prefixlen; + void *prefix = envIbAddrRange(userAddrFamily, &prefixlen); + *gidIndex = 0; + for (int gidIndexNext = 1; gidIndexNext < gidTblLen; ++gidIndexNext) { + NCCL_CHECK(ncclUpdateGidIndex(context, portNum, userAddrFamily, prefix, prefixlen, userRoceVersion, gidIndexNext, gidIndex)); + } + + return ncclSuccess; + } + } //namespace hybrid_ep + #endif //HYBRID_EP_BUILD_MULTINODE_ENABLE + \ No newline at end of file diff --git a/csrc/hybrid_ep/backend/topo_detection.cuh b/csrc/hybrid_ep/backend/topo_detection.cuh new file mode 100644 index 00000000..db905ee5 --- /dev/null +++ b/csrc/hybrid_ep/backend/topo_detection.cuh @@ -0,0 +1,1657 @@ +/************************************************************************* + * Copyright (c) 2016-2025, NVIDIA CORPORATION. All rights reserved. + * + * See NCCL_LICENSE.txt for license information + ************************************************************************/ + + #pragma once + + #ifdef HYBRID_EP_BUILD_MULTINODE_ENABLE + #if defined(__x86_64__) + #include + #endif + #include + #include + #include + #include + #include + #include + #include + #include + #include "nvml.h" + #include "utils.cuh" + + namespace hybrid_ep { + namespace { + static constexpr char HOSTID_FILE[] = "/proc/sys/kernel/random/boot_id"; + static constexpr size_t BUSID_SIZE = sizeof("0000:00:00.0"); + static constexpr size_t BUSID_REDUCED_SIZE = sizeof("0000:00"); + static constexpr size_t CPU_SET_N_U32 = sizeof(cpu_set_t) / sizeof(uint32_t); + static constexpr size_t MAX_STR_LEN = 255; + static constexpr size_t MAX_ATTR_COUNT = 16; + static constexpr size_t MAX_SUBS = 128; + static constexpr size_t MAXCHANNELS = 32; + static constexpr uint32_t NODE_TYPE_NONE = 0; + static constexpr uint32_t NODE_TYPE_OPEN = 1; + static constexpr uint32_t NODE_TYPE_CLOSE = 2; + static constexpr uint32_t NODE_TYPE_SINGLE = 3; + static constexpr uint32_t GPU = 0; + static constexpr uint32_t PCI = 1; + static constexpr uint32_t NVS = 2; + static constexpr uint32_t CPU = 3; // Actually NUMA domains + static constexpr uint32_t NIC = 4; + static constexpr uint32_t NET = 5; + static constexpr uint32_t NCCL_TOPO_NODE_TYPES = 6; + static constexpr size_t NCCL_TOPO_XML_MAX_NODES = 256; + static constexpr size_t NCCL_GRAPH_XML_MAX_NODES = 4096; + static constexpr uint32_t NCCL_TOPO_MAX_LINKS = 128; + static constexpr uint32_t NCCL_TOPO_MAX_NODES = 576; + static constexpr uint32_t NCCL_TOPO_MAX_HOPS = NCCL_TOPO_MAX_NODES * NCCL_TOPO_NODE_TYPES; + + static constexpr uint32_t LINK_LOC = 0; + static constexpr uint32_t LINK_NVL = 1; + // Skipping 2 for PATH_NVB + static constexpr uint32_t LINK_PCI = 3; + // Skipping 4 for PATH_PXB + // Skipping 5 for PATH_PXN + // Skipping 6 for PATH_PHB + static constexpr uint32_t LINK_SYS = 7; + static constexpr uint32_t LINK_NET = 8; + + // Local (myself) + static constexpr uint32_t PATH_LOC = 0; + // Connection traversing NVLink + static constexpr uint32_t PATH_NVL = 1; + // Connection through NVLink using an intermediate GPU + static constexpr uint32_t PATH_NVB = 2; + // Connection traversing at most a single PCIe bridge + static constexpr uint32_t PATH_PIX = 3; + // Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge) + static constexpr uint32_t PATH_PXB = 4; + // Connection between a GPU and a NIC using an intermediate GPU. Used to enable rail-local, aggregated network send/recv operations. + static constexpr uint32_t PATH_PXN = 5; + // Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU) + static constexpr uint32_t PATH_PHB = 6; + // Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI) + static constexpr uint32_t PATH_SYS = 7; + // Connection through the network + static constexpr uint32_t PATH_NET = 8; + // Disconnected + static constexpr uint32_t PATH_DIS = 9; + + static constexpr float LOC_BW = 5000.0; + static constexpr float SM60_NVLINK_BW = 18.0; + static constexpr float SM70_NVLINK_BW = 20.0; + static constexpr float SM80_NVLINK_BW = 20.0; + static constexpr float SM90_NVLINK_BW = 20.6; + static constexpr float SM86_NVLINK_BW = 12.0; + static constexpr float PCI_BW = 12.0; // PCI Gen3 x16 + static constexpr float QPI_BW = 6.0; + static constexpr float AMD_BW = 16.0; + static constexpr float SKL_QPI_BW = 10.0; + static constexpr float ZPI_BW = 6.0; + static constexpr float YONGFENG_ZPI_BW = 9.0; + static constexpr float P9_BW = 32.0; + static constexpr float ARM_BW = 6.0; + static constexpr float NET_BW = 12.0; // 100Gbit + + static constexpr uint32_t NCCL_TOPO_CPU_ARCH_X86 = 1; + static constexpr uint32_t NCCL_TOPO_CPU_ARCH_POWER = 2; + static constexpr uint32_t NCCL_TOPO_CPU_ARCH_ARM = 3; + static constexpr uint32_t NCCL_TOPO_CPU_VENDOR_INTEL = 1; + static constexpr uint32_t NCCL_TOPO_CPU_VENDOR_AMD = 2; + static constexpr uint32_t NCCL_TOPO_CPU_VENDOR_ZHAOXIN = 3; + static constexpr uint32_t NCCL_TOPO_CPU_TYPE_BDW = 1; + static constexpr uint32_t NCCL_TOPO_CPU_TYPE_SKL = 2; + static constexpr uint32_t NCCL_TOPO_CPU_TYPE_YONGFENG = 1; + + static constexpr uint32_t NCCL_TOPO_CPU_INTEL_BDW = 1; + static constexpr uint32_t NCCL_TOPO_CPU_INTEL_SKL = 2; + + static constexpr int32_t NCCL_TOPO_UNDEF = -1; + + static int ibvWidths[] = { 1, 4, 8, 12, 2 }; + static int ibvSpeeds[] = { + 2500, /* SDR */ + 5000, /* DDR */ + 10000, /* QDR */ + 10000, /* QDR */ + 14000, /* FDR */ + 25000, /* EDR */ + 50000, /* HDR */ + 100000 /* NDR */ }; + + struct ncclXmlNode { + char name[MAX_STR_LEN+1]; + struct { + char key[MAX_STR_LEN+1]; + char value[MAX_STR_LEN+1]; + } attrs[MAX_ATTR_COUNT+1]; // Need an extra one to consume extra params + int nAttrs; + int type; + struct ncclXmlNode* parent; + struct ncclXmlNode* subs[MAX_SUBS]; + int nSubs; + }; + + struct ncclXml { + int maxIndex, maxNodes; + struct ncclXmlNode nodes[1]; + }; + + struct ncclTopoNode; + struct ncclTopoLink { + int type; + float bw; + struct ncclTopoNode* remNode; + }; + + struct ncclTopoLinkList { + struct ncclTopoLink* list[NCCL_TOPO_MAX_HOPS]; + int count; + float bw; + int type; + }; + + struct ncclTopoNode { + int type; + int64_t id; + // Type specific data + union { + struct { + int dev; // NVML dev number + int rank; + int cudaCompCap; + int gdrSupport; + }gpu; + struct { + int dev; // Plugin dev number + uint64_t asic; + int port; + float bw; + float latency; + int gdrSupport; + int collSupport; + int maxChannels; + int localGpu; + const char *name; + }net; + struct { + int arch; + int vendor; + int model; + cpu_set_t affinity; + }cpu; + struct { + uint64_t device; + }pci; + }; + int nlinks; + struct ncclTopoLink links[NCCL_TOPO_MAX_LINKS]; + // Pre-computed paths to GPUs and NICs + struct ncclTopoLinkList* paths[NCCL_TOPO_NODE_TYPES]; + // Used during search + uint64_t used; + }; + + struct ncclTopoNodeList { + struct ncclTopoNode* list[NCCL_TOPO_MAX_NODES]; + int count; + }; + + struct ncclTopoNodeSet { + int count; + struct ncclTopoNode nodes[NCCL_TOPO_MAX_NODES]; + }; + + struct ncclTopoSystem { + int systemId; + uint64_t hostHashes[NCCL_TOPO_MAX_NODES]; + int nHosts; + struct ncclTopoNodeSet nodes[NCCL_TOPO_NODE_TYPES]; + float maxBw; + float totalBw; + }; + + struct kvDict { + const char* str; + int value; + }; + + uint64_t NCCL_TOPO_ID_SYSTEM_ID(uint64_t id) {return id >> 56;} + uint64_t NCCL_TOPO_ID(int systemid, int localid) {return ((int64_t)systemid << 56) + localid;} + struct kvDict kvDictCpuArch[] = { { "x86_64", NCCL_TOPO_CPU_ARCH_X86 }, { "arm64", NCCL_TOPO_CPU_ARCH_ARM }, { "ppc64", NCCL_TOPO_CPU_ARCH_POWER }, { NULL, 0 } }; + struct kvDict kvDictCpuVendor[] = { { "GenuineIntel", NCCL_TOPO_CPU_VENDOR_INTEL }, { "AuthenticAMD", NCCL_TOPO_CPU_VENDOR_AMD }, { "CentaurHauls", NCCL_TOPO_CPU_VENDOR_ZHAOXIN }, { " Shanghai ", NCCL_TOPO_CPU_VENDOR_ZHAOXIN }, { NULL, 0 } }; + struct kvDict kvDictPciClass[] = { { "0x060400", PCI }, { "0x068000", NVS }, { "0x068001", CPU }, { "0x03", GPU }, { "0x02", NIC }, { NULL, PCI /* Default fallback value */ } }; + struct kvDict kvDictPciGen[] = { + { "2.5 GT/s", 15 }, { "5 GT/s", 30 }, { "8 GT/s", 60 }, { "16 GT/s", 120 }, { "32 GT/s", 240 }, /* Kernel 5.6 and earlier */ + { "2.5 GT/s PCIe", 15 }, { "5.0 GT/s PCIe", 30 }, { "8.0 GT/s PCIe", 60 }, { "16.0 GT/s PCIe", 120 }, { "32.0 GT/s PCIe", 240 }, { "64.0 GT/s PCIe", 480 }, + { NULL, 60 /* Default fallback */ } }; // x100 Mbps per lane + + ncclResult_t kvConvertToInt(const char* str, int* value, struct kvDict* dict) { + struct kvDict* d = dict; + while (d->str) { + if (strncmp(str, d->str, strlen(d->str)) == 0) { + *value = d->value; + return ncclSuccess; + } + d++; + } + // INFO(NCCL_GRAPH, "KV Convert to int : could not find value of '%s' in dictionary, falling back to %d", str, d->value); + *value = d->value; + return ncclSuccess; + } + + int firstBitSet(int val, int max) { + int i = 0; + while (inodes[t].count; i++) { + if (system->nodes[t].nodes[i].id == id) { + *path = node->paths[t]+i; + return ncclSuccess; + } + } + // WARN("Could not find node of type %d id %lx", t, id); + return ncclInternalError; + } + + int isHex(char c) { + return ((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f') || (c >= 'A' && c <= 'F')); + } + + int hexToInt(char c) { + int v = c - '0'; + if (v < 0) return -1; + if (v > 9) v = 10 + c - 'a'; + if ((v < 0) || (v > 15)) return -1; + return v; + } + + int checkBDFFormat(char* bdf) { + if (bdf[4] != ':' || bdf[7] != ':' || bdf[10] != '.') return 0; + if (isHex(bdf[0]) == 0 || isHex(bdf[1]) == 0 || isHex(bdf[2]) == 0 || isHex(bdf[3]) == 0 || + isHex(bdf[5]) == 0 || isHex(bdf[6]) == 0 || isHex(bdf[8]) == 0 || isHex(bdf[9]) == 0 || + isHex(bdf[11]) == 0) return 0; + return 1; + } + + void memcpylower(char* dst, const char* src, const size_t size) { + for (int i=0; i= '0' && c <= '9') || + (c >= 'A' && c <= 'F') || + (c >= 'a' && c <= 'f')) { + hexStr[hexOffset++] = busId[i]; + } else break; + } + hexStr[hexOffset] = '\0'; + *id = strtol(hexStr, NULL, 16); + return ncclSuccess; + } + + size_t xmlMemSize(int maxNodes) { + return offsetof(struct ncclXml, nodes) + sizeof(struct ncclXmlNode)*maxNodes; + } + + ncclResult_t xmlAddNode(struct ncclXml* xml, struct ncclXmlNode* parent, const char* subName, struct ncclXmlNode** sub) { + if (xml->maxIndex == xml->maxNodes) { + // WARN("Error : too many XML nodes (max %d)", xml->maxNodes); + return ncclInternalError; + } + struct ncclXmlNode* s = xml->nodes+xml->maxIndex++; + s->nSubs = 0; + s->nAttrs = 0; + *sub = s; + s->parent = parent; + if (parent) { + if (parent->nSubs == MAX_SUBS) { + // WARN("Error : too many XML subnodes (max %d)", MAX_SUBS); + return ncclInternalError; + } + parent->subs[parent->nSubs++] = s; + } + strncpy(s->name, subName, MAX_STR_LEN); + s->name[MAX_STR_LEN] = '\0'; + return ncclSuccess; + } + + ncclResult_t xmlAlloc(struct ncclXml** xml, int maxNodes) { + char* mem; + NCCL_CHECK(ncclCalloc(&mem, xmlMemSize(maxNodes))); + *xml = (struct ncclXml*)mem; + (*xml)->maxNodes = maxNodes; + return ncclSuccess; + } + + ncclResult_t xmlGetAttrIndex(struct ncclXmlNode* node, const char* attrName, int* index) { + *index = -1; + const int nAttrs = node->nAttrs; + for (int a=0; aattrs[a].key, attrName, MAX_STR_LEN) == 0) { + *index = a; + return ncclSuccess; + } + } + return ncclSuccess; + } + + ncclResult_t xmlGetAttr(struct ncclXmlNode* node, const char* attrName, const char** value) { + int index; + NCCL_CHECK(xmlGetAttrIndex(node, attrName, &index)); + *value = index == -1 ? NULL : node->attrs[index].value; + return ncclSuccess; + } + + ncclResult_t xmlGetAttrStr(struct ncclXmlNode* node, const char* attrName, const char** value) { + NCCL_CHECK(xmlGetAttr(node, attrName, value)); + if (*value == NULL) { + // WARN("Attribute %s of node %s not found", attrName, node->name); + return ncclInternalError; + } + return ncclSuccess; + } + + ncclResult_t xmlGetAttrInt(struct ncclXmlNode* node, const char* attrName, int* value) { + const char* str; + NCCL_CHECK(xmlGetAttrStr(node, attrName, &str)); + *value = strtol(str, NULL, 0); + return ncclSuccess; + } + + ncclResult_t xmlGetAttrIntDefault(struct ncclXmlNode* node, const char* attrName, int* value, int defaultValue) { + const char* str; + NCCL_CHECK(xmlGetAttr(node, attrName, &str)); + *value = str ? strtol(str, NULL, 0) : defaultValue; + return ncclSuccess; + } + + ncclResult_t xmlGetAttrLong(struct ncclXmlNode* node, const char* attrName, int64_t* value) { + const char* str; + NCCL_CHECK(xmlGetAttrStr(node, attrName, &str)); + *value = strtol(str, NULL, 0); + return ncclSuccess; + } + + ncclResult_t xmlGetAttrFloat(struct ncclXmlNode* node, const char* attrName, float* value) { + const char* str; + NCCL_CHECK(xmlGetAttrStr(node, attrName, &str)); + *value = strtof(str, NULL); + return ncclSuccess; + } + + ncclResult_t xmlGetSub(struct ncclXmlNode* node, const char* subName, struct ncclXmlNode** sub) { + *sub = NULL; + for (int s=0; snSubs; s++) { + if (strcmp(node->subs[s]->name, subName) == 0) { + *sub = node->subs[s]; + return ncclSuccess; + } + } + return ncclSuccess; + } + + ncclResult_t xmlGetSubKv(struct ncclXmlNode* node, const char* subName, struct ncclXmlNode** sub, const char* attrName, const char* attrValue) { + *sub = NULL; + for (int s=0; snSubs; s++) { + struct ncclXmlNode* subNode = node->subs[s]; + if (strcmp(subNode->name, subName) == 0) { + const char* value; + NCCL_CHECK(xmlGetAttr(subNode, attrName, &value)); + if (value && strcmp(value, attrValue) == 0) { + *sub = node->subs[s]; + return ncclSuccess; + } + } + } + return ncclSuccess; + } + + ncclResult_t xmlFindNextTag(struct ncclXml* xml, const char* tagName, struct ncclXmlNode* prev, struct ncclXmlNode** node) { + *node = NULL; + for (int i=prev-xml->nodes+1; imaxIndex; i++) { + struct ncclXmlNode* n = xml->nodes+i; + if (strcmp(n->name, tagName) == 0) { + *node = n; + return ncclSuccess; + } + } + return ncclSuccess; + } + + ncclResult_t xmlFindTag(struct ncclXml* xml, const char* tagName, struct ncclXmlNode** node) { + *node = NULL; + for (int i=0; imaxIndex; i++) { + struct ncclXmlNode* n = xml->nodes+i; + if (strcmp(n->name, tagName) == 0) { + *node = n; + return ncclSuccess; + } + } + return ncclSuccess; + } + + ncclResult_t xmlFindTagKv(struct ncclXml* xml, const char* tagName, struct ncclXmlNode** node, const char* attrName, const char* attrValue) { + *node = NULL; + for (int i=0; imaxIndex; i++) { + struct ncclXmlNode* n = xml->nodes+i; + if (strcmp(n->name, tagName) == 0) { + const char* value; + NCCL_CHECK(xmlGetAttr(n, attrName, &value)); + if (value && strcmp(value, attrValue) == 0) { + *node = n; + return ncclSuccess; + } + } + } + return ncclSuccess; + } + + ncclResult_t xmlInitAttrInt(struct ncclXmlNode* node, const char* attrName, const int value) { + int index; + NCCL_CHECK(xmlGetAttrIndex(node, attrName, &index)); + if (index == -1) { + index = node->nAttrs++; + strncpy(node->attrs[index].key, attrName, MAX_STR_LEN); + snprintf(node->attrs[index].value, MAX_STR_LEN, "%d", value); + } + return ncclSuccess; + } + + ncclResult_t xmlInitAttrUint64(struct ncclXmlNode* node, const char* attrName, const uint64_t value) { + int index; + NCCL_CHECK(xmlGetAttrIndex(node, attrName, &index)); + if (index == -1) { + index = node->nAttrs++; + strncpy(node->attrs[index].key, attrName, MAX_STR_LEN); + snprintf(node->attrs[index].value, MAX_STR_LEN, "0x%lx", value); + } + return ncclSuccess; + } + + ncclResult_t xmlInitAttrFloat(struct ncclXmlNode* node, const char* attrName, const float value) { + int index; + NCCL_CHECK(xmlGetAttrIndex(node, attrName, &index)); + if (index == -1) { + index = node->nAttrs++; + strncpy(node->attrs[index].key, attrName, MAX_STR_LEN); + snprintf(node->attrs[index].value, MAX_STR_LEN, "%f", value); + } + return ncclSuccess; + } + + ncclResult_t xmlSetAttr(struct ncclXmlNode* node, const char* attrName, const char* value) { + int index; + NCCL_CHECK(xmlGetAttrIndex(node, attrName, &index)); + if (index == -1) { + index = node->nAttrs++; + strncpy(node->attrs[index].key, attrName, MAX_STR_LEN); + node->attrs[index].key[MAX_STR_LEN] = '\0'; + } + strncpy(node->attrs[index].value, value, MAX_STR_LEN); + node->attrs[index].value[MAX_STR_LEN] = '\0'; + return ncclSuccess; + } + + ncclResult_t xmlSetAttrIfUnset(struct ncclXmlNode* node, const char* attrName, const char* value) { + int index; + NCCL_CHECK(xmlGetAttrIndex(node, attrName, &index)); + if (index != -1) return ncclSuccess; + index = node->nAttrs++; + strncpy(node->attrs[index].key, attrName, MAX_STR_LEN); + node->attrs[index].key[MAX_STR_LEN] = '\0'; + strncpy(node->attrs[index].value, value, MAX_STR_LEN); + node->attrs[index].value[MAX_STR_LEN] = '\0'; + return ncclSuccess; + } + + ncclResult_t xmlSetAttrInt(struct ncclXmlNode* node, const char* attrName, const int value) { + int index; + NCCL_CHECK(xmlGetAttrIndex(node, attrName, &index)); + if (index == -1) { + index = node->nAttrs++; + strncpy(node->attrs[index].key, attrName, MAX_STR_LEN); + node->attrs[index].key[MAX_STR_LEN] = '\0'; + } + snprintf(node->attrs[index].value, MAX_STR_LEN, "%d", value); + node->attrs[index].value[MAX_STR_LEN] = '\0'; + return ncclSuccess; + } + + ncclResult_t xmlSetAttrLong(struct ncclXmlNode* node, const char* attrName, const int64_t value) { + int index; + NCCL_CHECK(xmlGetAttrIndex(node, attrName, &index)); + if (index == -1) { + index = node->nAttrs++; + strncpy(node->attrs[index].key, attrName, MAX_STR_LEN); + node->attrs[index].key[MAX_STR_LEN] = '\0'; + } + snprintf(node->attrs[index].value, MAX_STR_LEN, "%#lx", value); + node->attrs[index].value[MAX_STR_LEN] = '\0'; + return ncclSuccess; + } + + ncclResult_t ncclTopoGetStrFromSys(const char* path, const char* fileName, char* strValue) { + char filePath[PATH_MAX]; + sprintf(filePath, "%s/%s", path, fileName); + int offset = 0; + FILE* file; + if ((file = fopen(filePath, "r")) != NULL) { + while (feof(file) == 0 && ferror(file) == 0 && offset < MAX_STR_LEN) { + int len = fread(strValue+offset, 1, MAX_STR_LEN-offset, file); + offset += len; + } + fclose(file); + } + if (offset == 0) { + strValue[0] = '\0'; + // INFO(NCCL_GRAPH, "Topology detection : could not read %s, ignoring", filePath); + } else { + strValue[offset-1] = '\0'; + } + return ncclSuccess; + } + + ncclResult_t ncclTopoGetSubsystem(const char* sysPath, char* subSys) { + char subSysPath[PATH_MAX]; + sprintf(subSysPath, "%s/subsystem", sysPath); + char* path = realpath(subSysPath, NULL); + if (path == NULL) { + subSys[0] = '\0'; + } else { + int offset; + for (offset = strlen(path); offset > 0 && path[offset] != '/'; offset--); + strcpy(subSys, path+offset+1); + free(path); + } + return ncclSuccess; + } + + ncclResult_t ncclTopoSetAttrFromSys(struct ncclXmlNode* pciNode, const char* path, const char* fileName, const char* attrName) { + char strValue[MAX_STR_LEN]; + NCCL_CHECK(ncclTopoGetStrFromSys(path, fileName, strValue)); + if (strValue[0] != '\0') { NCCL_CHECK(xmlSetAttr(pciNode, attrName, strValue)); } + // TRACE(NCCL_GRAPH, "Read from sys %s/%s -> %s=%s", path, fileName, attrName, strValue); + return ncclSuccess; + } + + ncclResult_t ncclTopoGetXmlFromCpu(struct ncclXmlNode* cpuNode, struct ncclXml* xml) { + int index; + NCCL_CHECK(xmlGetAttrIndex(cpuNode, "affinity", &index)); + if (index == -1) { + const char* numaId; + NCCL_CHECK(xmlGetAttr(cpuNode, "numaid", &numaId)); + if (numaId == NULL) { + // WARN("GetXmlFromCpu : could not find CPU numa ID."); + return ncclInternalError; + } + // Set affinity + char cpumaskPath[] = "/sys/devices/system/node/node0000"; + sprintf(cpumaskPath, "/sys/devices/system/node/node%s", numaId); + NCCL_CHECK(ncclTopoSetAttrFromSys(cpuNode, cpumaskPath, "cpumap", "affinity")); + } + NCCL_CHECK(xmlGetAttrIndex(cpuNode, "arch", &index)); + if (index == -1) { + // Fill CPU type / vendor / model + #if defined(__PPC__) + NCCL_CHECK(xmlSetAttr(cpuNode, "arch", "ppc64")); + #elif defined(__aarch64__) + NCCL_CHECK(xmlSetAttr(cpuNode, "arch", "arm64")); + #elif defined(__x86_64__) + NCCL_CHECK(xmlSetAttr(cpuNode, "arch", "x86_64")); + #endif + } + + #if defined(__x86_64__) + NCCL_CHECK(xmlGetAttrIndex(cpuNode, "vendor", &index)); + if (index == -1) { + union { + struct { + // CPUID 0 String register order + uint32_t ebx; + uint32_t edx; + uint32_t ecx; + }; + char vendor[12]; + } cpuid0; + + [[maybe_unused]] unsigned unused; + __cpuid(0, unused, cpuid0.ebx, cpuid0.ecx, cpuid0.edx); + char vendor[13]; + strncpy(vendor, cpuid0.vendor, 12); + vendor[12] = '\0'; + NCCL_CHECK(xmlSetAttr(cpuNode, "vendor", vendor)); + } + NCCL_CHECK(xmlGetAttrIndex(cpuNode, "familyid", &index)); + if (index == -1) { + union { + struct { + unsigned steppingId:4; + unsigned modelId:4; + unsigned familyId:4; + unsigned processorType:2; + unsigned resv0:2; + unsigned extModelId:4; + unsigned extFamilyId:8; + unsigned resv1:4; + }; + uint32_t val; + } cpuid1; + [[maybe_unused]] unsigned unused; + __cpuid(1, cpuid1.val, unused, unused, unused); + int familyId = cpuid1.familyId + (cpuid1.extFamilyId << 4); + int modelId = cpuid1.modelId + (cpuid1.extModelId << 4); + NCCL_CHECK(xmlSetAttrInt(cpuNode, "familyid", familyId)); + NCCL_CHECK(xmlSetAttrInt(cpuNode, "modelid", modelId)); + } + #endif + return ncclSuccess; + } + + ncclResult_t ncclTopoGetXmlFromGpu(struct ncclXmlNode* pciNode, nvmlDevice_t nvmlDev, struct ncclXml* xml, struct ncclXmlNode** gpuNodeRet) { + struct ncclXmlNode* gpuNode = NULL; + NCCL_CHECK(xmlGetSub(pciNode, "gpu", &gpuNode)); + if (gpuNode == NULL) NCCL_CHECK(xmlAddNode(xml, pciNode, "gpu", &gpuNode)); + int index = -1; + int dev = -1; + NCCL_CHECK(xmlGetAttrIndex(gpuNode, "dev", &index)); + if (index == -1) { + CALL_CHECK(nvmlDeviceGetIndex(nvmlDev, (unsigned int*)&dev)); + NCCL_CHECK(xmlSetAttrInt(gpuNode, "dev", dev)); + } + NCCL_CHECK(xmlGetAttrInt(gpuNode, "dev", &dev)); + if (dev == -1) { *gpuNodeRet = NULL; return ncclSuccess; } + NCCL_CHECK(xmlGetAttrIndex(gpuNode, "sm", &index)); + if (index == -1) { + int cudaMajor, cudaMinor; + if (nvmlDev == NULL) { + cudaDeviceProp devProp; + CUDA_CHECK(cudaGetDeviceProperties(&devProp, dev)); + cudaMajor = devProp.major; cudaMinor = devProp.minor; + } else { + CALL_CHECK(nvmlDeviceGetCudaComputeCapability(nvmlDev, &cudaMajor, &cudaMinor)); + } + NCCL_CHECK(xmlSetAttrInt(gpuNode, "sm", cudaMajor*10+cudaMinor)); + } + int sm; + NCCL_CHECK(xmlGetAttrInt(gpuNode, "sm", &sm)); + struct ncclXmlNode* nvlNode = NULL; + NCCL_CHECK(xmlGetSub(gpuNode, "nvlink", &nvlNode)); + if (nvlNode == NULL) { + // NVML NVLink detection + int maxNvLinks = (sm < 60) ? 0 : (sm < 70) ? 4 : (sm < 80) ? 6 : (sm < 90) ? 12 : 18; + if (maxNvLinks > 0 && nvmlDev == NULL) { + // WARN("No NVML device handle. Skipping nvlink detection."); + maxNvLinks = 0; + } + for (int l=0; l= 11080 + if (sm >= 90) { + nvmlFieldValue_t fv; + fv.fieldId = NVML_FI_DEV_NVLINK_GET_STATE; + fv.scopeId = l; + // fv.value will contain NV_FEATURE_ENABLED or NV_FEATURE_DISABLED + if ((nvmlDeviceGetFieldValues(nvmlDev, 1, &fv) == NVML_SUCCESS) && (fv.nvmlReturn == NVML_SUCCESS)) + isActive = (nvmlEnableState_t) fv.value.uiVal; + } else /* FALLTHRU to GetNvLinkState if before SM90 */ + #endif + { + (void) nvmlDeviceGetNvLinkState(nvmlDev, l, &isActive); + } + if (isActive != NVML_FEATURE_ENABLED) continue; + // Try to figure out what's on the other side of the NVLink + nvmlPciInfo_t remoteProc; + if (nvmlDeviceGetNvLinkRemotePciInfo(nvmlDev, l, &remoteProc) != NVML_SUCCESS) continue; + // Make a lower case copy of the bus ID for calling ncclDeviceType + // PCI system path is in lower case + char* p = remoteProc.busId; + char lowerId[NVML_DEVICE_PCI_BUS_ID_BUFFER_SIZE]; + for (int c=0; c= 11080 + struct ncclXmlNode* c2cNode = NULL; + NCCL_CHECK(xmlGetSub(gpuNode, "c2c", &c2cNode)); + if (c2cNode == NULL) { + if (sm >= 90) { + int c2cLinksCount = 0; + nvmlFieldValue_t fv; + fv.fieldId = NVML_FI_DEV_C2C_LINK_COUNT; + if ((nvmlDeviceGetFieldValues(nvmlDev, 1, &fv) == NVML_SUCCESS) && (fv.nvmlReturn == NVML_SUCCESS)) { + c2cLinksCount = fv.value.uiVal; + int bw = 0; + int count = 0; + for (int l=0; l 0) { + NCCL_CHECK(xmlAddNode(xml, gpuNode, "c2c", &c2cNode)); + NCCL_CHECK(xmlSetAttrInt(c2cNode, "bw", bw)); + NCCL_CHECK(xmlSetAttrInt(c2cNode, "count", count)); + } + } + } + } + #endif + // Fill target classes + for (int s=0; snSubs; s++) { + struct ncclXmlNode* sub = gpuNode->subs[s]; + if (strcmp(sub->name, "nvlink") != 0) continue; + int index; + NCCL_CHECK(xmlGetAttrIndex(sub, "tclass", &index)); + if (index == -1) { + const char* busId; + NCCL_CHECK(xmlGetAttr(sub, "target", &busId)); + char* path; + // ncclDebugNoWarn = NCCL_GRAPH; + getPciPath(busId, &path); + // ncclDebugNoWarn = 0; + if (path == NULL || strcmp(busId, "fffffff:ffff:ff") == 0) { + // Remote NVLink device is not visible inside this VM. Assume NVSwitch. + NCCL_CHECK(xmlSetAttr(sub, "tclass", "0x068000")); + } else { + NCCL_CHECK(ncclTopoSetAttrFromSys(sub, path, "class", "tclass")); + free(path); + } + } + } + *gpuNodeRet = gpuNode; + return ncclSuccess; + } + + ncclResult_t ncclTopoGetXmlFromSys(struct ncclXmlNode* pciNode, struct ncclXml* xml) { + // Fill info, then parent + const char* busId; + NCCL_CHECK(xmlGetAttr(pciNode, "busid", &busId)); + char* path = NULL; + // ncclDebugNoWarn = NCCL_GRAPH; + getPciPath(busId, &path); + // ncclDebugNoWarn = 0; + if (path) { + NCCL_CHECK(ncclTopoSetAttrFromSys(pciNode, path, "class", "class")); + } + int index; + // ncclDebugNoWarn = NCCL_GRAPH; + NCCL_CHECK(xmlGetAttrIndex(pciNode, "vendor", &index)); + if (index == -1) { + if (path) ncclTopoSetAttrFromSys(pciNode, path, "vendor", "vendor"); + } + NCCL_CHECK(xmlGetAttrIndex(pciNode, "device", &index)); + if (index == -1) { + if (path) ncclTopoSetAttrFromSys(pciNode, path, "device", "device"); + } + NCCL_CHECK(xmlGetAttrIndex(pciNode, "subsystem_vendor", &index)); + if (index == -1) { + if (path) ncclTopoSetAttrFromSys(pciNode, path, "subsystem_vendor", "subsystem_vendor"); + } + NCCL_CHECK(xmlGetAttrIndex(pciNode, "subsystem_device", &index)); + if (index == -1) { + if (path) ncclTopoSetAttrFromSys(pciNode, path, "subsystem_device", "subsystem_device"); + } + // ncclDebugNoWarn = 0; + NCCL_CHECK(xmlGetAttrIndex(pciNode, "link_speed", &index)); + if (index == -1) { + if (path) { + char deviceSpeedStr[MAX_STR_LEN]; + float deviceSpeed = FLT_MAX; + NCCL_CHECK(ncclTopoGetStrFromSys(path, "max_link_speed", deviceSpeedStr)); + sscanf(deviceSpeedStr, "%f GT/s", &deviceSpeed); + char portSpeedStr[MAX_STR_LEN]; + float portSpeed = FLT_MAX; + NCCL_CHECK(ncclTopoGetStrFromSys(path, "../max_link_speed", portSpeedStr)); + sscanf(portSpeedStr, "%f GT/s", &portSpeed); + NCCL_CHECK(xmlSetAttr(pciNode, "link_speed", portSpeed < deviceSpeed ? portSpeedStr : deviceSpeedStr)); + } else { + NCCL_CHECK(xmlSetAttr(pciNode, "link_speed", "")); + } + } + NCCL_CHECK(xmlGetAttrIndex(pciNode, "link_width", &index)); + if (index == -1) { + if (path) { + char strValue[MAX_STR_LEN]; + NCCL_CHECK(ncclTopoGetStrFromSys(path, "max_link_width", strValue)); + int deviceWidth = strtol(strValue, NULL, 0); + NCCL_CHECK(ncclTopoGetStrFromSys(path, "../max_link_width", strValue)); + int portWidth = strtol(strValue, NULL, 0); + NCCL_CHECK(xmlSetAttrInt(pciNode, "link_width", std::min(deviceWidth,portWidth))); + } else { + NCCL_CHECK(xmlSetAttr(pciNode, "link_width", "")); + } + } + struct ncclXmlNode* parent = pciNode->parent; + if (parent == NULL) { + if (path) { + // Save that for later in case next step is a CPU + char numaIdStr[MAX_STR_LEN]; + NCCL_CHECK(ncclTopoGetStrFromSys(path, "numa_node", numaIdStr)); + + // Go up one level in the PCI tree. Rewind two "/" and follow the upper PCI + // switch, or stop if we reach a CPU root complex. + int slashCount = 0; + int parentOffset; + for (parentOffset = strlen(path)-1; parentOffset>0; parentOffset--) { + if (path[parentOffset] == '/') { + slashCount++; + path[parentOffset] = '\0'; + int start = parentOffset - 1; + while (start>0 && path[start] != '/') start--; + // Check whether the parent path looks like "BBBB:BB:DD.F" or not. + if (checkBDFFormat(path+start+1) == 0) { + // This a CPU root complex. Create a CPU tag and stop there. + struct ncclXmlNode* topNode; + NCCL_CHECK(xmlFindTag(xml, "system", &topNode)); + NCCL_CHECK(xmlGetSubKv(topNode, "cpu", &parent, "numaid", numaIdStr)); + if (parent == NULL) { + NCCL_CHECK(xmlAddNode(xml, topNode, "cpu", &parent)); + NCCL_CHECK(xmlSetAttrLong(parent, "host_hash", getHostHash())); + NCCL_CHECK(xmlSetAttr(parent, "numaid", numaIdStr)); + } + } else if (slashCount == 2) { + // Continue on the upper PCI switch + for (int i = strlen(path)-1; i>0; i--) { + if (path[i] == '/') { + NCCL_CHECK(xmlFindTagKv(xml, "pci", &parent, "busid", path+i+1)); + if (parent == NULL) { + NCCL_CHECK(xmlAddNode(xml, NULL, "pci", &parent)); + NCCL_CHECK(xmlSetAttr(parent, "busid", path+i+1)); + } + break; + } + } + } + } + if (parent) break; + } + } else { + // No information on /sys, attach GPU to unknown CPU + NCCL_CHECK(xmlFindTagKv(xml, "cpu", &parent, "numaid", "-1")); + if (parent == NULL) { + struct ncclXmlNode* topNode; + NCCL_CHECK(xmlFindTag(xml, "system", &topNode)); + NCCL_CHECK(xmlAddNode(xml, topNode, "cpu", &parent)); + NCCL_CHECK(xmlSetAttrLong(parent, "host_hash", getHostHash())); + NCCL_CHECK(xmlSetAttr(parent, "numaid", "-1")); + NCCL_CHECK(ncclTopoGetXmlFromCpu(parent, xml)); + } + } + pciNode->parent = parent; + // Keep PCI sub devices ordered by PCI Bus ID (Issue #820) + int subIndex = parent->nSubs; + const char* newBusId; + NCCL_CHECK(xmlGetAttrStr(pciNode, "busid", &newBusId)); + for (int s=0; snSubs; s++) { + const char* busId; + NCCL_CHECK(xmlGetAttr(parent->subs[s], "busid", &busId)); + if (busId != NULL && strcmp(newBusId, busId) < 0) { subIndex = s; break; } + } + if (parent->nSubs == MAX_SUBS) { + // WARN("Error : XML parser is limited to %d subnodes", MAX_SUBS); + return ncclInternalError; + } + for (int s = parent->nSubs; s > subIndex; s--) parent->subs[s] = parent->subs[s-1]; + parent->subs[subIndex] = pciNode; + parent->nSubs++; + } + if (strcmp(parent->name, "pci") == 0) { + NCCL_CHECK(ncclTopoGetXmlFromSys(parent, xml)); + } else if (strcmp(parent->name, "cpu") == 0) { + NCCL_CHECK(ncclTopoGetXmlFromCpu(parent, xml)); + } + free(path); + return ncclSuccess; + } + + ncclResult_t ncclGetSystemId(struct ncclTopoSystem* system, struct ncclXmlNode* xmlCpu, int* systemIdPtr) { + const char* hostHashStr; + NCCL_CHECK(xmlGetAttr(xmlCpu, "host_hash", &hostHashStr)); + uint64_t hostHash = hostHashStr ? strtoull(hostHashStr, NULL, 16) : 0; + int systemId; + for (systemId=0; systemIdnHosts; systemId++) if (system->hostHashes[systemId] == hostHash) break; + if (systemId == system->nHosts) system->hostHashes[system->nHosts++] = hostHash; + *systemIdPtr = systemId; + return ncclSuccess; + } + + ncclResult_t ncclTopoGetLocal(struct ncclTopoSystem* system, int type, int index, int resultType, + int* locals, int* localCount, int* pathType) { + int minType = PATH_DIS; + float maxBw = 0; + int count = 0; + struct ncclTopoLinkList* paths = system->nodes[type].nodes[index].paths[resultType]; + if (paths == NULL) { *localCount = 0; return ncclInternalError; } + for (int i=0; inodes[resultType].count; i++) { + if (paths[i].bw > maxBw || (paths[i].bw == maxBw && paths[i].type < minType)) { + maxBw = paths[i].bw; + minType = paths[i].type; + if (pathType) *pathType = minType; + count = 0; + } + if (paths[i].bw == maxBw && paths[i].type == minType) { + if (count == NCCL_TOPO_MAX_NODES) { + return ncclInternalError; + } + locals[count++] = i; + } + } + *localCount = count; + // int minType = PATH_DIS; + // float maxBw = 0; + // int count = 0; + // NCCL_CHECK(ncclCalloc(locals, system->nodes[resultType].count)); + // struct ncclTopoLinkList* paths = system->nodes[type].nodes[index].paths[resultType]; + // for (int i=0; inodes[resultType].count; i++) { + // if (paths[i].bw > maxBw || (paths[i].bw == maxBw && paths[i].type < minType)) { + // maxBw = paths[i].bw; + // minType = paths[i].type; + // if (pathType) *pathType = minType; + // count = 0; + // } + // if (paths[i].bw == maxBw && paths[i].type == minType) (*locals)[count++] = i; + // } + // *localCount = count; + return ncclSuccess; + } + + ncclResult_t ncclTopoGetInterCpuBw(struct ncclTopoNode* cpu, float* bw) { + *bw = LOC_BW; + if (cpu->cpu.arch == NCCL_TOPO_CPU_ARCH_POWER) { + *bw = P9_BW; + return ncclSuccess; + } + if (cpu->cpu.arch == NCCL_TOPO_CPU_ARCH_ARM) { + *bw = ARM_BW; + return ncclSuccess; + } + if (cpu->cpu.arch == NCCL_TOPO_CPU_ARCH_X86 && cpu->cpu.vendor == NCCL_TOPO_CPU_VENDOR_INTEL) { + *bw = cpu->cpu.model == NCCL_TOPO_CPU_TYPE_SKL ? SKL_QPI_BW : QPI_BW; + } + if (cpu->cpu.arch == NCCL_TOPO_CPU_ARCH_X86 && cpu->cpu.vendor == NCCL_TOPO_CPU_VENDOR_AMD) { + *bw = AMD_BW; + } + if (cpu->cpu.arch == NCCL_TOPO_CPU_ARCH_X86 && cpu->cpu.vendor == NCCL_TOPO_CPU_VENDOR_ZHAOXIN) { + *bw = cpu->cpu.model == NCCL_TOPO_CPU_TYPE_YONGFENG ? YONGFENG_ZPI_BW : ZPI_BW; + } + return ncclSuccess; + } + + ncclResult_t ncclTopoConnectNodes(struct ncclTopoNode* node, struct ncclTopoNode* remNode, int type, float bw) { + // Aggregate links into higher bw for NVLink + struct ncclTopoLink* link; + for (link = node->links; link - node->links != NCCL_TOPO_MAX_LINKS && link->remNode; link++) { + if (link->remNode == remNode && link->type == type) break; + } + if (link - node->links == NCCL_TOPO_MAX_LINKS) { + // WARN("Error : too many Topo links (max %d)", NCCL_TOPO_MAX_LINKS); + return ncclInternalError; + } + if (link->remNode == NULL) node->nlinks++; + link->type = type; + link->remNode = remNode; + link->bw += bw; + + // Sort links in BW descending order + struct ncclTopoLink linkSave; + memcpy(&linkSave, link, sizeof(struct ncclTopoLink)); + while (link != node->links) { + if ((link-1)->bw >= linkSave.bw) break; + memcpy(link, link-1, sizeof(struct ncclTopoLink)); + link--; + } + memcpy(link, &linkSave, sizeof(struct ncclTopoLink)); + return ncclSuccess; + } + + ncclResult_t ncclTopoConnectCpus(struct ncclTopoSystem* system) { + // And connect all CPU nodes together + for (int n=0; nnodes[CPU].count; n++) { + struct ncclTopoNode* cpu1 = system->nodes[CPU].nodes+n; + for (int p=0; pnodes[CPU].count; p++) { + struct ncclTopoNode* cpu2 = system->nodes[CPU].nodes+p; + if (n == p || (NCCL_TOPO_ID_SYSTEM_ID(cpu1->id) != NCCL_TOPO_ID_SYSTEM_ID(cpu2->id))) continue; + float bw; + NCCL_CHECK(ncclTopoGetInterCpuBw(cpu1, &bw)); + NCCL_CHECK(ncclTopoConnectNodes(cpu1, cpu2, LINK_SYS, bw)); + } + } + return ncclSuccess; + } + + ncclResult_t ncclTopoCreateNode(struct ncclTopoSystem* system, struct ncclTopoNode** node, int type, uint64_t id) { + if (system->nodes[type].count == NCCL_TOPO_MAX_NODES) { + // WARN("Error : tried to create too many nodes of type %d", type); + return ncclInternalError; + } + struct ncclTopoNode* n = system->nodes[type].nodes+system->nodes[type].count; + system->nodes[type].count++; + n->type = type; + n->id = id; + if (type == GPU) { + // Create link to itself (used in some corner cases) + n->nlinks=1; + n->links[0].type = LINK_LOC; + n->links[0].remNode = n; + n->links[0].bw = LOC_BW; + n->gpu.dev = NCCL_TOPO_UNDEF; + n->gpu.rank = NCCL_TOPO_UNDEF; + n->gpu.cudaCompCap = NCCL_TOPO_UNDEF; + } else if (type == CPU) { + n->cpu.arch = NCCL_TOPO_UNDEF; + n->cpu.vendor = NCCL_TOPO_UNDEF; + n->cpu.model = NCCL_TOPO_UNDEF; + } else if (type == NET) { + n->net.asic = 0ULL; + n->net.port = NCCL_TOPO_UNDEF; + n->net.bw = 0.0; + n->net.latency = 0.0; + } + *node = n; + return ncclSuccess; + } + + ncclResult_t ncclTopoGetNode(struct ncclTopoSystem* system, struct ncclTopoNode** node, int type, uint64_t id) { + for (int i=0; inodes[type].count; i++) { + if (system->nodes[type].nodes[i].id == id) { + *node = system->nodes[type].nodes+i; + return ncclSuccess; + } + } + return ncclSuccess; + } + + ncclResult_t ncclTopoGetPciNode(struct ncclXml* xml, const char* busId, struct ncclXmlNode** pciNode) { + NCCL_CHECK(xmlFindTagKv(xml, "pci", pciNode, "busid", busId)); + if (*pciNode == NULL) { + NCCL_CHECK(xmlAddNode(xml, NULL, "pci", pciNode)); + NCCL_CHECK(xmlSetAttr(*pciNode, "busid", busId)); + } + return ncclSuccess; + } + + ncclResult_t ncclTopoAddGpu(struct ncclXmlNode* xmlGpu, struct ncclTopoSystem* system, struct ncclTopoNode* gpu) { + NCCL_CHECK(xmlGetAttrInt(xmlGpu, "sm", &gpu->gpu.cudaCompCap)); + NCCL_CHECK(xmlGetAttrInt(xmlGpu, "rank", &gpu->gpu.rank)); + NCCL_CHECK(xmlGetAttrInt(xmlGpu, "dev", &gpu->gpu.dev)); + NCCL_CHECK(xmlGetAttrInt(xmlGpu, "gdr", &gpu->gpu.gdrSupport)); + // Do not go any further, nvlinks will be added in a second pass + return ncclSuccess; + } + + ncclResult_t ncclTopoAddNet(struct ncclXmlNode* xmlNet, struct ncclTopoSystem* system, struct ncclTopoNode* nic, int systemId) { + int dev; + NCCL_CHECK(xmlGetAttrInt(xmlNet, "dev", &dev)); + struct ncclTopoNode* net; + NCCL_CHECK(ncclTopoCreateNode(system, &net, NET, NCCL_TOPO_ID(systemId, dev))); + net->net.dev = dev; + const char* str; + NCCL_CHECK(xmlGetAttr(xmlNet, "guid", &str)); + if (str) sscanf(str, "0x%lx", &net->net.asic); + else net->net.asic = dev; + // ncclDebugNoWarn = NCCL_GRAPH; + int mbps; + NCCL_CHECK(xmlGetAttrIntDefault(xmlNet, "speed", &mbps, 0)); + if (mbps <= 0) mbps = 10000; // Some NICs define speed = -1 + net->net.bw = mbps / 8000.0; + if (xmlGetAttrFloat(xmlNet, "latency", &net->net.latency) != ncclSuccess) net->net.latency = 0; + NCCL_CHECK(xmlGetAttrIntDefault(xmlNet, "port", &net->net.port, 0)); + NCCL_CHECK(xmlGetAttrIntDefault(xmlNet, "gdr", &net->net.gdrSupport, 0)); + NCCL_CHECK(xmlGetAttrIntDefault(xmlNet, "maxconn", &net->net.maxChannels, MAXCHANNELS)); + NCCL_CHECK(xmlGetAttrIntDefault(xmlNet, "coll", &net->net.collSupport, 0)); + NCCL_CHECK(xmlGetAttrStr(xmlNet, "name", &net->net.name)); + // ncclDebugNoWarn = 0; + NCCL_CHECK(ncclTopoConnectNodes(nic, net, LINK_NET, net->net.bw)); + NCCL_CHECK(ncclTopoConnectNodes(net, nic, LINK_NET, net->net.bw)); + return ncclSuccess; + } + + ncclResult_t ncclTopoAddNic(struct ncclXmlNode* xmlNic, struct ncclTopoSystem* system, struct ncclTopoNode* nic, int systemId) { + for (int s=0; snSubs; s++) { + struct ncclXmlNode* xmlNet = xmlNic->subs[s]; + if (strcmp(xmlNet->name, "net") != 0) continue; + int index; + NCCL_CHECK(xmlGetAttrIndex(xmlNet, "dev", &index)); + if (index == -1) continue; + NCCL_CHECK(ncclTopoAddNet(xmlNet, system, nic, systemId)); + } + return ncclSuccess; + } + + ncclResult_t ncclTopoAddPci(struct ncclXmlNode* xmlPci, struct ncclTopoSystem* system, struct ncclTopoNode* parent, int systemId) { + const char* str; + int type; + NCCL_CHECK(xmlGetAttrStr(xmlPci, "class", &str)); + NCCL_CHECK(kvConvertToInt(str, &type, kvDictPciClass)); + int64_t busId; + NCCL_CHECK(xmlGetAttrStr(xmlPci, "busid", &str)); + NCCL_CHECK(busIdToInt64(str, &busId)); + struct ncclTopoNode* node = NULL; + struct ncclXmlNode* xmlGpu = NULL; + NCCL_CHECK(xmlGetSub(xmlPci, "gpu", &xmlGpu)); + if (xmlGpu != NULL) { + type = GPU; + int index; + NCCL_CHECK(xmlGetAttrIndex(xmlGpu, "rank", &index)); + if (index == -1) return ncclSuccess; + NCCL_CHECK(ncclTopoCreateNode(system, &node, type, NCCL_TOPO_ID(systemId, busId))); + NCCL_CHECK(ncclTopoAddGpu(xmlGpu, system, node)); + } + struct ncclXmlNode* xmlNic = NULL; + NCCL_CHECK(xmlGetSub(xmlPci, "nic", &xmlNic)); + if (xmlNic != NULL) { + type = NIC; + // Ignore sub device ID and merge multi-port NICs into one PCI device. + busId &= 0xfffffffffffffff0; + struct ncclTopoNode* nicNode = NULL; + int64_t id = NCCL_TOPO_ID(systemId, busId); + NCCL_CHECK(ncclTopoGetNode(system, &nicNode, type, id)); + if (nicNode == NULL) { + NCCL_CHECK(ncclTopoCreateNode(system, &nicNode, type, id)); + node = nicNode; // Connect it to parent later on + } + NCCL_CHECK(ncclTopoAddNic(xmlNic, system, nicNode, systemId)); + } else if (type == PCI) { + NCCL_CHECK(ncclTopoCreateNode(system, &node, type, NCCL_TOPO_ID(systemId, busId))); + NCCL_CHECK(xmlGetAttr(xmlPci, "vendor", &str)); + if (str) node->pci.device += strtol(str, NULL, 0) << 48; + NCCL_CHECK(xmlGetAttr(xmlPci, "device", &str)); + if (str) node->pci.device += strtol(str, NULL, 0) << 32; + NCCL_CHECK(xmlGetAttr(xmlPci, "subsystem_vendor", &str)); + if (str) node->pci.device += strtol(str, NULL, 0) << 16; + NCCL_CHECK(xmlGetAttr(xmlPci, "subsystem_device", &str)); + if (str) node->pci.device += strtol(str, NULL, 0); + for (int s=0; snSubs; s++) { + struct ncclXmlNode* xmlSubPci = xmlPci->subs[s]; + NCCL_CHECK(ncclTopoAddPci(xmlSubPci, system, node, systemId)); + } + } + if (node) { + int width, speed; + NCCL_CHECK(xmlGetAttrInt(xmlPci, "link_width", &width)); + NCCL_CHECK(xmlGetAttrStr(xmlPci, "link_speed", &str)); + + // Manage cases where speed was not indicated in /sys + if (width == 0) width = 16; + NCCL_CHECK(kvConvertToInt(str, &speed, kvDictPciGen)); // Values in 100Mbps, per lane (we want GB/s in the end) + + NCCL_CHECK(ncclTopoConnectNodes(node, parent, LINK_PCI, width*speed/80.0)); + NCCL_CHECK(ncclTopoConnectNodes(parent, node, LINK_PCI, width*speed/80.0)); + } + return ncclSuccess; + } + + ncclResult_t ncclTopoAddCpu(struct ncclXmlNode* xmlCpu, struct ncclTopoSystem* system) { + int numaId; + NCCL_CHECK(xmlGetAttrInt(xmlCpu, "numaid", &numaId)); + int systemId; + NCCL_CHECK(ncclGetSystemId(system, xmlCpu, &systemId)); + struct ncclTopoNode* cpu; + NCCL_CHECK(ncclTopoCreateNode(system, &cpu, CPU, NCCL_TOPO_ID(systemId, numaId))); + const char* str; + NCCL_CHECK(xmlGetAttr(xmlCpu, "affinity", &str)); + if (str != NULL) { + NCCL_CHECK(ncclStrToCpuset(str, &cpu->cpu.affinity)); + } + + NCCL_CHECK(xmlGetAttrStr(xmlCpu, "arch", &str)); + NCCL_CHECK(kvConvertToInt(str, &cpu->cpu.arch, kvDictCpuArch)); + if (cpu->cpu.arch == NCCL_TOPO_CPU_ARCH_X86) { + NCCL_CHECK(xmlGetAttrStr(xmlCpu, "vendor", &str)); + NCCL_CHECK(kvConvertToInt(str, &cpu->cpu.vendor, kvDictCpuVendor)); + if (cpu->cpu.vendor == NCCL_TOPO_CPU_VENDOR_INTEL) { + int familyId, modelId; + NCCL_CHECK(xmlGetAttrInt(xmlCpu, "familyid", &familyId)); + NCCL_CHECK(xmlGetAttrInt(xmlCpu, "modelid", &modelId)); + cpu->cpu.model = (familyId == 6 && modelId >= 0x55) ? NCCL_TOPO_CPU_TYPE_SKL : NCCL_TOPO_CPU_INTEL_BDW; + } else if (cpu->cpu.vendor == NCCL_TOPO_CPU_VENDOR_ZHAOXIN) { + int familyId, modelId; + NCCL_CHECK(xmlGetAttrInt(xmlCpu, "familyid", &familyId)); + NCCL_CHECK(xmlGetAttrInt(xmlCpu, "modelid", &modelId)); + if (familyId == 7 && modelId == 0x5B) cpu->cpu.model = NCCL_TOPO_CPU_TYPE_YONGFENG; + } + } + for (int s=0; snSubs; s++) { + struct ncclXmlNode* node = xmlCpu->subs[s]; + if (strcmp(node->name, "pci") == 0) NCCL_CHECK(ncclTopoAddPci(node, system, cpu, systemId)); + if (strcmp(node->name, "nic") == 0) { + struct ncclTopoNode* nic = NULL; + NCCL_CHECK(ncclTopoGetNode(system, &nic, NIC, 0)); + if (nic == NULL) { + NCCL_CHECK(ncclTopoCreateNode(system, &nic, NIC, NCCL_TOPO_ID(systemId, 0))); + NCCL_CHECK(ncclTopoConnectNodes(cpu, nic, LINK_PCI, LOC_BW)); + NCCL_CHECK(ncclTopoConnectNodes(nic, cpu, LINK_PCI, LOC_BW)); + } + NCCL_CHECK(ncclTopoAddNic(node, system, nic, systemId)); + } + } + return ncclSuccess; + } + + ncclResult_t ncclTopoFillGpu(struct ncclXml* xml, const char* busId, struct ncclXmlNode** gpuNode) { + struct ncclXmlNode* node; + NCCL_CHECK(ncclTopoGetPciNode(xml, busId, &node)); + NCCL_CHECK(xmlSetAttrIfUnset(node, "class", "0x03")); + NCCL_CHECK(ncclTopoGetXmlFromSys(node, xml)); + nvmlDevice_t nvmlDev; + CALL_CHECK(nvmlDeviceGetHandleByPciBusId(busId, &nvmlDev)); + NCCL_CHECK(ncclTopoGetXmlFromGpu(node, nvmlDev, xml, gpuNode)); + return ncclSuccess; + } + + ncclResult_t ncclTopoFillNet(struct ncclXml* xml, const char* pciPath, const char* netName, struct ncclXmlNode** netNode) { + NCCL_CHECK(xmlFindTagKv(xml, "net", netNode, "name", netName)); + if (*netNode != NULL) return ncclSuccess; + const char* pciSysPath = pciPath; + if (pciSysPath) { + char subSystem[PATH_MAX]; + NCCL_CHECK(ncclTopoGetSubsystem(pciSysPath, subSystem)); + // This is not a PCI device (virtual, usb, ...). + if (strcmp(subSystem, "pci") != 0) { + // INFO(NCCL_GRAPH, "Topology detection: network path %s is not a PCI device (%s). Attaching to first CPU", pciSysPath, subSystem); + pciSysPath = NULL; + } + } + struct ncclXmlNode* parent = NULL; + if (pciSysPath) { + int offset; + for (offset=strlen(pciSysPath)-1; pciSysPath[offset] != '/'; offset--); + char busId[NVML_DEVICE_PCI_BUS_ID_BUFFER_SIZE]; + strcpy(busId, pciSysPath+offset+1); + NCCL_CHECK(ncclTopoGetPciNode(xml, busId, &parent)); + NCCL_CHECK(xmlSetAttrIfUnset(parent, "class", "0x02")); + NCCL_CHECK(ncclTopoGetXmlFromSys(parent, xml)); + } else { + // Virtual NIC, no PCI device, attach to first CPU + NCCL_CHECK(xmlFindTag(xml, "cpu", &parent)); + } + struct ncclXmlNode* nicNode = NULL; + NCCL_CHECK(xmlGetSub(parent, "nic", &nicNode)); + if (nicNode == NULL) { + NCCL_CHECK(xmlAddNode(xml, parent, "nic", &nicNode)); + } + // We know that this net does not exist yet (we searched for it at the + // beginning of this function), so we can add it. + NCCL_CHECK(xmlAddNode(xml, nicNode, "net", netNode)); + NCCL_CHECK(xmlSetAttr(*netNode, "name", netName)); + return ncclSuccess; + } + + ncclResult_t ncclTopoSetPaths(struct ncclTopoNode* baseNode, struct ncclTopoSystem* system) { + if (baseNode->paths[baseNode->type] == NULL) { + NCCL_CHECK(ncclCalloc(baseNode->paths+baseNode->type, system->nodes[baseNode->type].count)); + } + // breadth-first search to set all paths to that node in the system + struct ncclTopoNodeList nodeList; + struct ncclTopoNodeList nextNodeList; + nodeList.count = 1; nodeList.list[0] = baseNode; + nextNodeList.count = 0; + struct ncclTopoLinkList* basePath; + NCCL_CHECK(getPath(system, baseNode, baseNode->type, baseNode->id, &basePath)); + basePath->count = 0; + basePath->bw = LOC_BW; + basePath->type = PATH_LOC; + + while (nodeList.count) { + nextNodeList.count = 0; + for (int n=0; ntype, baseNode->id, &path)); + for (int l=0; lnlinks; l++) { + struct ncclTopoLink* link = node->links+l; + struct ncclTopoNode* remNode = link->remNode; + if (remNode->paths[baseNode->type] == NULL) { + NCCL_CHECK(ncclCalloc(remNode->paths+baseNode->type, system->nodes[baseNode->type].count)); + for (int i=0; inodes[baseNode->type].count; i++) remNode->paths[baseNode->type][i].type = PATH_DIS; + } + struct ncclTopoLinkList* remPath; + NCCL_CHECK(getPath(system, remNode, baseNode->type, baseNode->id, &remPath)); + float bw = std::min(path->bw, link->bw); + + // allow routing through a GPU only as 1 hop + if (node != baseNode && node->type == GPU && + (link->type != LINK_NVL || remNode->type != GPU || path->count > 1)) continue; + + if ((remPath->bw == 0 || remPath->count > path->count) && remPath->bw < bw) { + // Find reverse link + for (int l=0; lnlinks; l++) { + if (remNode->links[l].remNode == node && remNode->links[l].type == link->type) { + remPath->list[0] = remNode->links+l; + break; + } + } + if (remPath->list[0] == NULL) { + // WARN("Failed to find reverse path from remNode %d/%lx nlinks %d to node %d/%lx", + // remNode->type, remNode->id, remNode->nlinks, node->type, node->id); + return ncclInternalError; + } + // Copy the rest of the path + for (int i=0; icount; i++) remPath->list[i+1] = path->list[i]; + remPath->count = path->count + 1; + remPath->bw = bw; + // Start with path type = link type. PATH and LINK types are supposed to match. + // Don't consider LINK_NET as we only care about the NIC->GPU path. + int type = link->type == LINK_NET ? LINK_LOC : link->type; + // Differentiate between one and multiple PCI switches + if (node->type == PCI && remNode->type == PCI) type = PATH_PXB; + // Consider a path going through the CPU as PATH_PHB + if (link->type == LINK_PCI && (node->type == CPU || link->remNode->type == CPU)) type = PATH_PHB; + // Set 1 hop NVLink as NVB + if (node->type == GPU && path->type == PATH_NVL && type == PATH_NVL && remPath->count > 1) type = PATH_NVB; + remPath->type = std::max(path->type, type); + // Add to the list for the next iteration if not already in the list + int i; + for (i=0; inodes[GPU].count; i++) { + if (system->nodes[GPU].nodes[i].gpu.rank == rank) { + *index = i; + return ncclSuccess; + } + } + return ncclInternalError; + } + + ncclResult_t ncclTopoSort(struct ncclTopoNode* node, struct ncclTopoNode* upNode) { + // Shift all links to have upLink as last link + if (upNode) { + int l=0; + while (node->links[l].remNode != upNode) l++; + struct ncclTopoLink upLink; + memcpy(&upLink, node->links+l, sizeof(struct ncclTopoLink)); + while (node->links[l+1].remNode) { + memcpy(node->links+l, node->links+l+1, sizeof(struct ncclTopoLink)); + l++; + } + memcpy(node->links+l, &upLink, sizeof(struct ncclTopoLink)); + } + // Recursively sort the PCI tree + for (int l=0; lnlinks; l++) { + struct ncclTopoLink* link = node->links+l; + if (link->type == LINK_PCI && link->remNode != upNode) NCCL_CHECK(ncclTopoSort(link->remNode, node)); + } + return ncclSuccess; + } + + ncclResult_t ncclTopoSortSystem(struct ncclTopoSystem* system) { + for (int n=0; nnodes[CPU].count; n++) NCCL_CHECK(ncclTopoSort(system->nodes[CPU].nodes+n, NULL)); + return ncclSuccess; + } + + int construct_xml_stru(struct ncclXml **xml, const std::vector &local_rank_vec) { + nvmlInit(); + assert(local_rank_vec.size() < MAX_SUBS); + xmlAlloc(xml, NCCL_GRAPH_XML_MAX_NODES); + struct ncclXmlNode* top; + xmlAddNode(*xml, NULL, "system", &top); + for (int idx = 0; idx < local_rank_vec.size(); ++idx) { + char busIdStr[] = "00000000:00:00.0"; + cudaDeviceGetPCIBusId(busIdStr, sizeof(busIdStr), local_rank_vec[idx]); + struct ncclXmlNode *node; + ncclTopoFillGpu(*xml, busIdStr, &node); + xmlSetAttrInt(node, "rank", idx); + xmlSetAttrInt(node, "gdr", 1); + } + ibv_fork_init(); + int num_of_device; + struct ibv_device **dev_list; + struct ibv_device *ib_dev = nullptr; + dev_list = ibv_get_device_list(&num_of_device); + int nic_idx = 0; + for (; ib_dev = *dev_list; ++dev_list) { + struct ibv_context *context = ibv_open_device(ib_dev); + struct ibv_device_attr devAttr; + memset(&devAttr, 0, sizeof(devAttr)); + ibv_query_device(context, &devAttr); + int port_cnt = devAttr.phys_port_cnt; + for (int port_idx = 1; port_idx <= port_cnt; ++port_idx) { + struct ibv_port_attr portAttr; + ibv_query_port(context, port_idx, &portAttr); + if (portAttr.state != IBV_PORT_ACTIVE) continue; + if (portAttr.link_layer != IBV_LINK_LAYER_INFINIBAND + && portAttr.link_layer != IBV_LINK_LAYER_ETHERNET) continue; + char *pciPath = nullptr; + ncclIbGetPciPath(ib_dev->name, &pciPath); + struct ncclXmlNode* netNode; + ncclTopoFillNet(*xml, pciPath, ib_dev->name, &netNode); + xmlSetAttrInt(netNode, "keep", 1); + xmlSetAttrInt(netNode, "dev", nic_idx++); + xmlInitAttrInt(netNode, "speed", ncclIbSpeed(portAttr.active_speed) * ncclIbWidth(portAttr.active_width)); + xmlInitAttrInt(netNode, "port", port_idx); + xmlInitAttrFloat(netNode, "latency", 0.0f); + xmlInitAttrUint64(netNode, "guid", devAttr.sys_image_guid); + xmlInitAttrInt(netNode, "gdr", 1); + break; + } + ibv_close_device(context); + } + // ibv_free_device_list(dev_list); + return 0; + } + + int construct_topo_stru(struct ncclXml* xml, struct ncclTopoSystem** topoSys) { + *topoSys = (ncclTopoSystem *)calloc(1, sizeof(ncclTopoSystem)); + [[maybe_unused]] struct ncclTopoSystem *system = *topoSys; + struct ncclXmlNode* topNode; + xmlFindTag(xml, "system", &topNode); + for (int s=0; snSubs; s++) { + struct ncclXmlNode* node = topNode->subs[s]; + if (strcmp(node->name, "cpu") == 0) ncclTopoAddCpu(node, *topoSys); + } + ncclTopoConnectCpus(*topoSys); + ncclTopoSortSystem(*topoSys); + return 0; + } + + int compute_paths(struct ncclTopoSystem *topoSys) { + for (int idx = 0; idx < topoSys->nodes[GPU].count; ++idx) { + ncclTopoSetPaths(topoSys->nodes[GPU].nodes + idx, topoSys); + } + for (int idx = 0; idx < topoSys->nodes[NET].count; ++idx) { + ncclTopoSetPaths(topoSys->nodes[NET].nodes + idx, topoSys); + } + return 0; + } + + int select_net(struct ncclTopoSystem* system, int gpuDev, const char **netStr) { + int idx = 0; + int localNets = 0; + int localGpus = 0; + int tempLocalGpus = 0; + int gpuIdxInVec = -1; + int localNetIndexes[NCCL_TOPO_MAX_NODES]; + int localGpuIndexes[NCCL_TOPO_MAX_NODES]; + int tempLocalGpuIndexes[NCCL_TOPO_MAX_NODES]; + int netPathType = 0; + int gpuPathType = 0; + ncclTopoRankToIndex(system, gpuDev, &idx); + ncclTopoGetLocal(system, GPU, idx, NET, localNetIndexes, &localNets, &netPathType); + for (int idx = 0; idx < localNets; ++idx) { + ncclTopoGetLocal(system, NET, localNetIndexes[idx], GPU, tempLocalGpuIndexes, &tempLocalGpus, &gpuPathType); + for (int new_idx = 0; new_idx < tempLocalGpus; ++new_idx) { + for (int curr_idx = 0; curr_idx < localGpus; ++curr_idx) { + if (localGpuIndexes[curr_idx] == tempLocalGpuIndexes[new_idx]) continue; + } + localGpuIndexes[localGpus++] = tempLocalGpuIndexes[new_idx]; + } + } + for (int idx = 0; idx < localGpus; ++idx) { + if (gpuDev == localGpuIndexes[idx]) gpuIdxInVec = idx; + } + *netStr = system->nodes[NET].nodes[localNetIndexes[gpuIdxInVec % localNets]].net.name; + return 0; + } + } //namespace + + inline int get_nic(const std::vector &gpu_idx_vec, int gpu_idx, const char **netName) { + struct ncclXml *xml; + construct_xml_stru(&xml, gpu_idx_vec); + struct ncclTopoSystem *topoSys; + construct_topo_stru(xml, &topoSys); + compute_paths(topoSys); + select_net(topoSys, gpu_idx, netName); + return 0; + } + } //namespace hybrid_ep + #endif //HYBRID_EP_BUILD_MULTINODE_ENABLE + \ No newline at end of file diff --git a/csrc/hybrid_ep/backend/utils.cuh b/csrc/hybrid_ep/backend/utils.cuh index 39862b9e..95e3eea5 100644 --- a/csrc/hybrid_ep/backend/utils.cuh +++ b/csrc/hybrid_ep/backend/utils.cuh @@ -289,4 +289,53 @@ inline void print_ptr_info(void* p) { CUresult r = cuMemGetAddressRange(&base, &size, reinterpret_cast(p)); fprintf(stderr, "alloc_base=%p, alloc_size=%zu bytes\n", reinterpret_cast(base), size); } -} \ No newline at end of file +} + +/* Error type */ +typedef enum { + ncclSuccess = 0, + ncclUnhandledCudaError = 1, + ncclSystemError = 2, + ncclInternalError = 3, + ncclInvalidArgument = 4, + ncclInvalidUsage = 5, + ncclRemoteError = 6, + ncclInProgress = 7, + ncclNumResults = 8 +} ncclResult_t; + +#define NCCL_CHECK(call) \ + do { \ + ncclResult_t RES = call; \ + if (RES != ncclSuccess && RES != ncclInProgress) { \ + /* Print the back trace*/ \ + fprintf(stderr, "%s:%d -> %d\n", __FILE__, __LINE__, RES); \ + return RES; \ + } \ + } while (0) + +template +ncclResult_t ncclCallocDebug(T** ptr, size_t nelem, const char* filefunc, int line) { + void* p = malloc(nelem * sizeof(T)); + if (p == NULL) { + // WARN("Failed to malloc %ld bytes", nelem*sizeof(T)); + return ncclSystemError; + } + // INFO(NCCL_ALLOC, "%s:%d malloc Size %ld pointer %p", filefunc, line, + // nelem*sizeof(T), p); + memset(p, 0, nelem * sizeof(T)); + *ptr = (T*)p; + return ncclSuccess; +} +#define ncclCalloc(...) ncclCallocDebug(__VA_ARGS__, __FILE__, __LINE__) + +#define CALL_CHECK(call) \ + do { \ + int result = call; \ + if (result != 0) { \ + fprintf(stderr, "file=%s, line=%d, call='%s', returned=%d.\n", \ + __FILE__, __LINE__, #call, result); \ + abort(); \ + } \ + } while(0) + diff --git a/csrc/hybrid_ep/config.cuh b/csrc/hybrid_ep/config.cuh index 0a041a63..c22fe7ba 100644 --- a/csrc/hybrid_ep/config.cuh +++ b/csrc/hybrid_ep/config.cuh @@ -20,6 +20,7 @@ struct BufferConfig { int num_of_blocks_preprocessing_api; int num_of_blocks_dispatch_api; int num_of_blocks_combine_api; + int num_of_blocks_permute_api; int num_of_tokens_per_chunk_dispatch_api; int num_of_tokens_per_chunk_combine_api; @@ -28,7 +29,11 @@ struct BufferConfig { */ bool is_valid(){ bool valid = true; - valid &= (hidden_dim % 512 == 0); + if (token_data_type == APP_TOKEN_DATA_TYPE::UINT8) { + valid &= (hidden_dim % 512 == 0); // Make TMA work in scaling factor. + } else { + valid &= (hidden_dim % 16 == 0); // Make TMA work. + } valid &= ((num_of_experts_per_rank * num_of_ranks_per_node) % 4 == 0); valid &= (num_of_ranks_per_node % 2 == 0); return valid; @@ -51,12 +56,14 @@ struct HybridEpConfigInstance { */ int num_of_threads_per_block_preprocessing_api; int num_of_blocks_preprocessing_api; + int num_of_blocks_permute_api; /* * Dispatch API Config */ APP_TOKEN_DATA_TYPE token_data_type; int num_of_stages_dispatch_api; + int num_of_in_flight_s2g_dispatch_api; int num_of_tokens_per_chunk_dispatch_api; int num_of_blocks_dispatch_api; bool forward_dispatch_api; @@ -79,7 +86,11 @@ struct HybridEpConfigInstance { */ bool is_valid(){ bool valid = true; - valid &= (hidden_dim % 512 == 0); + if (token_data_type == APP_TOKEN_DATA_TYPE::UINT8) { + valid &= (hidden_dim % 512 == 0); // Make TMA work in scaling factor. + } else { + valid &= (hidden_dim % 16 == 0); // Make TMA work. + } valid &= ((num_of_experts_per_rank * num_of_ranks_per_node) % 4 == 0); valid &= (num_of_ranks_per_node % 2 == 0); return valid; diff --git a/csrc/hybrid_ep/executor/executor.cu b/csrc/hybrid_ep/executor/executor.cu index 89b2fed5..dbe2fda6 100644 --- a/csrc/hybrid_ep/executor/executor.cu +++ b/csrc/hybrid_ep/executor/executor.cu @@ -30,8 +30,9 @@ Executor::metadata_preprocess_core( auto attn_to_rdma_map = torch::empty({num_of_tokens_per_rank, config.num_of_nodes - 1}, torch::dtype(torch::kBool).device(torch::kCUDA)); + // Put on the pinned memory auto num_of_tokens_for_experts = - torch::empty({1}, torch::dtype(torch::kInt32).device(torch::kCUDA)); + torch::empty({1}, torch::dtype(torch::kInt32).pinned_memory(true)); auto local_expert_routing_map = torch::empty( {num_of_tokens_per_rank * config.num_of_ranks_per_node * config.num_of_nodes, config.num_of_experts_per_rank}, torch::dtype(torch::kBool).device(torch::kCUDA)); @@ -48,7 +49,7 @@ Executor::metadata_preprocess_core( return std::make_tuple(sparse_to_dense_map, rdma_to_attn_map, attn_to_rdma_map, num_of_tokens_for_experts, local_expert_routing_map); } -void Executor::dispatch_preprocess(HybridEpConfigInstance config, DispatchBuffers& dispatch_buffers, DispatchArgs& args) { +std::tuple Executor::dispatch_preprocess(HybridEpConfigInstance config, DispatchBuffers& dispatch_buffers, DispatchArgs& args) { nvtxRangePushA("dispatch_preprocess in hybrid-ep"); if(config.num_of_nodes > 1) { #ifdef HYBRID_EP_BUILD_MULTINODE_ENABLE @@ -69,7 +70,44 @@ void Executor::dispatch_preprocess(HybridEpConfigInstance config, DispatchBuffer dispatch_buffers.attn_input_prob = (config.forward_dispatch_api) ? args.probs.data_ptr() : nullptr; dispatch_buffers.attn_input_scaling_factor = (config.token_data_type == APP_TOKEN_DATA_TYPE::UINT8) ? args.scaling_factor.data_ptr() : nullptr; } + + torch::Tensor row_id_map; + torch::Tensor tokens_per_expert; + + if(args.enable_permute) { + if(args.row_id_map.has_value()){ + assert(args.num_permuted_tokens >= 0); + row_id_map = args.row_id_map.value(); + } else { + assert(args.local_expert_routing_map.has_value()); + std::tie(row_id_map, tokens_per_expert) = permute_processing( + args.local_expert_routing_map.value().data_ptr(), + args.num_dispatched_tokens_tensor.value(), + args.max_num_dispatched_tokens, + config.num_of_experts_per_rank, + args.pad_multiple, + config.num_of_blocks_preprocessing_api, + args.stream + ); + args.row_id_map = row_id_map; + + // If we want to put the tokens_per_expert/num_dispatched_tokens_tensor can be used in the host, we need to synchronize the stream. + if (!args.non_blocking) { + cudaStreamSynchronize(args.stream); + if (args.num_permuted_tokens < 0) { + const int32_t* tokens_per_expert_ptr = tokens_per_expert.data_ptr(); + int64_t num_permuted_tokens = 0; + for (int i = 0; i < config.num_of_experts_per_rank; ++i) { + num_permuted_tokens += static_cast(tokens_per_expert_ptr[i]); + } + args.num_permuted_tokens = num_permuted_tokens; + } + } + } + } nvtxRangePop(); // End of dispatch_preprocess nvtx range + + return std::make_tuple(row_id_map, tokens_per_expert); } template void Executor::dispatch_core(HybridEpConfigInstance config, DispatchBuffers& dispatch_buffers, DispatchArgs& args); @@ -114,8 +152,8 @@ void Executor::dispatch_core(HybridEpConfigInstance config, DispatchBuffers& dis param.expected_rdma_flag_value = dispatch_buffers.expected_rdma_flag_value; param.expected_intra_node_flag_value = dispatch_buffers.expected_intra_node_flag_value; #ifdef HYBRID_EP_BUILD_MULTINODE_ENABLE - param.d_qps_gpu = dispatch_buffers.d_qps_gpu; - param.mr_info = dispatch_buffers.mr_info; + param.d_qps_gpu = reinterpret_cast(dispatch_buffers.d_qps_gpu); + param.mr_info = reinterpret_cast(dispatch_buffers.mr_info); #endif // Launch kernel @@ -123,7 +161,7 @@ void Executor::dispatch_core(HybridEpConfigInstance config, DispatchBuffers& dis nvtxRangePop(); // End of dispatch_core nvtx range } -std::tuple, c10::optional, torch::Tensor, torch::Tensor> +std::tuple, c10::optional > Executor::dispatch_postprocess(HybridEpConfigInstance config, DispatchBuffers& dispatch_buffers, DispatchArgs& args) { nvtxRangePushA("dispatch_postprocess in hybrid-ep"); @@ -132,91 +170,63 @@ Executor::dispatch_postprocess(HybridEpConfigInstance config, DispatchBuffers& d torch::Tensor dispatched_tokens; c10::optional dispatched_probs; c10::optional dispatched_scaling_factor; - // Possible ouput from the permute part - torch::Tensor row_id_map, tokens_per_expert; - - if(args.num_dispatched_tokens == 0 ) { - // Fast return empty tensors if there are no tokens to dispatch - dispatched_tokens = torch::empty({0, config.hidden_dim}, torch::dtype(args.hidden.dtype()).device(torch::kCUDA)); - if(config.forward_dispatch_api) { - dispatched_probs = torch::empty({0}, torch::dtype(torch::kFloat32).device(torch::kCUDA)); - } - if(config.token_data_type == APP_TOKEN_DATA_TYPE::UINT8) { - dispatched_scaling_factor = torch::empty({0, config.hidden_dim / 128}, torch::dtype(torch::kFloat32).device(torch::kCUDA)); - } - row_id_map = torch::empty({0, config.num_of_experts_per_rank}, torch::dtype(torch::kInt32).device(torch::kCUDA)); - tokens_per_expert = torch::zeros({config.num_of_experts_per_rank}, torch::dtype(torch::kInt32).device(torch::kCUDA)); - return std::make_tuple(dispatched_tokens, dispatched_probs, dispatched_scaling_factor, row_id_map, tokens_per_expert); - } if(args.enable_permute) { // Use permute kernel to avoid standalone D2D memory copy assert(args.num_dispatched_tokens_tensor.has_value()); - - int num_dispatched_tokens = args.num_dispatched_tokens; - int num_permuted_tokens = args.num_permuted_tokens; - torch::Tensor num_dispatched_tokens_tensor = args.num_dispatched_tokens_tensor.value(); - - if (args.row_id_map.has_value()) { - // The row_id_map is valid, which means that the cached model is used. - // Then we will use the values in args directly. - assert(args.num_permuted_tokens >= 0); - row_id_map = args.row_id_map.value(); - } else { - // Otherwise, we will compute the row_id_map/tokens_per_expert by preprocessing kernel. - assert(args.local_expert_routing_map.has_value()); - - std::tie(row_id_map, tokens_per_expert) = permute_processing( - args.local_expert_routing_map.value().data_ptr(), num_dispatched_tokens_tensor, - num_dispatched_tokens, config.num_of_experts_per_rank, args.pad_multiple, args.stream); + assert(args.row_id_map.has_value()); + assert(args.num_permuted_tokens >= 0); - // If use pre-allocated sync-free mode, we use the value in args directly. - // otherwise, we will compute the num_permuted_tokens by summing the tokens_per_expert. - if (num_permuted_tokens < 0) { - if (args.use_host_meta) { - auto host_opts = tokens_per_expert.options().device(torch::kCPU).pinned_memory(true); - torch::Tensor tokens_per_expert_pinned = torch::empty(tokens_per_expert.sizes(), host_opts); - tokens_per_expert_pinned.copy_(tokens_per_expert, /*non_blocking=*/false); - tokens_per_expert = tokens_per_expert_pinned; - } - num_permuted_tokens = tokens_per_expert.sum().item(); - } + // Prepare the arguments for the permute kernel + PermuteArgs permute_args; + permute_args.tokens_ptr = reinterpret_cast(dispatch_buffers.expert_output_token); + permute_args.probs_ptr = reinterpret_cast(dispatch_buffers.expert_output_prob); + permute_args.scaling_factor_ptr = reinterpret_cast(dispatch_buffers.expert_output_scaling_factor); + permute_args.row_id_map = args.row_id_map.value(); + permute_args.hidden_size = config.hidden_dim; + permute_args.scales_per_token = config.hidden_dim / 128; + permute_args.num_dispatched_token_tensor = args.num_dispatched_tokens_tensor.value(); + permute_args.num_permuted_token = args.num_permuted_tokens; + permute_args.num_ranks_per_node = config.num_of_ranks_per_node; + permute_args.num_of_local_experts = config.num_of_experts_per_rank; + permute_args.pad_multiple = args.pad_multiple; + permute_args.local_rank = local_rank; + permute_args.use_fp8 = config.token_data_type == APP_TOKEN_DATA_TYPE::UINT8; + permute_args.with_probs = config.forward_dispatch_api; + permute_args.token_options = args.hidden.options(); + permute_args.stream = args.stream; + permute_args.num_of_blocks_permute_api = config.num_of_blocks_permute_api; + + if(config.token_data_type == APP_TOKEN_DATA_TYPE::UINT16) { + std::tie(dispatched_tokens, dispatched_scaling_factor, dispatched_probs) = + permute_launcher(permute_args); + } else if (config.token_data_type == APP_TOKEN_DATA_TYPE::UINT8) { + std::tie(dispatched_tokens, dispatched_scaling_factor, dispatched_probs) = + permute_launcher(permute_args); + }else { + throw std::runtime_error("Unsupported token data type: " + type_to_string(config.token_data_type)); } - - if (config.token_data_type == APP_TOKEN_DATA_TYPE::UINT16) { - std::tie(dispatched_tokens, dispatched_scaling_factor, dispatched_probs) = permute_launcher( - reinterpret_cast(dispatch_buffers.expert_output_token), - reinterpret_cast(dispatch_buffers.expert_output_prob), - reinterpret_cast(dispatch_buffers.expert_output_scaling_factor), row_id_map, - config.hidden_dim, config.hidden_dim / 128, local_rank, config.num_of_ranks_per_node, - config.num_of_experts_per_rank, num_dispatched_tokens_tensor, num_dispatched_tokens, - num_permuted_tokens, args.pad_multiple, - false, // use_fp8 - config.forward_dispatch_api, args.hidden.options(), args.stream); - - } else { - std::tie(dispatched_tokens, dispatched_scaling_factor, dispatched_probs) = permute_launcher( - reinterpret_cast(dispatch_buffers.expert_output_token), - reinterpret_cast(dispatch_buffers.expert_output_prob), - reinterpret_cast(dispatch_buffers.expert_output_scaling_factor), row_id_map, - config.hidden_dim, config.hidden_dim / 128, local_rank, config.num_of_ranks_per_node, - config.num_of_experts_per_rank, num_dispatched_tokens_tensor, num_dispatched_tokens, - num_permuted_tokens, args.pad_multiple, - true, // use_fp8 - config.forward_dispatch_api, args.hidden.options(), args.stream); - } }else { // D2D copy the result to the pytorch tensor + int num_dispatched_tokens = 0; + if (args.num_dispatched_tokens >= 0) { + num_dispatched_tokens = args.num_dispatched_tokens; + } else { + num_dispatched_tokens = args.num_dispatched_tokens_tensor.value().item(); + } size_t sizeof_token_data_type = get_token_data_type_size(dispatch_buffers.data_type); - dispatched_tokens = torch::empty({args.num_dispatched_tokens, config.hidden_dim}, torch::dtype(args.hidden.dtype()).device(torch::kCUDA)); - auto res_sz = args.num_dispatched_tokens * config.hidden_dim * sizeof_token_data_type; + dispatched_tokens = torch::empty( + {num_dispatched_tokens, config.hidden_dim}, + torch::dtype(args.hidden.dtype()).device(torch::kCUDA) + ); + auto res_sz = num_dispatched_tokens * config.hidden_dim * sizeof_token_data_type; CUDA_CHECK(cudaMemcpyAsync(dispatched_tokens.data_ptr(), dispatch_buffers.expert_output_token, res_sz, cudaMemcpyDeviceToDevice, args.stream)); if(config.forward_dispatch_api) { - dispatched_probs = torch::empty({args.num_dispatched_tokens, + dispatched_probs = torch::empty({num_dispatched_tokens, config.num_of_experts_per_rank * config.num_of_ranks_per_node}, torch::dtype(torch::kFloat32).device(torch::kCUDA)); - auto probs_sz = args.num_dispatched_tokens * config.num_of_experts_per_rank * config.num_of_ranks_per_node * sizeof(float); + auto probs_sz = num_dispatched_tokens * config.num_of_experts_per_rank * config.num_of_ranks_per_node * sizeof(float); CUDA_CHECK(cudaMemcpyAsync(dispatched_probs.value().data_ptr(), dispatch_buffers.expert_output_prob, probs_sz, cudaMemcpyDeviceToDevice, args.stream)); @@ -224,10 +234,10 @@ Executor::dispatch_postprocess(HybridEpConfigInstance config, DispatchBuffers& d if(config.token_data_type == APP_TOKEN_DATA_TYPE::UINT8) { dispatched_scaling_factor = torch::empty({ - args.num_dispatched_tokens, + num_dispatched_tokens, config.hidden_dim / 128}, torch::dtype(torch::kFloat32).device(torch::kCUDA)); - auto scaling_factor_sz = args.num_dispatched_tokens * config.hidden_dim / 128 * sizeof(float); + auto scaling_factor_sz = num_dispatched_tokens * config.hidden_dim / 128 * sizeof(float); CUDA_CHECK(cudaMemcpyAsync(dispatched_scaling_factor.value().data_ptr(), dispatch_buffers.expert_output_scaling_factor, scaling_factor_sz, cudaMemcpyDeviceToDevice, args.stream)); @@ -235,7 +245,7 @@ Executor::dispatch_postprocess(HybridEpConfigInstance config, DispatchBuffers& d } nvtxRangePop(); // End of dispatch_postprocess nvtx range - return std::make_tuple(dispatched_tokens, dispatched_probs, dispatched_scaling_factor, row_id_map, tokens_per_expert); + return std::make_tuple(dispatched_tokens, dispatched_probs, dispatched_scaling_factor); } void Executor::combine_preprocess(HybridEpConfigInstance config, CombineBuffers& combine_buffers, CombineArgs& args) { @@ -252,15 +262,28 @@ void Executor::combine_preprocess(HybridEpConfigInstance config, CombineBuffers& // If args.num_dispatched_tokens >= 0, which means that the sync-free model is used. // Otherwise, we will use the values in args.num_dispatched_tokens_tensor. if (num_dispatched_tokens < 0) { + // Synchronize the stream to get the real num_dispatched_tokens from the pinned memory. + cudaStreamSynchronize(args.stream); num_dispatched_tokens = num_dispatched_tokens_tensor.item(); } - unpermute_launcher( - args.hidden, args.probs, reinterpret_cast(combine_buffers.expert_input_token), - reinterpret_cast(combine_buffers.expert_input_prob), args.row_id_map.value(), - config.num_of_experts_per_rank, num_dispatched_tokens_tensor, num_dispatched_tokens, - args.pad_multiple, config.hidden_dim, local_rank, config.num_of_ranks_per_node, - config.backward_combine_api, args.stream); + UnpermuteArgs unpermute_args; + unpermute_args.permuted_tokens = args.hidden; + unpermute_args.permuted_probs = args.probs; + unpermute_args.tokens_ptr = reinterpret_cast(combine_buffers.expert_input_token); + unpermute_args.probs_ptr = reinterpret_cast(combine_buffers.expert_input_prob); + unpermute_args.row_id_map = args.row_id_map.value(); + unpermute_args.num_of_local_experts = config.num_of_experts_per_rank; + unpermute_args.num_dispatched_tokens_tensor = num_dispatched_tokens_tensor; + unpermute_args.pad_multiple = args.pad_multiple; + unpermute_args.hidden_size = config.hidden_dim; + unpermute_args.local_rank = local_rank; + unpermute_args.num_ranks_per_node = config.num_of_ranks_per_node; + unpermute_args.with_probs = config.backward_combine_api; + unpermute_args.stream = args.stream; + unpermute_args.num_of_blocks_permute_api = config.num_of_blocks_permute_api; + + unpermute_launcher(unpermute_args); }else{ // Copy the input tensor to the input buffer @@ -319,8 +342,8 @@ void Executor::combine_core(HybridEpConfigInstance config, CombineBuffers& combi param.expected_intra_node_flag_value = combine_buffers.expected_intra_node_flag_value; #ifdef HYBRID_EP_BUILD_MULTINODE_ENABLE - param.d_qps_gpu = combine_buffers.d_qps_gpu; - param.mr_info = combine_buffers.mr_info; + param.d_qps_gpu = reinterpret_cast(combine_buffers.d_qps_gpu); + param.mr_info = reinterpret_cast(combine_buffers.mr_info); #endif // Launch kernel @@ -334,3 +357,4 @@ void Executor::combine_postprocess(HybridEpConfigInstance config, CombineBuffers // No postprocess is needed for the combine kernel now. nvtxRangePop(); // End of combine_postprocess nvtx range } + diff --git a/csrc/hybrid_ep/executor/executor.cuh b/csrc/hybrid_ep/executor/executor.cuh index 916f0496..97793abd 100644 --- a/csrc/hybrid_ep/executor/executor.cuh +++ b/csrc/hybrid_ep/executor/executor.cuh @@ -26,15 +26,16 @@ public: torch::Tensor attn_to_rdma_map; c10::optional num_dispatched_tokens_tensor; // Used in the permute c10::optional local_expert_routing_map; // Used in the permute - // Used in the sync-free permute + int64_t num_dispatched_tokens = -1; - // Cached permute + // Used in the permute case, use up-bound to avoid synchronization to get the real num_dispatched_tokens from the pinned memory + int64_t max_num_dispatched_tokens = -1; c10::optional row_id_map; int64_t num_permuted_tokens = -1; // Misc int pad_multiple; // Used in the padding case of permute bool enable_permute = false; - bool use_host_meta = false; // If enable this, the produced num_dispatched_tokens will be put + bool non_blocking = false; // If enable this, the produced num_dispatched_tokens will be put // on the CPU pinned memory, and the tokens_per_expert will be put // on the CPU, which may reduce the times of the sync int64_t num_of_tokens_per_rank; // Dynamic sequence length @@ -72,12 +73,13 @@ public: int num_of_tokens_per_rank ); - void dispatch_preprocess( - HybridEpConfigInstance config, DispatchBuffers& dispatch_buffers, DispatchArgs& args); // Now is empty op, will be filled with D2D in the inter-node case + std::tuple + dispatch_preprocess( + HybridEpConfigInstance config, DispatchBuffers& dispatch_buffers, DispatchArgs& args); template void dispatch_core( HybridEpConfigInstance config, DispatchBuffers& dispatch_buffers, DispatchArgs& args); - std::tuple, c10::optional, torch::Tensor, torch::Tensor> + std::tuple, c10::optional > dispatch_postprocess( HybridEpConfigInstance config, DispatchBuffers& dispatch_buffers, DispatchArgs& args); @@ -86,11 +88,12 @@ public: void combine_core( HybridEpConfigInstance config, CombineBuffers& combine_buffers, CombineArgs& args); void combine_postprocess( - HybridEpConfigInstance config, CombineBuffers& combine_buffers, CombineArgs& args); // Now is empty op, will be filled with D2D in the inter-node case + HybridEpConfigInstance config, CombineBuffers& combine_buffers, CombineArgs& args); private: KernelCache kernel_cache; HybridEpConfigInstance config; int local_rank; int node_rank; -}; \ No newline at end of file +}; + diff --git a/csrc/hybrid_ep/extension/permute.cu b/csrc/hybrid_ep/extension/permute.cu index 051b7dc7..db2aa02f 100644 --- a/csrc/hybrid_ep/extension/permute.cu +++ b/csrc/hybrid_ep/extension/permute.cu @@ -4,57 +4,13 @@ #include "permute.cuh" template std::tuple, c10::optional> - permute_launcher(uint16_t* tokens_ptr, - float* probs_ptr, - float* scaling_factor_ptr, - torch::Tensor row_id_map, - int hidden_size, - int scales_per_token, - int local_rank, - int num_ranks_per_node, - int num_of_local_experts, - torch::Tensor num_dispatched_token_tensor, - int num_dispatched_tokens, - int num_permuted_token, - int pad_multiple, - bool use_fp8, - bool with_probs, - torch::TensorOptions token_options, - cudaStream_t stream); + permute_launcher(PermuteArgs args); template std::tuple, c10::optional> - permute_launcher(uint8_t* tokens_ptr, - float* probs_ptr, - float* scaling_factor_ptr, - torch::Tensor row_id_map, - int hidden_size, - int scales_per_token, - int local_rank, - int num_ranks_per_node, - int num_of_local_experts, - torch::Tensor num_dispatched_token_tensor, - int num_dispatched_tokens, - int num_permuted_token, - int pad_multiple, - bool use_fp8, - bool with_probs, - torch::TensorOptions token_options, - cudaStream_t stream); - - template void unpermute_launcher(torch::Tensor permuted_tokens, - c10::optional permuted_probs, - uint16_t* tokens_ptr, - float* probs_ptr, - torch::Tensor row_id_map, - int num_of_local_experts, - torch::Tensor num_dispatched_tokens_tensor, - int num_dispatched_tokens, - int pad_multiple, - int hidden_size, - int local_rank, - int num_ranks_per_node, - bool with_probs, - cudaStream_t stream); + permute_launcher(PermuteArgs args); + + template void unpermute_launcher(UnpermuteArgs args); + /** * @brief Permute the tokens to the experts * @param routing_map[in] shape: [num_dispatched_tokens, num_of_local_experts], @@ -261,38 +217,25 @@ std::tuple permute_processing( bool* routing_map, torch::Tensor num_dispatched_token_tensor, - int num_dispatched_tokens, + // Used in the permute case, use up-bound to avoid synchronization to get the real num_dispatched_tokens from the pinned memory + int max_num_dispatched_tokens, int num_of_local_experts, int pad_multiple, + int num_of_blocks, cudaStream_t stream) { constexpr int block_size = 256; const int warp_size = 32; - - // Get the number of SMs for the current device - int num_sms = at::cuda::getCurrentDeviceProperties()->multiProcessorCount; - // Leave 20 SMs for other kernels - int grid_size = num_sms - 20; - + assert(num_of_local_experts <= block_size); assert(num_of_local_experts > 0); - auto row_id_map = torch::empty({num_dispatched_tokens + pad_multiple, num_of_local_experts}, + auto row_id_map = torch::empty({max_num_dispatched_tokens + pad_multiple, num_of_local_experts}, torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA)); auto tokens_per_expert = torch::empty( - {num_of_local_experts}, torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA)); - - // If the size of the allocated dispatched tokens is 0, return the empty - // tensors - if (num_dispatched_tokens == 0) { - // Fill the tokens_per_expert with 0 if no tokens need to permute in the - // current rank - tokens_per_expert.zero_(); - row_id_map.zero_(); - return std::make_tuple(row_id_map, tokens_per_expert); - } + {num_of_local_experts}, torch::TensorOptions().dtype(torch::kInt32).pinned_memory(true)); // Construct the template buffers - int rows_workspace_1 = (num_dispatched_tokens + block_size - 1) / block_size; + int rows_workspace_1 = (max_num_dispatched_tokens + block_size - 1) / block_size; int rows_workspace_2 = (rows_workspace_1 + block_size - 1) / block_size; auto workspace1 = torch::empty({rows_workspace_1, num_of_local_experts}, torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA)); @@ -318,7 +261,7 @@ cudaFuncSetAttribute(permute_processing_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size); - cudaLaunchCooperativeKernel(permute_processing_kernel, grid_size, + cudaLaunchCooperativeKernel(permute_processing_kernel, num_of_blocks, block_size, args, shared_mem_size, stream); return std::make_tuple(row_id_map, tokens_per_expert); @@ -346,96 +289,79 @@ int64_t tokens_per_block = blockDim.x / 128; int64_t extended_laned_id = threadIdx.x % 128; int64_t extended_warp_id = threadIdx.x / 128; - int64_t block_start = blockIdx.x * tokens_per_block; - int64_t token_id = block_start + extended_warp_id; int num_dispatched_tokens = *num_dispatched_tokens_ptr + pad_multiple; - - // Compute the offset for each expert, means the prefix sum of tokens per - // expert extern __shared__ int shmem_in_permute_kernel[]; int* expert_routing_map = shmem_in_permute_kernel; - // Load the current routing map - for (int i = threadIdx.x; i < num_of_local_experts * tokens_per_block; i += block_size) { - expert_routing_map[i] = (block_start + i / num_of_local_experts < num_dispatched_tokens) - ? row_id_map[block_start * num_of_local_experts + i] - : 0; - } - __syncthreads(); - - if (token_id >= num_dispatched_tokens) { // If the token is out of range, return - return; - } - - // Permute the tokens - int num_eles_per_float4 = sizeof(float4) / sizeof(DType); - int64_t hidden_size_fp4 = hidden_size / num_eles_per_float4; - float4* tokens_fp4 = reinterpret_cast(tokens); - float4* permuted_tokens_fp4 = reinterpret_cast(permuted_tokens); - for (int64_t i = 0; i < num_of_local_experts; i++) { - int64_t dest_token_id = expert_routing_map[extended_warp_id * num_of_local_experts + i]; - if (dest_token_id > 0) { - for (int64_t j = extended_laned_id; j < hidden_size_fp4; j += 128) { - permuted_tokens_fp4[(dest_token_id - 1) * hidden_size_fp4 + j] = - tokens_fp4[token_id * hidden_size_fp4 + j]; - } - } else if (dest_token_id < 0) { - for (int64_t j = extended_laned_id; j < hidden_size_fp4; j += 128) { - permuted_tokens_fp4[(-dest_token_id - 1) * hidden_size_fp4 + j] = {0.0f, 0.0f, 0.0f, 0.0f}; - } - } - } - - // If use fp8, permute the scaling factor - if (scaling_factor != nullptr) { - for (int64_t i = 0; i < num_of_local_experts; i++) { - int64_t dest_token_id = expert_routing_map[extended_warp_id * num_of_local_experts + i]; - if (dest_token_id > 0) { - for (int64_t j = extended_laned_id; j < scales_per_token; j += 128) { - permuted_scaling_factor[(dest_token_id - 1) * scales_per_token + j] = - scaling_factor[token_id * scales_per_token + j]; - } - } else if (dest_token_id < 0) { - for (int64_t j = extended_laned_id; j < scales_per_token; j += 128) { - permuted_scaling_factor[(-dest_token_id - 1) * scales_per_token + j] = 0; - } - } - } - } - - // If use probs, permute the probs - if (probs != nullptr) { - for (int64_t i = 0; i < num_of_local_experts; i++) { - int64_t dest_token_id = expert_routing_map[extended_warp_id * num_of_local_experts + i]; - if (dest_token_id > 0) { - permuted_probs[dest_token_id - 1] = - probs[token_id * num_of_local_experts * num_ranks_per_node + - local_rank * num_of_local_experts + i]; - } else if (dest_token_id < 0) { - permuted_probs[(-dest_token_id - 1)] = 0; - } - } + + for(int64_t block_start = blockIdx.x * tokens_per_block; block_start < num_dispatched_tokens; block_start += tokens_per_block * gridDim.x) { + int64_t token_id = block_start + extended_warp_id; + // Load the current routing map + for (int i = threadIdx.x; i < num_of_local_experts * tokens_per_block; i += block_size) { + expert_routing_map[i] = (block_start + i / num_of_local_experts < num_dispatched_tokens) + ? row_id_map[block_start * num_of_local_experts + i] + : 0; + } + __syncthreads(); + + if (token_id < num_dispatched_tokens) { + // Permute the tokens + int num_eles_per_float4 = sizeof(float4) / sizeof(DType); + int64_t hidden_size_fp4 = hidden_size / num_eles_per_float4; + float4* tokens_fp4 = reinterpret_cast(tokens); + float4* permuted_tokens_fp4 = reinterpret_cast(permuted_tokens); + for (int64_t i = 0; i < num_of_local_experts; i++) { + int64_t dest_token_id = expert_routing_map[extended_warp_id * num_of_local_experts + i]; + if (dest_token_id > 0) { + for (int64_t j = extended_laned_id; j < hidden_size_fp4; j += 128) { + permuted_tokens_fp4[(dest_token_id - 1) * hidden_size_fp4 + j] = + tokens_fp4[token_id * hidden_size_fp4 + j]; + } + } else if (dest_token_id < 0) { + for (int64_t j = extended_laned_id; j < hidden_size_fp4; j += 128) { + permuted_tokens_fp4[(-dest_token_id - 1) * hidden_size_fp4 + j] = {0.0f, 0.0f, 0.0f, 0.0f}; + } + } + } + + // If use fp8, permute the scaling factor + if (scaling_factor != nullptr) { + for (int64_t i = 0; i < num_of_local_experts; i++) { + int64_t dest_token_id = expert_routing_map[extended_warp_id * num_of_local_experts + i]; + if (dest_token_id > 0) { + for (int64_t j = extended_laned_id; j < scales_per_token; j += 128) { + permuted_scaling_factor[(dest_token_id - 1) * scales_per_token + j] = + scaling_factor[token_id * scales_per_token + j]; + } + } else if (dest_token_id < 0) { + for (int64_t j = extended_laned_id; j < scales_per_token; j += 128) { + permuted_scaling_factor[(-dest_token_id - 1) * scales_per_token + j] = 0; + } + } + } + } + + // If use probs, permute the probs + if (probs != nullptr) { + for (int64_t i = 0; i < num_of_local_experts; i++) { + int64_t dest_token_id = expert_routing_map[extended_warp_id * num_of_local_experts + i]; + if (dest_token_id > 0) { + permuted_probs[dest_token_id - 1] = + probs[token_id * num_of_local_experts * num_ranks_per_node + + local_rank * num_of_local_experts + i]; + } else if (dest_token_id < 0) { + permuted_probs[(-dest_token_id - 1)] = 0; + } + } + } + } // if (token_id < num_dispatched_tokens) + __syncthreads(); } } template std::tuple, c10::optional> - permute_launcher(DType* tokens_ptr, - ProbType* probs_ptr, - ScalarType* scaling_factor_ptr, - torch::Tensor row_id_map, - int hidden_size, - int scales_per_token, - int local_rank, - int num_ranks_per_node, - int num_of_local_experts, - torch::Tensor num_dispatched_token_tensor, - int num_dispatched_tokens, - int num_permuted_token, - int pad_multiple, - bool use_fp8, - bool with_probs, - torch::TensorOptions token_options, - cudaStream_t stream) { + permute_launcher( PermuteArgs args) { + DType * tokens_ptr = reinterpret_cast(args.tokens_ptr); // Current only support 8-bits and 16-bits tokens assert((std::is_same::value || std::is_same::value)); // Current only support float probs @@ -444,49 +370,48 @@ assert((std::is_same::value)); // For alignment of float4 vectorizatized load if(std::is_same::value) { - assert(hidden_size % 16 == 0); + assert(args.hidden_size % 16 == 0); }else if(std::is_same::value) { - assert(hidden_size % 8 == 0); + assert(args.hidden_size % 8 == 0); } - assert(num_permuted_token >= 0); + assert(args.num_permuted_token >= 0); // Construct the output tensors auto permuted_tokens = - torch::empty({num_permuted_token, hidden_size}, token_options.device(torch::kCUDA)); - - int padded_num_dispatched_tokens = num_dispatched_tokens + pad_multiple; - + torch::empty({args.num_permuted_token, args.hidden_size}, args.token_options.device(torch::kCUDA)); + torch::Tensor permuted_scaling_factor, permuted_probs; - if (use_fp8) { + if (args.use_fp8) { permuted_scaling_factor = - torch::empty({num_permuted_token, scales_per_token}, + torch::empty({args.num_permuted_token, args.scales_per_token}, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA)); } - if (with_probs) { + if (args.with_probs) { permuted_probs = torch::empty( - {num_permuted_token}, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA)); - } - - // If the size of the allocated dispatched tokens is 0, return the empty - // tensors - if (padded_num_dispatched_tokens == 0) { - return std::make_tuple(permuted_tokens, permuted_scaling_factor, permuted_probs); + {args.num_permuted_token}, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA)); } // Launch the kernel constexpr int block_size = 512; constexpr int tokens_per_block = block_size / 128; - int grid_size = (padded_num_dispatched_tokens + tokens_per_block - 1) / tokens_per_block; - int shared_mem_size = num_of_local_experts * tokens_per_block * sizeof(int); - permute_kernel<<>>( + int grid_size = args.num_of_blocks_permute_api; + int shared_mem_size = args.num_of_local_experts * tokens_per_block * sizeof(int); + permute_kernel<<>>( reinterpret_cast(tokens_ptr), reinterpret_cast(permuted_tokens.data_ptr()), - use_fp8 ? reinterpret_cast(scaling_factor_ptr) : nullptr, - use_fp8 ? permuted_scaling_factor.data_ptr() : nullptr, - with_probs ? reinterpret_cast(probs_ptr) : nullptr, - with_probs ? permuted_probs.data_ptr() : nullptr, row_id_map.data_ptr(), - num_dispatched_token_tensor.data_ptr(), pad_multiple, num_of_local_experts, hidden_size, - scales_per_token, local_rank, num_ranks_per_node); + args.use_fp8 ? reinterpret_cast(args.scaling_factor_ptr) : nullptr, + args.use_fp8 ? permuted_scaling_factor.data_ptr() : nullptr, + args.with_probs ? reinterpret_cast(args.probs_ptr) : nullptr, + args.with_probs ? permuted_probs.data_ptr() : nullptr, + args.row_id_map.data_ptr(), + args.num_dispatched_token_tensor.data_ptr(), + args.pad_multiple, + args.num_of_local_experts, + args.hidden_size, + args.scales_per_token, + args.local_rank, + args.num_ranks_per_node + ); CUDA_CHECK(cudaGetLastError()); return std::make_tuple(permuted_tokens, permuted_scaling_factor, permuted_probs); @@ -509,117 +434,103 @@ int64_t tokens_per_block = blockDim.x / 128; int64_t extended_laned_id = threadIdx.x % 128; int64_t extended_warp_id = threadIdx.x / 128; - int64_t block_start = blockIdx.x * tokens_per_block; - int64_t token_id = block_start + extended_warp_id; - int num_dispatched_tokens = *num_dispatched_tokens_ptr; - - // Compute the offset for each expert, means the prefix sum of tokens per - // expert extern __shared__ int shmem_in_permute_kernel[]; int* expert_routing_map = shmem_in_permute_kernel; - // Load the current routing map - for (int i = threadIdx.x; i < num_of_local_experts * tokens_per_block; i += block_size) { - expert_routing_map[i] = (block_start + i / num_of_local_experts < num_dispatched_tokens) - ? row_id_map[block_start * num_of_local_experts + i] - : 0; - } - __syncthreads(); - - if (token_id >= num_dispatched_tokens) { // If the token is out of range, return - return; - } - - // Unpermute the tokens - constexpr int num_eles_per_float4 = sizeof(float4) / sizeof(DType); - int64_t hidden_size_fp4 = hidden_size / num_eles_per_float4; - float4* tokens_fp4 = reinterpret_cast(tokens); - float4* permuted_tokens_fp4 = reinterpret_cast(permuted_tokens); - // Use float4 buffer to accumulate the tokens - float4 buffer_fp4; - float accumulator_fp4[num_eles_per_float4]; - DType* buffer_ptr = reinterpret_cast(&buffer_fp4); - // Accumulate the tokens from multi-experts - for (int64_t j = extended_laned_id; j < hidden_size_fp4; j += 128) { - // Initialize the accumulator - #pragma unroll - for (int k = 0; k < num_eles_per_float4; k++) - accumulator_fp4[k] = 0.0f; - for (int i = 0; i < num_of_local_experts; i++) { - int64_t source_token_id = expert_routing_map[extended_warp_id * num_of_local_experts + i]; - if (source_token_id > 0) { - buffer_fp4 = permuted_tokens_fp4[(source_token_id - 1) * hidden_size_fp4 + j]; - #pragma unroll - for (int k = 0; k < num_eles_per_float4; k++) { - accumulator_fp4[k] += DType2Float(buffer_ptr[k]); - } - } - } - #pragma unroll - for (int k = 0; k < num_eles_per_float4; k++) { - buffer_ptr[k] = Float2DType(accumulator_fp4[k]); - } - // Store the accumulated tokens to the output tensor - tokens_fp4[token_id * hidden_size_fp4 + j] = buffer_fp4; - } - - // If use probs, unpermute the probs - if (permuted_probs != nullptr) { - for (int64_t j = extended_laned_id; j < num_of_local_experts * num_ranks_per_node; j += 128) { - float value = 0.0f; - if (j / num_of_local_experts == local_rank) { - int64_t source_token_id = - expert_routing_map[extended_warp_id * num_of_local_experts + j % num_of_local_experts]; - if (source_token_id > 0) { - value = static_cast(permuted_probs[source_token_id - 1]); - } - } - probs[token_id * num_of_local_experts * num_ranks_per_node + j] = - static_cast(value); - } - } + int num_dispatched_tokens = *num_dispatched_tokens_ptr; + + + for(int64_t block_start = blockIdx.x * tokens_per_block; block_start < num_dispatched_tokens; block_start += tokens_per_block * gridDim.x) { + int64_t token_id = block_start + extended_warp_id; + // Load the current routing map + for (int i = threadIdx.x; i < num_of_local_experts * tokens_per_block; i += block_size) { + expert_routing_map[i] = (block_start + i / num_of_local_experts < num_dispatched_tokens) + ? row_id_map[block_start * num_of_local_experts + i] + : 0; + } + __syncthreads(); + + if (token_id < num_dispatched_tokens) { + // Unpermute the tokens + constexpr int num_eles_per_float4 = sizeof(float4) / sizeof(DType); + int64_t hidden_size_fp4 = hidden_size / num_eles_per_float4; + float4* tokens_fp4 = reinterpret_cast(tokens); + float4* permuted_tokens_fp4 = reinterpret_cast(permuted_tokens); + // Use float4 buffer to accumulate the tokens + float4 buffer_fp4; + float accumulator_fp4[num_eles_per_float4]; + DType* buffer_ptr = reinterpret_cast(&buffer_fp4); + // Accumulate the tokens from multi-experts + for (int64_t j = extended_laned_id; j < hidden_size_fp4; j += 128) { + // Initialize the accumulator + #pragma unroll + for (int k = 0; k < num_eles_per_float4; k++) + accumulator_fp4[k] = 0.0f; + for (int i = 0; i < num_of_local_experts; i++) { + int64_t source_token_id = expert_routing_map[extended_warp_id * num_of_local_experts + i]; + if (source_token_id > 0) { + buffer_fp4 = permuted_tokens_fp4[(source_token_id - 1) * hidden_size_fp4 + j]; + #pragma unroll + for (int k = 0; k < num_eles_per_float4; k++) { + accumulator_fp4[k] += DType2Float(buffer_ptr[k]); + } + } + } + #pragma unroll + for (int k = 0; k < num_eles_per_float4; k++) { + buffer_ptr[k] = Float2DType(accumulator_fp4[k]); + } + // Store the accumulated tokens to the output tensor + tokens_fp4[token_id * hidden_size_fp4 + j] = buffer_fp4; + } + + // If use probs, unpermute the probs + if (permuted_probs != nullptr) { + for (int64_t j = extended_laned_id; j < num_of_local_experts * num_ranks_per_node; j += 128) { + float value = 0.0f; + if (j / num_of_local_experts == local_rank) { + int64_t source_token_id = + expert_routing_map[extended_warp_id * num_of_local_experts + j % num_of_local_experts]; + if (source_token_id > 0) { + value = static_cast(permuted_probs[source_token_id - 1]); + } + } + probs[token_id * num_of_local_experts * num_ranks_per_node + j] = + static_cast(value); + } + } + } // if (token_id < num_dispatched_tokens) + __syncthreads(); + } } template - void unpermute_launcher(torch::Tensor permuted_tokens, - c10::optional permuted_probs, - DType* tokens_ptr, - ProbType* probs_ptr, - torch::Tensor row_id_map, - int num_of_local_experts, - torch::Tensor num_dispatched_tokens_tensor, - int num_dispatched_tokens, - int pad_multiple, - int hidden_size, - int local_rank, - int num_ranks_per_node, - bool with_probs, - cudaStream_t stream) { - assert(permuted_tokens.dtype() == torch::kBFloat16); - if (with_probs) { - assert(permuted_probs.has_value()); - assert(permuted_probs.value().dtype() == torch::kFloat32); + void unpermute_launcher(UnpermuteArgs args) { + assert(args.permuted_tokens.dtype() == torch::kBFloat16); + if (args.with_probs) { + assert(args.permuted_probs.has_value()); + assert(args.permuted_probs.value().dtype() == torch::kFloat32); } assert((std::is_same::value)); assert((std::is_same::value)); - assert(hidden_size % 2 == 0); + assert(args.hidden_size % 2 == 0); constexpr int block_size = 512; constexpr int tokens_per_block = block_size / 128; - int grid_size = (num_dispatched_tokens + tokens_per_block - 1) / tokens_per_block; - int shared_mem_size = num_of_local_experts * tokens_per_block * sizeof(int); - - // If the size of the dispatched tokens is 0, return - if (num_dispatched_tokens == 0) { - return; - } - - unpermute_kernel<<>>( - reinterpret_cast<__nv_bfloat16*>(permuted_tokens.data_ptr()), - reinterpret_cast<__nv_bfloat16*>(tokens_ptr), - with_probs ? reinterpret_cast(permuted_probs.value().data_ptr()) : nullptr, - with_probs ? reinterpret_cast(probs_ptr) : nullptr, row_id_map.data_ptr(), - num_dispatched_tokens_tensor.data_ptr(), num_of_local_experts, hidden_size, local_rank, - num_ranks_per_node); + int grid_size = args.num_of_blocks_permute_api; + int shared_mem_size = args.num_of_local_experts * tokens_per_block * sizeof(int); + + unpermute_kernel<<>>( + reinterpret_cast<__nv_bfloat16*>(args.permuted_tokens.data_ptr()), + reinterpret_cast<__nv_bfloat16*>(args.tokens_ptr), + args.with_probs ? reinterpret_cast(args.permuted_probs.value().data_ptr()) : nullptr, + args.with_probs ? reinterpret_cast(args.probs_ptr) : nullptr, + args.row_id_map.data_ptr(), + args.num_dispatched_tokens_tensor.data_ptr(), + args.num_of_local_experts, + args.hidden_size, + args.local_rank, + args.num_ranks_per_node + ); CUDA_CHECK(cudaGetLastError()); } diff --git a/csrc/hybrid_ep/extension/permute.cuh b/csrc/hybrid_ep/extension/permute.cuh index 15da27ca..e4b1f27a 100644 --- a/csrc/hybrid_ep/extension/permute.cuh +++ b/csrc/hybrid_ep/extension/permute.cuh @@ -11,12 +11,61 @@ #include #include "utils.cuh" +struct PermuteArgs { + // The address of the input + void* tokens_ptr; + float* probs_ptr; + float* scaling_factor_ptr; + torch::Tensor row_id_map; + + // The shape message of the input + int hidden_size; + int scales_per_token; // Now is hidden_size/128 + torch::Tensor num_dispatched_token_tensor; // We assume it is only valid on GPU + int num_permuted_token; + int num_ranks_per_node; // Probs dimension 0 = num_ranks_per_node * num_of_local_experts + int num_of_local_experts; + int pad_multiple; + + // Misc + int local_rank; + bool use_fp8; + bool with_probs; + int num_of_blocks_permute_api; + torch::TensorOptions token_options; // To record the Dtype of the input tokens from the expert mlp, maybe bf16/fp16/fp8... + cudaStream_t stream; +}; + +struct UnpermuteArgs { + // Input tensors + torch::Tensor permuted_tokens; + c10::optional permuted_probs; + torch::Tensor row_id_map; + + // The address of the output + uint16_t* tokens_ptr; + float* probs_ptr; + + // The shape message of the output + int num_of_local_experts; + torch::Tensor num_dispatched_tokens_tensor; // We assume it is only valid on GPU + int pad_multiple; + int hidden_size; + + // Misc + int local_rank; + int num_ranks_per_node; + bool with_probs; + int num_of_blocks_permute_api; + cudaStream_t stream; +}; + /** * @brief Make the row id map for the permute kernel, padding at the num of * tokens dimension * @param routing_map[in] shape: [num_dispatched_tokens, num_of_local_experts], * type: bool - * @param num_dispatched_tokens[in] + * @param max_num_dispatched_tokens[in] * @param num_of_local_experts[in] * @param pad_multiple[in] * @param stream[in] @@ -26,9 +75,10 @@ std::tuple permute_processing( bool* routing_map, torch::Tensor num_dispatched_token_tensor, - int num_dispatched_tokens, + int max_num_dispatched_tokens, int num_of_local_experts, int pad_multiple, + int num_of_blocks, cudaStream_t stream); /** @@ -50,23 +100,7 @@ */ template std::tuple, c10::optional> - permute_launcher(DType* tokens_ptr, - ProbType* probs_ptr, - ScalarType* scaling_factor_ptr, - torch::Tensor row_id_map, - int hidden_size, - int scales_per_token, - int local_rank, - int num_ranks_per_node, - int num_of_local_experts, - torch::Tensor num_dispatched_token_tensor, - int num_dispatched_tokens, - int num_permuted_token, - int pad_multiple, - bool use_fp8, - bool with_probs, - torch::TensorOptions token_options, - cudaStream_t stream); + permute_launcher(PermuteArgs args); /** * @brief Unpermute the tokens to the original order @@ -82,20 +116,7 @@ * type: int */ template - void unpermute_launcher(torch::Tensor permuted_tokens, - c10::optional permuted_probs, - DType* tokens_ptr, - ProbType* probs_ptr, - torch::Tensor row_id_map, - int num_of_local_experts, - torch::Tensor num_dispatched_tokens_tensor, - int num_dispatched_tokens, - int pad_multiple, - int hidden_size, - int local_rank, - int num_ranks_per_node, - bool with_probs, - cudaStream_t stream); + void unpermute_launcher(UnpermuteArgs args); template inline __device__ float DType2Float(DType value) { @@ -114,4 +135,4 @@ return static_cast(value); } } - \ No newline at end of file + diff --git a/csrc/hybrid_ep/hybrid_ep.cu b/csrc/hybrid_ep/hybrid_ep.cu index ea02c2a0..16958686 100644 --- a/csrc/hybrid_ep/hybrid_ep.cu +++ b/csrc/hybrid_ep/hybrid_ep.cu @@ -9,17 +9,18 @@ HybridEPBuffer::HybridEPBuffer( int node_rank, int group_size, std::string base_path, - std::vector ib_dev_name_list, bool load_cached_kernels, bool use_shared_buffer, - bool enable_fabric + bool use_mnnvl ) : process_group(process_group), buffer_config(config), local_rank(local_rank), node_rank(node_rank), group_size(group_size), use_shared_buffer(use_shared_buffer), executor(local_rank, node_rank, base_path, load_cached_kernels) { - remote_allocator.init(enable_fabric); + remote_allocator.init(/*enable_fabric=*/use_mnnvl); if(group_size > buffer_config.num_of_ranks_per_node) { #ifdef HYBRID_EP_BUILD_MULTINODE_ENABLE - rdma_coordinator.init(process_group, node_rank, local_rank, buffer_config, ib_dev_name_list); + rdma_coordinator.init(process_group, node_rank, local_rank, use_mnnvl, buffer_config); #else + fprintf(stderr, "Inter-node communication is not supported. Please rebuild with HYBRID_EP_MULTINODE flag.\n"); + fflush(stderr); assert(false); // inter-node communication is not supported. #endif } @@ -378,6 +379,7 @@ bool HybridEPBuffer::update_buffer(HybridEpConfigInstance config) { // If new config requires bigger buffer, we will release the old buffer and allocate a new one. bool need_reallocate = false; + need_reallocate |= grow_to(buffer_config.max_num_of_tokens_per_rank, config.max_num_of_tokens_per_rank); need_reallocate |= grow_to(buffer_config.hidden_dim, config.hidden_dim); need_reallocate |= grow_to(buffer_config.num_of_experts_per_rank,config.num_of_experts_per_rank); need_reallocate |= grow_to(buffer_config.num_of_ranks_per_node, config.num_of_ranks_per_node); @@ -395,11 +397,14 @@ bool HybridEPBuffer::update_buffer(HybridEpConfigInstance config) { } if(buffer_config.num_of_nodes > 1 && need_reallocate) { - fprintf(stderr, "Reallocate buffer for multi-node case is very slow, please check the buffer configuration to pre-allocate the buffer."); - assert(!need_reallocate); + TORCH_WARN("Reallocating HybridEP buffers in multi-node mode is very slow; " + "adjust buffer_config to pre-allocate sufficient capacity."); } if(need_reallocate) { + #ifdef HYBRID_EP_BUILD_MULTINODE_ENABLE + rdma_coordinator.update_config(buffer_config); + #endif release_buffer(); allocate_buffer(); } @@ -452,8 +457,7 @@ HybridEPBuffer::dispatch(HybridEpConfigInstance config, args.attn_to_rdma_map = attn_to_rdma_map; args.num_dispatched_tokens_tensor = num_dispatched_tokens_tensor; args.num_dispatched_tokens = (num_dispatched_tokens.has_value()) ? - num_dispatched_tokens.value() : - num_dispatched_tokens_tensor.value().item(); + num_dispatched_tokens.value() : -1; args.num_of_tokens_per_rank = num_of_tokens_per_rank; args.enable_permute = false; args.stream = at::cuda::getCurrentCUDAStream(); @@ -468,7 +472,7 @@ HybridEPBuffer::dispatch(HybridEpConfigInstance config, }else { throw std::runtime_error("Invalid token data type:" + std::to_string(static_cast(config.token_data_type))); } - auto [dispatched_tokens, dispatched_probs, dispatched_scaling_factor, row_id_map, tokens_per_expert] = executor.dispatch_postprocess(config, dispatch_buffers, args); + auto [dispatched_tokens, dispatched_probs, dispatched_scaling_factor] = executor.dispatch_postprocess(config, dispatch_buffers, args); return std::make_tuple(dispatched_tokens, dispatched_probs, dispatched_scaling_factor); } @@ -535,11 +539,10 @@ HybridEPBuffer::dispatch_with_permute(HybridEpConfigInstance config, c10::optional num_dispatched_tokens_tensor, c10::optional local_expert_routing_map, c10::optional row_id_map, - c10::optional num_dispatched_tokens, c10::optional num_permuted_tokens, int64_t num_of_tokens_per_rank, c10::optional pad_multiple, - bool use_host_meta, + bool non_blocking, bool with_probs) { // Check the input tensors @@ -567,20 +570,18 @@ HybridEPBuffer::dispatch_with_permute(HybridEpConfigInstance config, args.attn_to_rdma_map = attn_to_rdma_map; args.local_expert_routing_map = local_expert_routing_map; args.num_dispatched_tokens_tensor = num_dispatched_tokens_tensor; - args.num_dispatched_tokens = (num_dispatched_tokens.has_value()) ? - num_dispatched_tokens.value() : - num_dispatched_tokens_tensor.value().item(); + args.max_num_dispatched_tokens = this->max_num_of_tokens; args.row_id_map = row_id_map; args.num_permuted_tokens = (num_permuted_tokens.has_value()) ? num_permuted_tokens.value() : -1; args.pad_multiple = (pad_multiple.has_value()) ? pad_multiple.value() : 0; - args.use_host_meta = use_host_meta; + args.non_blocking = non_blocking; args.num_of_tokens_per_rank = num_of_tokens_per_rank; args.enable_permute = true; args.stream = at::cuda::getCurrentCUDAStream(); // Run the full dispatch operation config.forward_dispatch_api = with_probs; - executor.dispatch_preprocess(config, dispatch_buffers, args); + auto [result_row_id_map, result_tokens_per_expert] = executor.dispatch_preprocess(config, dispatch_buffers, args); if(config.token_data_type == APP_TOKEN_DATA_TYPE::UINT8) { executor.dispatch_core(config, dispatch_buffers, args); } else if (config.token_data_type == APP_TOKEN_DATA_TYPE::UINT16) { @@ -589,7 +590,9 @@ HybridEPBuffer::dispatch_with_permute(HybridEpConfigInstance config, throw std::runtime_error("Invalid token data type:" + std::to_string(static_cast(config.token_data_type))); } - return executor.dispatch_postprocess(config, dispatch_buffers, args); + auto [dispatched_tokens, dispatched_probs, dispatched_scaling_factor] = executor.dispatch_postprocess(config, dispatch_buffers, args); + + return std::make_tuple(dispatched_tokens, dispatched_probs, dispatched_scaling_factor, result_row_id_map, result_tokens_per_expert); } std::tuple @@ -598,7 +601,6 @@ HybridEPBuffer::combine_with_unpermute(HybridEpConfigInstance config, torch::Tensor sparse_to_dense_map, torch::Tensor rdma_to_attn_map, torch::Tensor attn_to_rdma_map, c10::optional num_dispatched_tokens_tensor, c10::optional row_id_map, - c10::optional num_dispatched_tokens, int64_t num_of_tokens_per_rank, c10::optional pad_multiple, bool with_probs) @@ -634,9 +636,6 @@ HybridEPBuffer::combine_with_unpermute(HybridEpConfigInstance config, args.rdma_to_attn_map = rdma_to_attn_map; args.attn_to_rdma_map = attn_to_rdma_map; args.num_dispatched_tokens_tensor = num_dispatched_tokens_tensor; - args.num_dispatched_tokens = (num_dispatched_tokens.has_value()) ? - num_dispatched_tokens.value() : - num_dispatched_tokens_tensor.value().item(); args.row_id_map = row_id_map; args.pad_multiple = (pad_multiple.has_value()) ? pad_multiple.value() : 0; args.num_of_tokens_per_rank = num_of_tokens_per_rank; diff --git a/csrc/hybrid_ep/hybrid_ep.cuh b/csrc/hybrid_ep/hybrid_ep.cuh index f7fe2c18..e0dd561c 100644 --- a/csrc/hybrid_ep/hybrid_ep.cuh +++ b/csrc/hybrid_ep/hybrid_ep.cuh @@ -19,7 +19,7 @@ class HybridEPBuffer { public: - HybridEPBuffer(pybind11::object process_group, BufferConfig config, int local_rank, int node_rank, int group_size, std::string base_path, std::vector ib_dev_name_list, bool load_cached_kernels, bool use_shared_buffer, bool enable_fabric); + HybridEPBuffer(pybind11::object process_group, BufferConfig config, int local_rank, int node_rank, int group_size, std::string base_path, bool load_cached_kernels, bool use_shared_buffer, bool use_mnnvl); ~HybridEPBuffer(); bool update_buffer(HybridEpConfigInstance config); // True means the buffer is reallocated. @@ -56,11 +56,10 @@ public: c10::optional num_dispatched_tokens_tensor, c10::optional local_expert_routing_map, c10::optional row_id_map, - c10::optional num_dispatched_tokens, c10::optional num_permuted_tokens, int64_t num_of_tokens_per_rank, c10::optional pad_multiple, - bool use_host_meta, + bool non_blocking, bool with_probs); std::tuple @@ -70,7 +69,6 @@ public: torch::Tensor sparse_to_dense_map, torch::Tensor rdma_to_attn_map, torch::Tensor attn_to_rdma_map, c10::optional num_dispatched_tokens_tensor, c10::optional row_id_map, - c10::optional num_dispatched_tokens, int64_t num_of_tokens_per_rank, c10::optional pad_multiple, bool with_probs); diff --git a/csrc/hybrid_ep/internode.cu b/csrc/hybrid_ep/internode.cu index a31f420a..a9a5d6c3 100644 --- a/csrc/hybrid_ep/internode.cu +++ b/csrc/hybrid_ep/internode.cu @@ -143,20 +143,28 @@ int create_and_place_qps(struct gverbs_context *g_ctx, return status; } -static int setup_qp_attr_for_modify(struct doca_verbs_qp_attr *qp_attr, - struct remote_info *rem_dest, - struct ibv_context *ib_context) { +static int setup_qp_attr_for_modify(struct ibv_port_attr *port_attr, struct doca_verbs_qp_attr *qp_attr, + struct remote_info *l_info, struct remote_info *r_info, + struct ibv_context *ib_context) { int status = 0; - status = doca_verbs_qp_attr_set_dest_qp_num(qp_attr, rem_dest->qpn); + status = doca_verbs_qp_attr_set_dest_qp_num(qp_attr, r_info->qpn); assert(status == 0); struct doca_verbs_ah *ah = nullptr; status = doca_verbs_ah_create(ib_context, &ah); assert(status == 0); - status = doca_verbs_ah_set_dlid(ah, rem_dest->lid); + if (port_attr->link_layer == IBV_LINK_LAYER_INFINIBAND) { + status = doca_verbs_ah_set_addr_type(ah, DOCA_VERBS_ADDR_TYPE_IB_NO_GRH); + } else { + status = doca_verbs_ah_set_addr_type(ah, DOCA_VERBS_ADDR_TYPE_IPv4); + } + assert(status == 0); + status = doca_verbs_ah_set_dlid(ah, r_info->lid); + assert(status == 0); + status = doca_verbs_ah_set_gid(ah, *((struct doca_verbs_gid *)(&r_info->gid))); assert(status == 0); status = doca_verbs_ah_set_sl(ah, 0); assert(status == 0); - status = doca_verbs_ah_set_sgid_index(ah, rem_dest->gid_index); + status = doca_verbs_ah_set_sgid_index(ah, l_info->gid_index); assert(status == 0); status = doca_verbs_qp_attr_set_ah_attr(qp_attr, ah); assert(status == 0); @@ -166,8 +174,7 @@ static int setup_qp_attr_for_modify(struct doca_verbs_qp_attr *qp_attr, assert(status == 0); status = doca_verbs_qp_attr_set_sq_psn(qp_attr, 0); assert(status == 0); - status = - doca_verbs_qp_attr_set_path_mtu(qp_attr, DOCA_VERBS_MTU_SIZE_1K_BYTES); + status = doca_verbs_qp_attr_set_path_mtu(qp_attr, DOCA_VERBS_MTU_SIZE_1K_BYTES); assert(status == 0); status = doca_verbs_qp_attr_set_min_rnr_timer(qp_attr, 1); assert(status == 0); @@ -180,6 +187,7 @@ static int setup_qp_attr_for_modify(struct doca_verbs_qp_attr *qp_attr, return 0; } + int doca_gpunetio_test_change_qp_state(struct doca_gpu_verbs_qp_hl *qp, struct doca_verbs_qp_attr *qp_attr, int attr_mask) { @@ -225,60 +233,75 @@ int doca_gpunetio_test_change_qp_state(struct doca_gpu_verbs_qp_hl *qp, return 0; } -static int setup_qp_attr_and_set_qp(struct gverbs_context *g_ctx, - struct ibv_context *ib_context, - struct remote_info *rem_dest, - struct doca_verbs_qp_attr *qp_attr, - int num_of_blocks, int num_of_nodes, - int node_rank, uint32_t qp_cnt) { - int attr_mask = DOCA_VERBS_QP_ATTR_NEXT_STATE | - DOCA_VERBS_QP_ATTR_ALLOW_REMOTE_WRITE | - DOCA_VERBS_QP_ATTR_ALLOW_REMOTE_READ | - DOCA_VERBS_QP_ATTR_PORT_NUM | DOCA_VERBS_QP_ATTR_PKEY_INDEX; +static int setup_qp_attr_and_set_qp(struct gverbs_context *g_ctx, struct ibv_context *ib_context, struct ibv_port_attr *port_attr, + struct remote_info *rem_dest, struct doca_verbs_qp_attr *qp_attr, + int num_of_blocks, int num_of_nodes, int node_rank, uint32_t qp_cnt) { + int attr_mask = DOCA_VERBS_QP_ATTR_NEXT_STATE | DOCA_VERBS_QP_ATTR_ALLOW_REMOTE_WRITE | + DOCA_VERBS_QP_ATTR_ALLOW_REMOTE_READ | DOCA_VERBS_QP_ATTR_PORT_NUM | + DOCA_VERBS_QP_ATTR_PKEY_INDEX; for (int qp_idx = 0; qp_idx < num_of_blocks; ++qp_idx) { for (int peer_idx = 0; peer_idx < num_of_nodes - 1; ++peer_idx) { int actual_node_idx = peer_idx < node_rank ? peer_idx : (peer_idx + 1); - int actual_idx_in_node = - peer_idx < node_rank ? (node_rank - 1) : node_rank; - int my_idx = peer_idx + qp_idx * (num_of_nodes - 1); - int rem_idx = actual_node_idx * qp_cnt + qp_idx * (num_of_nodes - 1) + - actual_idx_in_node; + int actual_idx_in_node = peer_idx < node_rank ? (node_rank - 1) : node_rank; + int curr_qp_idx = peer_idx + qp_idx * (num_of_nodes - 1); + int local_idx = curr_qp_idx + node_rank * qp_cnt; + int rem_idx = actual_node_idx * qp_cnt + qp_idx * (num_of_nodes - 1) + actual_idx_in_node; + struct remote_info *l_info = &rem_dest[local_idx]; struct remote_info *r_info = &rem_dest[rem_idx]; - struct doca_gpu_verbs_qp_hl *qp = g_ctx->qp_hls[my_idx]; - setup_qp_attr_for_modify(qp_attr, r_info, ib_context); + struct doca_gpu_verbs_qp_hl *qp = g_ctx->qp_hls[curr_qp_idx]; + setup_qp_attr_for_modify(port_attr, qp_attr, l_info, r_info, ib_context); doca_gpunetio_test_change_qp_state(qp, qp_attr, attr_mask); } } return 0; } +void RDMACoordinator::update_config(BufferConfig config) { + this->buffer_config = config; +} + void RDMACoordinator::init( pybind11::object process_group, int node_rank, int local_rank, - BufferConfig config, - std::vector ib_dev_name_list + bool use_mnnvl, + BufferConfig config ) { this->process_group = process_group; this->node_rank = node_rank; this->local_rank = local_rank; this->buffer_config = config; - this->ib_dev_name_list = ib_dev_name_list; - assert(buffer_config.num_of_nodes > 1); + + std::vector gpu_idx_vec; + // The node in config means the nvlink domain + // The local device index is the index of the device in the real device list within the physical node. + int num_of_local_devices = buffer_config.num_of_ranks_per_node; + if (use_mnnvl) { + num_of_local_devices = std::min(num_of_local_devices, 4); + } + int local_device_idx = local_rank % num_of_local_devices; + for (int i = 0; i < num_of_local_devices; ++i) { + gpu_idx_vec.push_back(i); + } // Get name of ibv device. - const char *ib_devname = ib_dev_name_list[local_rank].c_str(); + const char *net_name; + hybrid_ep::get_nic(gpu_idx_vec, local_device_idx, &net_name); // Find ib device and get ibv_context. - struct ibv_device *ib_dev = ctx_find_dev(ib_devname); + struct ibv_device *ib_dev = ctx_find_dev(net_name); ib_context = ibv_open_device(ib_dev);; - ibv_query_port(ib_context, IB_PORT, &port_attr); auto transport_type = ib_context->device->transport_type; assert(transport_type == IBV_TRANSPORT_IB); + ibv_query_port(ib_context, IB_PORT, &port_attr); + uint8_t link_layer = port_attr.link_layer; + assert(link_layer == IBV_LINK_LAYER_INFINIBAND || link_layer == IBV_LINK_LAYER_ETHERNET); + hybrid_ep::ncclIbGetGidIndex(ib_context, IB_PORT, &port_attr, &gid_index); + // Alloc protect domain. ib_pd = ibv_alloc_pd(ib_context); gpu_handler = (struct doca_gpu *)calloc(1, sizeof(struct doca_gpu)); - get_gpu_handler(gpu_handler, ib_context, local_rank); + get_gpu_handler(gpu_handler, ib_context, local_device_idx); mr_access_flag = IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ | IBV_ACCESS_REMOTE_ATOMIC | IBV_ACCESS_RELAXED_ORDERING; @@ -355,9 +378,7 @@ void RDMACoordinator::allocate_dispatch_rdma_buffers(DispatchBuffers &dispatch_b // Set dispatch queue pair attributes. int num_of_dispatch_qps = (buffer_config.num_of_nodes - 1) * buffer_config.num_of_blocks_dispatch_api; memset(&dispatch_gverbs_ctx, 0, sizeof(gverbs_context)); - if (GID_INDEX != -1) { - ibv_query_gid(ib_context, IB_PORT, GID_INDEX, &dispatch_gverbs_ctx.gid); - } + ibv_query_gid(ib_context, IB_PORT, gid_index, &dispatch_gverbs_ctx.gid); dispatch_gverbs_ctx.qp_init_attr = (struct doca_gpu_verbs_qp_init_attr_hl *)calloc(1, sizeof(struct doca_gpu_verbs_qp_init_attr_hl)); setup_qp_init_attr(dispatch_gverbs_ctx.qp_init_attr, gpu_handler, ib_pd, 3 * buffer_config.max_num_of_tokens_per_rank + 1); dispatch_gverbs_ctx.qp_hls = (struct doca_gpu_verbs_qp_hl **)calloc(sizeof(struct doca_gpu_verbs_qp_hl *), num_of_dispatch_qps); @@ -385,7 +406,7 @@ void RDMACoordinator::allocate_dispatch_rdma_buffers(DispatchBuffers &dispatch_b struct remote_info *curr_info = my_dispatch_info + idx; curr_info->lid = port_attr.lid; curr_info->qpn = doca_verbs_qp_get_qpn(dispatch_gverbs_ctx.qp_hls[idx]->qp); - curr_info->gid_index = GID_INDEX; + curr_info->gid_index = gid_index; memset(&curr_info->gid, 0, sizeof(curr_info->gid));; memcpy(curr_info->gid.raw, dispatch_gverbs_ctx.gid.raw, 16); curr_info->token_rkey = dispatch_rdma_inter_node_group_token_mr->rkey; @@ -411,7 +432,8 @@ void RDMACoordinator::allocate_dispatch_rdma_buffers(DispatchBuffers &dispatch_b exchange_remote_rdma_info(dispatch_remote_info_vec, my_dispatch_info, num_of_dispatch_qps); // Init queue pairs. - setup_qp_attr_and_set_qp(&dispatch_gverbs_ctx, ib_context, dispatch_remote_info_vec, dispatch_gverbs_ctx.qp_attr, + setup_qp_attr_and_set_qp(&dispatch_gverbs_ctx, ib_context, &port_attr, + dispatch_remote_info_vec, dispatch_gverbs_ctx.qp_attr, buffer_config.num_of_blocks_dispatch_api, buffer_config.num_of_nodes, node_rank, num_of_dispatch_qps); // Move queue pairs to GPU. doca_gpu_dev_verbs_qp **h_qps_gpu = (doca_gpu_dev_verbs_qp**)calloc(sizeof(*h_qps_gpu), num_of_dispatch_qps); @@ -512,9 +534,7 @@ void RDMACoordinator::allocate_combine_rdma_buffers(CombineBuffers &combine_buff // Set combine queue pair attributes. int num_of_combine_qps = (buffer_config.num_of_nodes - 1) * buffer_config.num_of_blocks_combine_api; memset(&combine_gverbs_ctx, 0, sizeof(gverbs_context)); - if (GID_INDEX != -1) { - ibv_query_gid(ib_context, IB_PORT, GID_INDEX, &combine_gverbs_ctx.gid); - } + ibv_query_gid(ib_context, IB_PORT, gid_index, &combine_gverbs_ctx.gid); combine_gverbs_ctx.qp_init_attr = (struct doca_gpu_verbs_qp_init_attr_hl *)calloc(1, sizeof(struct doca_gpu_verbs_qp_init_attr_hl)); setup_qp_init_attr(combine_gverbs_ctx.qp_init_attr, gpu_handler, ib_pd, 2 * buffer_config.max_num_of_tokens_per_rank + 1); combine_gverbs_ctx.qp_hls = (struct doca_gpu_verbs_qp_hl **)calloc(sizeof(struct doca_gpu_verbs_qp_hl *), num_of_combine_qps); @@ -540,7 +560,7 @@ void RDMACoordinator::allocate_combine_rdma_buffers(CombineBuffers &combine_buff struct remote_info *curr_info = my_combine_info + idx; curr_info->lid = port_attr.lid; curr_info->qpn = doca_verbs_qp_get_qpn(combine_gverbs_ctx.qp_hls[idx]->qp); - curr_info->gid_index = GID_INDEX; + curr_info->gid_index = gid_index; memset(&curr_info->gid, 0, sizeof(curr_info->gid));; memcpy(curr_info->gid.raw, combine_gverbs_ctx.gid.raw, 16); curr_info->token_rkey = combine_rdma_inter_node_group_token_mr->rkey; @@ -558,7 +578,8 @@ void RDMACoordinator::allocate_combine_rdma_buffers(CombineBuffers &combine_buff exchange_remote_rdma_info(combine_remote_info_vec, my_combine_info, num_of_combine_qps); // Init queue pairs. - setup_qp_attr_and_set_qp(&combine_gverbs_ctx, ib_context, combine_remote_info_vec, combine_gverbs_ctx.qp_attr, + setup_qp_attr_and_set_qp(&combine_gverbs_ctx, ib_context, &port_attr, + combine_remote_info_vec, combine_gverbs_ctx.qp_attr, buffer_config.num_of_blocks_combine_api, buffer_config.num_of_nodes, node_rank, num_of_combine_qps); // Move queue pairs to GPU. doca_gpu_dev_verbs_qp **h_qps_gpu = (doca_gpu_dev_verbs_qp**)calloc(sizeof(*h_qps_gpu), num_of_combine_qps); diff --git a/csrc/hybrid_ep/internode.cuh b/csrc/hybrid_ep/internode.cuh index 1965a087..16d90b7a 100644 --- a/csrc/hybrid_ep/internode.cuh +++ b/csrc/hybrid_ep/internode.cuh @@ -11,6 +11,8 @@ #include #include "backend/hybrid_ep_backend.cuh" #include "backend/utils.cuh" +#include "backend/topo_detection.cuh" +#include "backend/ibvcore.h" #include "config.cuh" #define RC (0) @@ -39,7 +41,6 @@ constexpr int32_t DEF_HOP_LIMIT = 64; constexpr int32_t DEF_RX_RDMA = 128; constexpr int32_t DEF_TX_BW = 512; constexpr int32_t EQ_NUM = 0; -constexpr int32_t GID_INDEX = 0; constexpr int32_t GVERBS_WQ_BUF_LOC = DOCA_GPU_MEM_TYPE_GPU; constexpr int32_t GVERBS_CQ_BUF_LOC = DOCA_GPU_MEM_TYPE_GPU; constexpr int32_t GVERBS_USE_ASYNC_STIRE_RELEASE = 0; @@ -120,14 +121,15 @@ void setup_qp_init_attr(struct doca_gpu_verbs_qp_init_attr_hl *qp_init_attr, int create_and_place_qps(struct gverbs_context *g_ctx, struct doca_gpu_verbs_qp_init_attr_hl *qp_init_attr, int num_qps); -static int setup_qp_attr_for_modify(struct doca_verbs_qp_attr *qp_attr, - struct remote_info *rem_dest, +static int setup_qp_attr_for_modify(struct ibv_port_attr *port_attr, struct doca_verbs_qp_attr *qp_attr, + struct remote_info *l_info, struct remote_info *r_info, struct ibv_context *ib_context); int doca_gpunetio_test_change_qp_state(struct doca_gpu_verbs_qp_hl *qp, struct doca_verbs_qp_attr *qp_attr, int attr_mask); static int setup_qp_attr_and_set_qp(struct gverbs_context *g_ctx, struct ibv_context *ib_context, + struct ibv_port_attr *port_attr, struct remote_info *rem_dest, struct doca_verbs_qp_attr *qp_attr, int num_of_blocks, int num_of_nodes, @@ -137,13 +139,14 @@ class RDMACoordinator { public: RDMACoordinator() = default; ~RDMACoordinator(); - void init(pybind11::object process_group, int node_rank, int local_rank, BufferConfig config, std::vector ib_dev_name_list); + void init(pybind11::object process_group, int node_rank, int local_rank, bool use_mnnvl, BufferConfig config); + void update_config(BufferConfig config); void destroy(); void allocate_dispatch_rdma_buffers(DispatchBuffers &dispatch_buffers); void allocate_combine_rdma_buffers(CombineBuffers &combine_buffers); private: - std::vector ib_dev_name_list; + int gid_index = 0; int node_rank = -1; int local_rank = -1; BufferConfig buffer_config; diff --git a/csrc/hybrid_ep/jit/compiler.cu b/csrc/hybrid_ep/jit/compiler.cu index d2da156f..0b458db6 100644 --- a/csrc/hybrid_ep/jit/compiler.cu +++ b/csrc/hybrid_ep/jit/compiler.cu @@ -11,19 +11,31 @@ inline std::string get_env(std::string name) { return std::string(env); } +std::string get_jit_dir() { + std::string home_dir = get_env("HOME"); + if (home_dir.empty()) { + home_dir = "/tmp"; // Fallback to /tmp if HOME is not set + } + return home_dir + "/.deepep/hybrid_ep/jit"; +} + NVCCCompiler::NVCCCompiler(std::string base_path): base_path(base_path) { + jit_dir = get_jit_dir(); + nvcc_path = get_env("CUDA_HOME") + "/bin/nvcc"; // Init the flags to compiler std::string sm_arch_flags = convert_to_nvcc_arch_flags(SM_ARCH); - flags = "-std=c++17 " + sm_arch_flags + + std::string flags = "-std=c++17 " + sm_arch_flags + " -O3 --expt-relaxed-constexpr " " -Xcompiler -fPIC -shared "; // Add the include path of the hybrid-ep library - include = " -I" + base_path + "/backend" + std::string include = " -I" + base_path + "/backend" + " -I" + get_env("CUDA_HOME") + "/include "; // Add the library path of the hybrid-ep library - library = "-L" + get_env("CUDA_HOME") + "/lib64 -lcudart "; + std::string library = "-L" + get_env("CUDA_HOME") + "/lib64 -lcudart "; + + intra_node_flags = flags + " " + include + " " + library; #ifdef HYBRID_EP_BUILD_MULTINODE_ENABLE // Add the dependency of the inter-node jit @@ -51,13 +63,12 @@ NVCCCompiler::NVCCCompiler(std::string base_path): base_path(base_path) { + doca_obj_path + "/doca_gpunetio_log.o "; #endif - flags = flags + " " + include + " " + library; + inter_node_flags = flags + " " + include + " " + library; } -std::string NVCCCompiler::build(std::string code, std::string signature, int local_rank, int node_rank) { +std::string NVCCCompiler::build(std::string code, std::string signature, int local_rank, int node_rank, int num_of_nodes) { // Create the source directory - std::string jit_dir = base_path + "/build/jit"; std::filesystem::create_directories(jit_dir); // Get a unique signature for each run @@ -75,14 +86,19 @@ std::string NVCCCompiler::build(std::string code, std::string signature, int loc out.write(code.data(), code.size()); out.close(); - // Compile the code std::string output_path = jit_dir + "/" + extended_signature + ".so"; // Remove the output .so file if it exists remove(output_path.c_str()); - std::string compile_command = - nvcc_path + " " + flags + " " + source_path + " " + objs + " -o " + output_path; - + // Choose the flags based on the number of nodes + std::string compile_command; + if(num_of_nodes > 1) { + compile_command = nvcc_path + " " + inter_node_flags + " " + source_path + " " + objs + " -o " + output_path; + }else { + compile_command = nvcc_path + " " + intra_node_flags + " " + source_path + " -o " + output_path; + } + + // Run the compile command auto ret = std::system(compile_command.c_str()); if (ret != 0) { throw std::runtime_error("Failed to compile the code, compile command: " + compile_command); @@ -113,7 +129,7 @@ std::any NVCCCompiler::get_instance(std::string library_path, std::string kernel } // Unique the compiled lib from different rank - std::string unique_library_path = base_path + "/build/jit/" + kernel_key + ".so"; + std::string unique_library_path = jit_dir + "/" + kernel_key + ".so"; std::string unique_command = "mv " + library_path + " " + unique_library_path; if(library_path != unique_library_path) { auto ret = std::system(unique_command.c_str()); @@ -160,7 +176,7 @@ std::string NVCCCompiler::get_dispatch_code(HybridEpConfigInstance config) { std::to_string(config.hidden_dim) + ", " + std::to_string(config.max_num_of_tokens_per_rank) + ", " + std::to_string(config.num_of_ranks_per_node) + ", " + std::to_string(config.num_of_nodes) + ", " + std::to_string(config.num_of_experts_per_rank) + ">::dispatch<" + token_type + ", " + - std::to_string(config.num_of_stages_dispatch_api) + ", " + std::to_string(config.num_of_tokens_per_chunk_dispatch_api) + ", " + + std::to_string(config.num_of_stages_dispatch_api) + ", " + std::to_string(config.num_of_in_flight_s2g_dispatch_api) + ", " + std::to_string(config.num_of_tokens_per_chunk_dispatch_api) + ", " + std::to_string(config.num_of_blocks_dispatch_api) + ", " + (config.forward_dispatch_api ? "true" : "false") + ", " + (config.device_side_sync_dispatch_api ? "true" : "false") + R"(>; return func_ptr; @@ -193,12 +209,12 @@ std::string NVCCCompiler::get_combine_code(HybridEpConfigInstance config) { } KernelCache::KernelCache(int node_rank, int local_rank, std::string base_path, bool load_cached_kernels): -node_rank(node_rank), local_rank(local_rank), base_path(base_path), nvcc_compiler(base_path) { +node_rank(node_rank), local_rank(local_rank), nvcc_compiler(base_path) { // Load all cached kernels from the cache directory - std::string cache_dir = base_path + "/build/jit"; - std::filesystem::create_directories(cache_dir); + jit_dir = get_jit_dir(); + std::filesystem::create_directories(jit_dir); if(load_cached_kernels) { - for (const auto& entry : std::filesystem::directory_iterator(cache_dir)) { + for (const auto& entry : std::filesystem::directory_iterator(jit_dir)) { if (entry.path().extension() == ".so") { std::string kernel_key = entry.path().stem().string(); kernel_cache[kernel_key] = nvcc_compiler.get_instance(entry.path().string(), kernel_key); @@ -235,7 +251,7 @@ void KernelCache::run_proprecess_kernel( auto it = kernel_cache.find(preprocess_kernel_key); if (it == kernel_cache.end()) { auto preprocessing_code = nvcc_compiler.get_metadata_preprocessing_code(config); - auto preprocessing_path = nvcc_compiler.build(preprocessing_code, preprocess_kernel_key, local_rank, node_rank); + auto preprocessing_path = nvcc_compiler.build(preprocessing_code, preprocess_kernel_key, local_rank, node_rank, config.num_of_nodes); kernel_cache[preprocess_kernel_key] = nvcc_compiler.get_instance(preprocessing_path, preprocess_kernel_key); } auto preprocessing_instance = kernel_cache[preprocess_kernel_key]; @@ -278,6 +294,7 @@ void KernelCache::run_dispatch_kernel( config.num_of_nodes, type_to_string(config.token_data_type), config.num_of_stages_dispatch_api, + config.num_of_in_flight_s2g_dispatch_api, config.num_of_tokens_per_chunk_dispatch_api, config.num_of_blocks_dispatch_api, config.forward_dispatch_api, @@ -288,7 +305,7 @@ void KernelCache::run_dispatch_kernel( if (it == kernel_cache.end()) { // JIT Compile the kernel auto dispatch_code = nvcc_compiler.get_dispatch_code(config); - auto dispatch_path = nvcc_compiler.build(dispatch_code, dispatch_kernel_key, local_rank, node_rank); + auto dispatch_path = nvcc_compiler.build(dispatch_code, dispatch_kernel_key, local_rank, node_rank, config.num_of_nodes); kernel_cache[dispatch_kernel_key] = nvcc_compiler.get_instance(dispatch_path, dispatch_kernel_key); } auto dispatch_instance = kernel_cache[dispatch_kernel_key]; @@ -328,7 +345,7 @@ void KernelCache::run_combine_kernel( if (it == kernel_cache.end()) { // JIT Compile the kernel auto combine_code = nvcc_compiler.get_combine_code(config); - auto combine_path = nvcc_compiler.build(combine_code, combine_kernel_key, local_rank, node_rank); + auto combine_path = nvcc_compiler.build(combine_code, combine_kernel_key, local_rank, node_rank, config.num_of_nodes); kernel_cache[combine_kernel_key] = nvcc_compiler.get_instance(combine_path, combine_kernel_key); } auto combine_instance = kernel_cache[combine_kernel_key]; diff --git a/csrc/hybrid_ep/jit/compiler.cuh b/csrc/hybrid_ep/jit/compiler.cuh index fed2fbc4..ce031064 100644 --- a/csrc/hybrid_ep/jit/compiler.cuh +++ b/csrc/hybrid_ep/jit/compiler.cuh @@ -37,9 +37,11 @@ public: * @param signature The signature of the code, which is used to name the .so * file * @param local_rank The local rank of the current process + * @param node_rank The node rank of the current process + * @param num_of_nodes The number of nodes in the communication * @return std::string The path of the compiled .so file */ - std::string build(std::string code, std::string signature, int local_rank, int node_rank); + std::string build(std::string code, std::string signature, int local_rank, int node_rank, int num_of_nodes); /** * @brief Get the compiled function pointer from the compiled .so file @@ -52,13 +54,12 @@ public: private: - std::string base_path; // The path of the installed package - std::string flags; // The flags required by nvcc compiler, which contains the - // base flags(-O3, -arch...), include files, library files - std::string nvcc_path; // The path of the nvcc compiler - std::string include; - std::string library; - std::string objs = ""; + std::string base_path; // The path of the installed package + std::string jit_dir; // The path of the jit library + std::string intra_node_flags; // The flags required by nvcc compiler in the intra-node case + std::string inter_node_flags; // The flags required by nvcc compiler in the inter-node case + std::string nvcc_path; // The path of the nvcc compiler + std::string objs; // The objects to be compiled, only used in the inter-node case }; class KernelCache{ @@ -96,6 +97,7 @@ public: private: NVCCCompiler nvcc_compiler; std::unordered_map kernel_cache; + std::string jit_dir; // The path of the jit directory std::string base_path; // The path of the installed package int local_rank; // Used to generate the unique signature for each rank int node_rank; // Used to generate the unique signature for each node diff --git a/csrc/hybrid_ep/pybind_hybrid_ep.cu b/csrc/hybrid_ep/pybind_hybrid_ep.cu index 5e165b2f..7c7bdeda 100644 --- a/csrc/hybrid_ep/pybind_hybrid_ep.cu +++ b/csrc/hybrid_ep/pybind_hybrid_ep.cu @@ -1,6 +1,7 @@ // SPDX-License-Identifier: MIT // SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +#include #include #include #include @@ -13,7 +14,7 @@ namespace py = pybind11; PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.doc() = "HybridEP, efficiently enable the expert-parallel communication in " "the Hopper+ architectures"; - + pybind11::enum_(m, "APP_TOKEN_DATA_TYPE") .value("UINT16", APP_TOKEN_DATA_TYPE::UINT16) .value("UINT8", APP_TOKEN_DATA_TYPE::UINT8) @@ -33,8 +34,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { .def_readwrite("num_of_blocks_preprocessing_api", &BufferConfig::num_of_blocks_preprocessing_api) .def_readwrite("num_of_blocks_dispatch_api", &BufferConfig::num_of_blocks_dispatch_api) .def_readwrite("num_of_blocks_combine_api", &BufferConfig::num_of_blocks_combine_api) + .def_readwrite("num_of_blocks_permute_api", &BufferConfig::num_of_blocks_permute_api) .def_readwrite("num_of_tokens_per_chunk_dispatch_api", &BufferConfig::num_of_tokens_per_chunk_dispatch_api) .def_readwrite("num_of_tokens_per_chunk_combine_api", &BufferConfig::num_of_tokens_per_chunk_combine_api) + .def("is_valid", &BufferConfig::is_valid) .def("__repr__", [](const BufferConfig &config) { return ""; @@ -68,10 +72,14 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { &HybridEpConfigInstance::num_of_threads_per_block_preprocessing_api) .def_readwrite("num_of_blocks_preprocessing_api", &HybridEpConfigInstance::num_of_blocks_preprocessing_api) + .def_readwrite("num_of_blocks_permute_api", + &HybridEpConfigInstance::num_of_blocks_permute_api) // Dispatch API Config .def_readwrite("token_data_type", &HybridEpConfigInstance::token_data_type) .def_readwrite("num_of_stages_dispatch_api", &HybridEpConfigInstance::num_of_stages_dispatch_api) + .def_readwrite("num_of_in_flight_s2g_dispatch_api", + &HybridEpConfigInstance::num_of_in_flight_s2g_dispatch_api) .def_readwrite("num_of_tokens_per_chunk_dispatch_api", &HybridEpConfigInstance::num_of_tokens_per_chunk_dispatch_api) .def_readwrite("num_of_blocks_dispatch_api", @@ -98,6 +106,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { &HybridEpConfigInstance::backward_combine_api) .def_readwrite("device_side_sync_combine_api", &HybridEpConfigInstance::device_side_sync_combine_api) + .def("is_valid", &HybridEpConfigInstance::is_valid) .def("__repr__", [](const HybridEpConfigInstance &config) { return ", bool, bool, bool>(), + .def(py::init(), py::arg("process_group"), py::arg("config"), py::arg("local_rank"), py::arg("node_rank"), py::arg("group_size"), py::arg("base_path"), - py::arg("ib_dev_name_list") = std::vector{}, py::arg("load_cached_kernels") = false, py::arg("use_shared_buffer") = true, - py::arg("enable_fabric") = false) + py::arg("use_mnnvl") = false) .def("update_buffer", &HybridEPBuffer::update_buffer, py::arg("config")) .def("metadata_preprocessing", &HybridEPBuffer::metadata_preprocessing, py::kw_only(), py::arg("config"), py::arg("routing_map"), py::arg("num_of_tokens_per_rank")) @@ -141,16 +149,16 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("scaling_factor") = c10::nullopt, py::arg("sparse_to_dense_map"), py::arg("rdma_to_attn_map"), py::arg("attn_to_rdma_map"), py::arg("num_dispatched_tokens_tensor"), - py::arg("local_expert_routing_map"), py::arg("row_id_map"), py::arg("num_dispatched_tokens") = std::nullopt, + py::arg("local_expert_routing_map"), py::arg("row_id_map"), py::arg("num_permuted_tokens") = std::nullopt, - py::arg("num_of_tokens_per_rank"), py::arg("pad_multiple") = std::nullopt, py::arg("use_host_meta") = false, + py::arg("num_of_tokens_per_rank"), py::arg("pad_multiple") = std::nullopt, py::arg("non_blocking") = false, py::arg("with_probs") = false) .def("combine_with_unpermute", &HybridEPBuffer::combine_with_unpermute, py::kw_only(), py::arg("config"), py::arg("hidden"), py::arg("probs") = c10::nullopt, py::arg("sparse_to_dense_map"), py::arg("rdma_to_attn_map"), py::arg("attn_to_rdma_map"), py::arg("num_dispatched_tokens_tensor"), - py::arg("row_id_map"), py::arg("num_dispatched_tokens") = std::nullopt, + py::arg("row_id_map"), py::arg("num_of_tokens_per_rank"), py::arg("pad_multiple") = std::nullopt, py::arg("with_probs") = false); diff --git a/deep_ep/hybrid_ep_buffer.py b/deep_ep/hybrid_ep_buffer.py index 1ad68678..0c3615f4 100644 --- a/deep_ep/hybrid_ep_buffer.py +++ b/deep_ep/hybrid_ep_buffer.py @@ -44,8 +44,9 @@ def __init__( num_sms_dispatch_api: int = None, num_sms_combine_api: int = None, num_sms_preprocessing_api: int = None, - use_mnnvl: bool = None, - ib_dev_name_list: list[str] = [], + # Rank-based setting + num_of_hybrid_ep_ranks_per_nvlink_domain: int = None, + use_mnnvl: bool = None ): self.group = group self.rank = self.group.rank() @@ -54,63 +55,61 @@ def __init__( self.group_size > 1 ), f"The hybrid-ep kernel should be used with at least 2 ranks, but got {self.group_size}." - # Compute the number of the involved ranks in the nvlink domain. - global_ranks = torch.distributed.get_process_group_ranks(self.group) - rank_stride = global_ranks[1] - global_ranks[0] # Number of ranks in the first nvlink domain. if use_mnnvl is None: use_mnnvl = os.getenv("USE_MNNVL", "0").strip().lower() in {"1", "true", "t", "yes", "y", "on"} - if int(os.getenv("NVLINK_DOMAIN_SIZE", "8")) > 8: # For compatibility + if num_of_hybrid_ep_ranks_per_nvlink_domain is None: + num_of_hybrid_ep_ranks_per_nvlink_domain = int(os.getenv("NUM_OF_HYBRID_EP_RANKS_PER_NVLINK_DOMAIN", "8")) + if num_of_hybrid_ep_ranks_per_nvlink_domain > 8: use_mnnvl = True - self.nvlink_domain_size = 72 if use_mnnvl else 8 + assert ( - rank_stride <= self.nvlink_domain_size - ), "The rank stride should be less than or equal to the nvlink domain size." - num_of_ranks_per_node = min(self.nvlink_domain_size // rank_stride, self.group_size) - - assert ( - self.group_size % num_of_ranks_per_node == 0 + self.group_size % num_of_hybrid_ep_ranks_per_nvlink_domain == 0 ), "The number of ranks should be divisible by the number of ranks per node." self.rank = self.group.rank() - self.num_of_ranks_per_node = num_of_ranks_per_node + self.num_of_hybrid_ep_ranks_per_nvlink_domain = num_of_hybrid_ep_ranks_per_nvlink_domain # Local rank: the active rank in the nvlink domain. - self.local_rank = self.rank % self.num_of_ranks_per_node + self.local_rank = self.rank % self.num_of_hybrid_ep_ranks_per_nvlink_domain # Node rank: the active rank between the nvlink domains. - self.node_rank = self.rank // self.num_of_ranks_per_node + self.node_rank = self.rank // self.num_of_hybrid_ep_ranks_per_nvlink_domain # The number of nodes. - self.num_of_nodes = self.group_size // self.num_of_ranks_per_node + self.num_of_nodes = self.group_size // self.num_of_hybrid_ep_ranks_per_nvlink_domain self.use_fp8 = use_fp8 props = torch.cuda.get_device_properties(torch.cuda.current_device()) sm_count = props.multi_processor_count if num_sms_preprocessing_api is None: - num_sms_preprocessing_api = 128 + num_sms_preprocessing_api = 108 + num_blocks_permute_api = sm_count * 16 # Inter-node case should use less SMs for the dispatch and combine APIs. if num_sms_dispatch_api is None: - num_sms_dispatch_api = 32 if self.num_of_nodes == 1 else 16 + num_sms_dispatch_api = 32 if self.num_of_nodes == 1 else 8 if num_sms_combine_api is None: - num_sms_combine_api = 32 if self.num_of_nodes == 1 else 16 + num_sms_combine_api = 32 if self.num_of_nodes == 1 else 8 assert ( sm_count >= num_sms_preprocessing_api and sm_count >= num_sms_dispatch_api and sm_count >= num_sms_combine_api ), "check the sms occupancy setting" + # Used SMs for preprocessing of dispatch and permute. self.num_sms_preprocessing_api = num_sms_preprocessing_api self.num_sms_dispatch_api = num_sms_dispatch_api self.num_sms_combine_api = num_sms_combine_api - + self.num_blocks_permute_api = num_blocks_permute_api + # Initialize the BufferConfig for the hybrid-ep buffer allocation. self.config = hybrid_ep_cpp.BufferConfig() self.config.hidden_dim = hidden_dim - self.config.max_num_of_tokens_per_rank = max_num_of_tokens_per_rank + self.config.max_num_of_tokens_per_rank = max(max_num_of_tokens_per_rank, 512) self.config.num_of_experts_per_rank = num_local_experts - self.config.num_of_ranks_per_node = self.num_of_ranks_per_node + self.config.num_of_ranks_per_node = self.num_of_hybrid_ep_ranks_per_nvlink_domain self.config.num_of_nodes = self.num_of_nodes self.config.num_of_blocks_dispatch_api = self.num_sms_dispatch_api self.config.num_of_blocks_combine_api = self.num_sms_combine_api # The SMs of preprocessing, chunk size of dispatch and combine will affact the size of intermediate buffers. self.config.num_of_blocks_preprocessing_api = self.num_sms_preprocessing_api + self.config.num_of_blocks_permute_api = self.num_blocks_permute_api # The fp8/bf16/fp16 data is communicated in the uint8/uint16 format. self.config.token_data_type = ( hybrid_ep_cpp.UINT8 if self.use_fp8 else hybrid_ep_cpp.UINT16 @@ -122,6 +121,8 @@ def __init__( os.getenv("NUM_OF_TOKENS_PER_CHUNK_COMBINE_API", "128") ) + assert self.config.is_valid(), "The buffer config is not valid." + # Create C++ buffer - this will allocate all buffers during construction self.runtime = hybrid_ep_cpp.HybridEPBuffer( self.group, @@ -130,10 +131,9 @@ def __init__( self.node_rank, self.group_size, os.path.dirname(os.path.abspath(__file__)), - ib_dev_name_list, load_cached_kernels = False, use_shared_buffer = True, - enable_fabric = use_mnnvl, # If use_mnnvl is True, the fabric memory handle will be used. + use_mnnvl = use_mnnvl, # If use_mnnvl is True, the fabric memory handle will be used. ) def empty_jit_cache(self): @@ -162,22 +162,20 @@ def update_template_config( config.hidden_dim = ( hidden_dim if hidden_dim is not None else self.config.hidden_dim ) - config.max_num_of_tokens_per_rank = ( - max_num_of_tokens_per_rank - if max_num_of_tokens_per_rank is not None - else self.config.max_num_of_tokens_per_rank - ) - if self.num_of_nodes > 1: - assert self.config.max_num_of_tokens_per_rank == max_num_of_tokens_per_rank, "Dynamic sequence length is not supported in the multi-node case." - config.max_num_of_tokens_per_rank = max( - config.max_num_of_tokens_per_rank, self.config.max_num_of_tokens_per_rank - ) + if max_num_of_tokens_per_rank is None: + max_num_of_tokens_per_rank = self.config.max_num_of_tokens_per_rank + else: + config.max_num_of_tokens_per_rank = max( + max_num_of_tokens_per_rank, self.config.max_num_of_tokens_per_rank + ) + self.config.max_num_of_tokens_per_rank = config.max_num_of_tokens_per_rank + config.num_of_experts_per_rank = ( num_local_experts if num_local_experts is not None else self.config.num_of_experts_per_rank ) - config.num_of_ranks_per_node = self.num_of_ranks_per_node + config.num_of_ranks_per_node = self.num_of_hybrid_ep_ranks_per_nvlink_domain config.num_of_nodes = self.num_of_nodes # Metadata-preprocessing API Config @@ -185,6 +183,7 @@ def update_template_config( config.num_of_threads_per_block_preprocessing_api = int( os.getenv("NUM_OF_THREADS_PER_BLOCK_PREPROCESSING_API", "512") ) + config.num_of_blocks_permute_api = self.num_blocks_permute_api # Dispatch API Config if use_fp8 is None: @@ -198,6 +197,9 @@ def update_template_config( config.num_of_stages_dispatch_api = int( os.getenv("NUM_OF_STAGES_DISPATCH_API", "10") ) + config.num_of_in_flight_s2g_dispatch_api = int( + os.getenv("NUM_OF_IN_FLIGHT_S2G_DISPATCH_API", "8") + ) config.num_of_tokens_per_chunk_dispatch_api = int( os.getenv("NUM_OF_TOKENS_PER_CHUNK_DISPATCH_API", "128") ) @@ -227,8 +229,10 @@ def update_template_config( os.getenv("NUM_OF_ADDITIONAL_IN_FLIGHT_S2G_COMBINE_API", "2") ) + assert config.is_valid(), "The config is not valid." + # Use the runtime kernel config to update the buffer. - reallocated = self.runtime.update_buffer(config) + self.runtime.update_buffer(config) return config def dispatch( @@ -321,6 +325,8 @@ def dispatch( ) = handle if num_dispatched_tokens is None: + # Synchronize the stream to make sure the data in the pinned_memory_buffer: num_dispatched_tokens_tensor is ready. + torch.cuda.current_stream().synchronize() num_dispatched_tokens = num_dispatched_tokens_tensor.item() dispatched_token, dispatched_probs, dispatched_scaling_factor = ( @@ -389,7 +395,6 @@ def dispatch_with_permute( probs: torch.Tensor = None, scaling_factor: torch.Tensor = None, # Used in the sync-free permute - num_dispatched_tokens: int = None, num_permuted_tokens: int = None, # If we use permute kernel, the output tensor will be permuted. the result can be directly used in the gemm. pad_multiple: int = None, @@ -406,8 +411,13 @@ def dispatch_with_permute( # # Cache for template config # 7. template_config: HybridEpConfigInstance handle: tuple = None, - # If enable this, the produced num_dispatched_tokens will be put on the CPU pinned memory, and the tokens_per_expert will be put on the CPU, which may reduce the times of the sync - use_host_meta: bool = True, + # There are 2 tensors are put on the CPU pinned memory + # 1. num_dispatched_tokens in handle + # 2. tokens_per_expert + # If non_blocking is True, no stream synchronization will be used, so we can not promise the data in pinned + # memory is ready for using in CPU. The CPU value of num_permuted_tokens required for this mode + # Otherwise, the stream synchronization will be used to wait for the data in pinned memory. + non_blocking: bool = False, ): """ Dispatch the data to the experts with permute. @@ -426,6 +436,8 @@ def dispatch_with_permute( routing_map, probs = indices_to_map( topk_idx, topk_weights, num_of_tokens_per_rank, num_of_experts ) + if non_blocking: + assert num_permuted_tokens >= 0, "The num_permuted_tokens is required for non-blocking mode." # If the handle is not provided, we need to generate the handle in the first invocation of the dispatch kernel. if handle is None: @@ -462,15 +474,6 @@ def dispatch_with_permute( routing_map=global_routing_map, num_of_tokens_per_rank=num_of_tokens_per_rank, ) - if use_host_meta: - # Put the num_dispatched_tokens_tensor on the CPU pinned memory, because this tensor also will be used in the GPU kernel - num_dispatched_tokens_tensor_pinned = torch.empty( - num_dispatched_tokens_tensor.shape, - device="cpu", - dtype=num_dispatched_tokens_tensor.dtype, - pin_memory=True, - ) - num_dispatched_tokens_tensor_pinned.copy_(num_dispatched_tokens_tensor, False) else: ( sparse_to_dense_map, @@ -501,11 +504,10 @@ def dispatch_with_permute( num_dispatched_tokens_tensor=num_dispatched_tokens_tensor, local_expert_routing_map=local_expert_routing_map, row_id_map=row_id_map, - num_dispatched_tokens=num_dispatched_tokens, num_permuted_tokens=num_permuted_tokens, num_of_tokens_per_rank=num_of_tokens_per_rank, pad_multiple=pad_multiple, - use_host_meta=use_host_meta, + non_blocking=non_blocking, with_probs=probs is not None, ) @@ -519,6 +521,7 @@ def dispatch_with_permute( num_of_tokens_per_rank, config, ) + return ( dispatched_token, dispatched_probs, @@ -533,7 +536,6 @@ def combine_with_unpermute( # Input tensors hidden: torch.Tensor, probs: torch.Tensor = None, - num_dispatched_tokens: int = None, handle: tuple = None, pad_multiple: int = None, ): @@ -565,7 +567,6 @@ def combine_with_unpermute( attn_to_rdma_map=attn_to_rdma_map, num_dispatched_tokens_tensor=num_dispatched_tokens_tensor, row_id_map=row_id_map, - num_dispatched_tokens=num_dispatched_tokens, num_of_tokens_per_rank=num_of_tokens_per_rank, pad_multiple=pad_multiple, with_probs=probs is not None, diff --git a/setup.py b/setup.py index 13adae3b..687a1278 100644 --- a/setup.py +++ b/setup.py @@ -10,6 +10,17 @@ from torch.utils.cpp_extension import BuildExtension, CUDAExtension +def collect_package_files(package: str, relative_dir: str): + base_path = Path(package) / relative_dir + if not base_path.exists(): + return [] + return [ + str(path.relative_to(package)) + for path in base_path.rglob('*') + if path.is_file() + ] + + # Wheel specific: the wheels only include the soname of the host library `libnvshmem_host.so.X` def get_nvshmem_host_lib_name(base_dir): path = Path(base_dir).joinpath('lib') @@ -66,6 +77,7 @@ def get_extension_hybrid_ep_cpp(): libraries = ["cuda", "nvtx3interop"] extra_objects = [] runtime_library_dirs = [] + extra_link_args = [] # Add dependency for jit compile_args["nvcc"].append(f'-DSM_ARCH="{os.environ["TORCH_CUDA_ARCH_LIST"]}"') @@ -82,6 +94,7 @@ def get_extension_hybrid_ep_cpp(): nccl_dir = os.path.join(current_dir, "third-party/nccl") compile_args["nvcc"].append("-DHYBRID_EP_BUILD_MULTINODE_ENABLE") compile_args["nvcc"].append(f"-DRDMA_CORE_HOME=\"{rdma_core_dir}\"") + extra_link_args.append(f"-l:libnvidia-ml.so.1") subprocess.run(["git", "submodule", "update", "--init", "--recursive"], cwd=current_dir) # Generate the inter-node dependency to the python package for JIT compilation @@ -128,6 +141,7 @@ def get_extension_hybrid_ep_cpp(): print(f' > Includes: {include_dirs}') print(f' > Libraries: {libraries}') print(f' > Library dirs: {library_dirs}') + print(f' > Extra link args: {extra_link_args}') print(f' > Compilation flags: {compile_args}') print(f' > Extra objects: {extra_objects}') print(f' > Runtime library dirs: {runtime_library_dirs}') @@ -143,6 +157,7 @@ def get_extension_hybrid_ep_cpp(): extra_compile_args=compile_args, extra_objects=extra_objects, runtime_library_dirs=runtime_library_dirs, + extra_link_args=extra_link_args, ) return extension_hybrid_ep_cpp @@ -271,7 +286,7 @@ def get_extension_deep_ep_cpp(): 'build_ext': BuildExtension }, package_data={ - 'deep_ep': ['backend/*'], + 'deep_ep': collect_package_files('deep_ep', 'backend'), }, include_package_data=True ) diff --git a/tests/test_hybrid_ep.py b/tests/test_hybrid_ep.py index 2415f41f..64b954eb 100644 --- a/tests/test_hybrid_ep.py +++ b/tests/test_hybrid_ep.py @@ -21,6 +21,7 @@ NUM_OF_EXPERTS = NUM_LOCAL_EXPERTS * NUM_OF_RANKS_PER_NODE * NUM_OF_NODES ITERATIONS = int(os.environ.get("ITERATIONS", 100)) SEED = int(os.environ.get("SEED", 42)) +USE_MNNVL = os.environ.get("USE_MNNVL", "0").strip().lower() in {"1", "true", "t", "yes", "y", "on"} torch.manual_seed(SEED) torch.cuda.manual_seed(SEED) torch.cuda.manual_seed_all(SEED) @@ -253,7 +254,16 @@ def test_hybrid_ep_benchmark(buffer: deep_ep.HybridEPBuffer, group: dist.Process dispatch_bf16_rdma_send_bytes = num_rdma_send * HIDDEN_DIM * 2 combine_bf16_rdma_recv_bytes = dispatch_bf16_rdma_send_bytes - dispatch_args = {'hidden': hidden, 'scaling_factor': scaling_factor, 'topk_idx': topk_idx, 'topk_weights': topk_weights, 'num_of_experts': NUM_OF_EXPERTS} + ''' + Benchmark of the dispatch and combine torch API without permute + ''' + + dispatched_hidden, dispatched_probs, _, handle= ( + buffer.dispatch(hidden=hidden, scaling_factor=scaling_factor, topk_idx=topk_idx, topk_weights=topk_weights, num_of_experts=NUM_OF_EXPERTS) + ) + dispatched_hidden_bf16 = dispatched_hidden.to(torch.bfloat16) + + dispatch_args = {'hidden': hidden, 'scaling_factor': scaling_factor, 'topk_idx': topk_idx, 'topk_weights': topk_weights, 'num_of_experts': NUM_OF_EXPERTS, 'handle': handle} t = bench(lambda: buffer.dispatch(**dispatch_args))[0] nvl_recv_bytes = (dispatch_bf16_nvl_recv_bytes * fp8_factor) if hidden.dtype == torch.uint8 else dispatch_bf16_nvl_recv_bytes if NUM_OF_NODES > 1: @@ -264,10 +274,6 @@ def test_hybrid_ep_benchmark(buffer: deep_ep.HybridEPBuffer, group: dist.Process print_in_order(f'[rank {rank}] HybridEP dispatch torch API ({"FP8" if hidden.dtype == torch.uint8 else "BF16"}): ' f'{rdma_send_bytes / 1e9 / t:.2f} GB/s (IB), t: {t * 1e6:.2f} us, rdma_send_bytes: {rdma_send_bytes / 1e6:.2f} MB') - dispatched_hidden, dispatched_probs, _, handle= ( - buffer.dispatch(hidden=hidden, scaling_factor=scaling_factor, topk_idx=topk_idx, topk_weights=topk_weights, num_of_experts=NUM_OF_EXPERTS) - ) - dispatched_hidden_bf16 = dispatched_hidden.to(torch.bfloat16) combine_args = {'hidden': dispatched_hidden_bf16, 'probs': dispatched_probs, 'handle': handle} t = bench(lambda: buffer.combine(**combine_args))[0] print_in_order(f'[rank {rank}] HybridEP combine torch API: ' @@ -279,7 +285,13 @@ def test_hybrid_ep_benchmark(buffer: deep_ep.HybridEPBuffer, group: dist.Process ''' Benchmark of the dispatch and combine with permute extension ''' - dispatch_with_permute_args = {'hidden': hidden, 'scaling_factor': scaling_factor, 'routing_map': routing_map, 'probs': probs, 'pad_multiple': PAD_MULTIPLE} + dispatched_hidden_with_permute, dispatched_probs_with_permute, _, tokens_per_expert, handle_with_permute= ( + buffer.dispatch_with_permute(hidden=hidden, scaling_factor=scaling_factor, routing_map=routing_map, probs=probs, pad_multiple=PAD_MULTIPLE) + ) + num_permuted_tokens = tokens_per_expert.sum().item() + dispatched_hidden_bf16_with_permute = dispatched_hidden_with_permute.to(torch.bfloat16) + + dispatch_with_permute_args = {'hidden': hidden, 'scaling_factor': scaling_factor, 'routing_map': routing_map, 'probs': probs, 'pad_multiple': PAD_MULTIPLE, 'handle': handle_with_permute, 'num_permuted_tokens': num_permuted_tokens} t = bench(lambda: buffer.dispatch_with_permute(**dispatch_with_permute_args))[0] nvl_recv_bytes = (dispatch_bf16_nvl_recv_bytes * fp8_factor) if hidden.dtype == torch.uint8 else dispatch_bf16_nvl_recv_bytes print_in_order(f'[rank {rank}] HybridEP dispatch+permute torch API ({"FP8" if hidden.dtype == torch.uint8 else "BF16"}): ' @@ -288,11 +300,7 @@ def test_hybrid_ep_benchmark(buffer: deep_ep.HybridEPBuffer, group: dist.Process print_in_order(f'[rank {rank}] HybridEP dispatch+permute torch API ({"FP8" if hidden.dtype == torch.uint8 else "BF16"}): ' f'{rdma_send_bytes / 1e9 / t:.2f} GB/s (IB), t: {t * 1e6:.2f} us, rdma_send_bytes: {rdma_send_bytes / 1e6:.2f} MB') - dispatched_hidden, dispatched_probs, _, _, handle= ( - buffer.dispatch_with_permute(hidden=hidden, scaling_factor=scaling_factor, routing_map=routing_map, probs=probs, pad_multiple=PAD_MULTIPLE) - ) - dispatched_hidden_bf16 = dispatched_hidden.to(torch.bfloat16) - combine_with_unpermute_args = {'hidden': dispatched_hidden_bf16, 'probs': dispatched_probs, 'handle': handle, 'pad_multiple': PAD_MULTIPLE} + combine_with_unpermute_args = {'hidden': dispatched_hidden_bf16_with_permute, 'probs': dispatched_probs_with_permute, 'handle': handle_with_permute, 'pad_multiple': PAD_MULTIPLE} t = bench(lambda: buffer.combine_with_unpermute(**combine_with_unpermute_args))[0] print_in_order(f'[rank {rank}] HybridEP combine+unpermute torch API: ' f'{combine_bf16_nvl_send_bytes / 1e9 / t:.2f} GB/s (NVL), t: {t * 1e6:.2f} us, combine_send_bytes: {combine_bf16_nvl_send_bytes / 1e6:.2f} MB') @@ -325,27 +333,38 @@ def test_func(): with torch.cuda.nvtx.range(f"hybrid-ep dispatch ({"FP8" if hidden.dtype == torch.uint8 else "BF16"})"): if rank == 0: print(f"profile hybrid-ep dispatch ({"FP8" if hidden.dtype == torch.uint8 else "BF16"})", flush=True) - dispatch_args = {'tensor': hidden, 'scaling_factor': scaling_factor, 'topk_idx': topk_idx, 'topk_weights': topk_weights, 'num_of_experts': NUM_OF_EXPERTS} + dispatch_args = {'hidden': hidden, 'scaling_factor': scaling_factor, 'topk_idx': topk_idx, 'topk_weights': topk_weights, 'num_of_experts': NUM_OF_EXPERTS} bench(lambda: buffer.dispatch(**dispatch_args)) with torch.cuda.nvtx.range("hybrid-ep combine"): if rank == 0: print(f"profile hybrid-ep combine", flush=True) - combine_args = {'tensor': dispatched_hidden_bf16, 'probs': dispatched_probs, 'handle': handle} + combine_args = {'hidden': dispatched_hidden_bf16, 'probs': dispatched_probs, 'handle': handle} bench(lambda: buffer.combine(**combine_args)) + with torch.cuda.nvtx.range(f"hybrid-ep dispatch+permute ({"FP8" if hidden.dtype == torch.uint8 else "BF16"})"): + if rank == 0: + print(f"profile hybrid-ep dispatch+permute ({"FP8" if hidden.dtype == torch.uint8 else "BF16"})", flush=True) + dispatch_with_permute_args = {'hidden': hidden, 'scaling_factor': scaling_factor, 'routing_map': routing_map, 'probs': probs, 'pad_multiple': PAD_MULTIPLE} + bench(lambda: buffer.dispatch_with_permute(**dispatch_with_permute_args)) + with torch.cuda.nvtx.range("hybrid-ep combine+unpermute"): + if rank == 0: + print(f"profile hybrid-ep combine+unpermute", flush=True) + combine_with_unpermute_args = {'hidden': dispatched_hidden_bf16_with_permute, 'probs': dispatched_probs_with_permute, 'handle': handle_with_permute, 'pad_multiple': PAD_MULTIPLE} + bench(lambda: buffer.combine_with_unpermute(**combine_with_unpermute_args)) time.sleep(1) torch.cuda.profiler.stop() def test_main(local_rank: int, num_local_ranks: int, args: argparse.Namespace): _, _, group = init_dist(local_rank, num_local_ranks) - for use_fp8 in [True, False]: + for use_fp8 in [False, True]: buffer = deep_ep.HybridEPBuffer( group=group, hidden_dim=HIDDEN_DIM, max_num_of_tokens_per_rank=MAX_NUM_OF_TOKENS_PER_RANK, num_local_experts=NUM_LOCAL_EXPERTS, - use_fp8=use_fp8, - ib_dev_name_list=args.ib_dev_name_list, + num_of_hybrid_ep_ranks_per_nvlink_domain=NUM_OF_RANKS_PER_NODE, + use_mnnvl=USE_MNNVL, + use_fp8=use_fp8 ) ref = TorchRef( @@ -365,7 +384,5 @@ def test_main(local_rank: int, num_local_ranks: int, args: argparse.Namespace): help='Number of processes to spawn (default: 4)') parser.add_argument('--nsys-profile', action='store_true', default=False, help='benchmark with nsys profile or not (default: False)') - parser.add_argument('--ib-dev-name-list', nargs='+', type=str, default=[], - help='IB device name list (default: [])') args = parser.parse_args() torch.multiprocessing.spawn(test_main, args=(args.num_processes, args), nprocs=args.num_processes)