diff --git a/csrc/config.hpp b/csrc/config.hpp index 5c911989..15dfbacd 100644 --- a/csrc/config.hpp +++ b/csrc/config.hpp @@ -6,7 +6,7 @@ namespace deep_ep { template -dtype_t ceil_div(dtype_t a, dtype_t b) { +constexpr dtype_t ceil_div(dtype_t a, dtype_t b) { return (a + b - 1) / b; } @@ -89,6 +89,11 @@ struct Config { num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * kNumMaxTopK * sizeof(float) * 2; num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * kNumMaxScales * sizeof(float) * 2; num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * sizeof(int4) * 2; + + // NOTE Please keep in sync: Config.get_nvl_buffer_size_hint, LowLatencyLayout.constructor, internode_ll_v2 + // NOTE add a large number to be safe + num_bytes += 1048576; + num_bytes = ((num_bytes + 127) / 128) * 128; return num_bytes; #else @@ -102,7 +107,9 @@ struct LowLatencyBuffer { void* dispatch_rdma_send_buffer = nullptr; void* dispatch_rdma_recv_data_buffer = nullptr; - int* dispatch_rdma_recv_count_buffer = nullptr; + // NOTE rename + // int* dispatch_rdma_recv_count_buffer = nullptr; + int* dispatch_rdma_general_signal_buffer = nullptr; void* combine_rdma_send_buffer = nullptr; void* combine_rdma_recv_data_buffer = nullptr; @@ -112,8 +119,8 @@ struct LowLatencyBuffer { size_t num_bytes_per_combine_msg = 0; std::pair clean_meta() { - EP_HOST_ASSERT(dispatch_rdma_recv_count_buffer == combine_rdma_recv_flag_buffer); - return {dispatch_rdma_recv_count_buffer, num_clean_int}; + EP_HOST_ASSERT(dispatch_rdma_general_signal_buffer == combine_rdma_recv_flag_buffer); + return {dispatch_rdma_general_signal_buffer, num_clean_int}; } }; @@ -129,6 +136,9 @@ struct LowLatencyLayout { LowLatencyLayout(void* rdma_buffer, int num_max_dispatch_tokens_per_rank, int hidden, int num_ranks, int num_experts) { const int num_scales = hidden / 128; + EP_HOST_ASSERT(num_experts % num_ranks == 0); + const int num_local_experts = num_experts / num_ranks; + // Dispatch and combine layout: // - 2 symmetric odd/even send buffer // - 2 symmetric odd/even receive buffers @@ -157,9 +167,13 @@ struct LowLatencyLayout { total_bytes += recv_buffer_bytes * 2; // Symmetric signaling buffers - size_t dispatch_recv_count_buffer_bytes = num_experts * sizeof(int); - size_t combine_recv_flag_buffer_bytes = dispatch_recv_count_buffer_bytes; - size_t signaling_buffer_bytes = std::max(dispatch_recv_count_buffer_bytes, combine_recv_flag_buffer_bytes); + // NOTE can only increase instead of decrease to be compatible with v1 + // NOTE be careful about alignment + // size_t dispatch_recv_count_buffer_bytes = num_experts * sizeof(int); + // NOTE Please keep in sync: Config.get_nvl_buffer_size_hint, LowLatencyLayout.constructor, internode_ll_v2 + size_t dispatch_general_signal_buffer_bytes = num_experts * sizeof(int64_t) + num_local_experts * sizeof(int); + size_t combine_recv_flag_buffer_bytes = dispatch_general_signal_buffer_bytes; + size_t signaling_buffer_bytes = std::max(dispatch_general_signal_buffer_bytes, combine_recv_flag_buffer_bytes); size_t signaling_buffer_bytes_aligned = align(signaling_buffer_bytes, 128); total_bytes += signaling_buffer_bytes_aligned * 2; diff --git a/csrc/deep_ep.cpp b/csrc/deep_ep.cpp index 0789cd58..fd8561cf 100644 --- a/csrc/deep_ep.cpp +++ b/csrc/deep_ep.cpp @@ -9,9 +9,143 @@ #include "deep_ep.hpp" #include "kernels/api.cuh" #include "kernels/configs.cuh" +#include "kernels/internode_ll_v2_inc.cuh" + +constexpr int HIDDEN_DIM = 7168; + +namespace shared_memory { +void cu_mem_set_access_all(void* ptr, size_t size) { + int device_count; + CUDA_CHECK(cudaGetDeviceCount(&device_count)); + + CUmemAccessDesc access_desc[device_count]; + for (int idx = 0; idx < device_count; ++idx) { + access_desc[idx].location.type = CU_MEM_LOCATION_TYPE_DEVICE; + access_desc[idx].location.id = idx; + access_desc[idx].flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; + } + + CU_CHECK(cuMemSetAccess((CUdeviceptr)ptr, size, access_desc, device_count)); +} + +void cu_mem_free(void* ptr) { + CUmemGenericAllocationHandle handle; + CU_CHECK(cuMemRetainAllocationHandle(&handle, ptr)); + + size_t size = 0; + CU_CHECK(cuMemGetAddressRange(NULL, &size, (CUdeviceptr)ptr)); + + CU_CHECK(cuMemUnmap((CUdeviceptr)ptr, size)); + CU_CHECK(cuMemAddressFree((CUdeviceptr)ptr, size)); + CU_CHECK(cuMemRelease(handle)); +} + +size_t get_size_align_to_granularity(size_t size_raw, size_t granularity) { + size_t size = (size_raw + granularity - 1) & ~(granularity - 1); + if (size == 0) size = granularity; + return size; +} + +bool support_fabric() { + int device_count; + CUDA_CHECK(cudaGetDeviceCount(&device_count)); + + for (int device = 0; device < device_count; ++device) { + int support = 0; + CU_CHECK(cuDeviceGetAttribute(&support, CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_FABRIC_SUPPORTED, device)); + if (!support) { + return false; + } + } + + return true; +} + +SharedMemoryAllocator::SharedMemoryAllocator() : enable_fabric(support_fabric()) {} + +void SharedMemoryAllocator::malloc(void** ptr, size_t size_raw) { + if (enable_fabric) { + CUdevice device; + CU_CHECK(cuCtxGetDevice(&device)); + + CUmemAllocationProp prop = {}; + prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; + prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + prop.requestedHandleTypes = CU_MEM_HANDLE_TYPE_FABRIC; + prop.location.id = device; + + size_t granularity = 0; + CU_CHECK(cuMemGetAllocationGranularity(&granularity, &prop, CU_MEM_ALLOC_GRANULARITY_MINIMUM)); + + size_t size = get_size_align_to_granularity(size_raw, granularity); + + CUmemGenericAllocationHandle handle; + CU_CHECK(cuMemCreate(&handle, size, &prop, 0)); + + CU_CHECK(cuMemAddressReserve((CUdeviceptr *)ptr, size, granularity, 0, 0)); + CU_CHECK(cuMemMap((CUdeviceptr)*ptr, size, 0, handle, 0)); + cu_mem_set_access_all(*ptr, size); + } else { + CUDA_CHECK(cudaMalloc(ptr, size_raw)); + } +} + +void SharedMemoryAllocator::free(void* ptr) { + if (enable_fabric) { + cu_mem_free(ptr); + } else { + CUDA_CHECK(cudaFree(ptr)); + } +} + +void SharedMemoryAllocator::get_mem_handle(MemHandle* mem_handle, void* ptr) { + size_t size = 0; + CU_CHECK(cuMemGetAddressRange(NULL, &size, (CUdeviceptr)ptr)); + + mem_handle->size = size; + + if (enable_fabric) { + CUmemGenericAllocationHandle handle; + CU_CHECK(cuMemRetainAllocationHandle(&handle, ptr)); + + CU_CHECK(cuMemExportToShareableHandle(&mem_handle->inner.cu_mem_fabric_handle, handle, CU_MEM_HANDLE_TYPE_FABRIC, 0)); + } else { + CUDA_CHECK(cudaIpcGetMemHandle(&mem_handle->inner.cuda_ipc_mem_handle, ptr)); + } +} + +void SharedMemoryAllocator::open_mem_handle(void** ptr, MemHandle* mem_handle) { + if (enable_fabric) { + size_t size = mem_handle->size; + + CUmemGenericAllocationHandle handle; + CU_CHECK(cuMemImportFromShareableHandle(&handle, &mem_handle->inner.cu_mem_fabric_handle, CU_MEM_HANDLE_TYPE_FABRIC)); + + CU_CHECK(cuMemAddressReserve((CUdeviceptr *)ptr, size, 0, 0, 0)); + CU_CHECK(cuMemMap((CUdeviceptr)*ptr, size, 0, handle, 0)); + cu_mem_set_access_all(*ptr, size); + } else { + CUDA_CHECK(cudaIpcOpenMemHandle(ptr, mem_handle->inner.cuda_ipc_mem_handle, cudaIpcMemLazyEnablePeerAccess)); + } +} + +void SharedMemoryAllocator::close_mem_handle(void* ptr) { + if (enable_fabric) { + cu_mem_free(ptr); + } else { + CUDA_CHECK(cudaIpcCloseMemHandle(ptr)); + } +} +} namespace deep_ep { +cudaError_t cudaMallocAndZero(void** devPtr, size_t size) { + cudaError_t err = cudaMalloc(devPtr, size); + if (err != cudaSuccess) return err; + return cudaMemset(*devPtr, 0, size); +} + Buffer::Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_bytes, bool low_latency_mode, bool explicitly_destroy): rank(rank), num_ranks(num_ranks), num_nvl_bytes(num_nvl_bytes), num_rdma_bytes(num_rdma_bytes), @@ -46,8 +180,8 @@ Buffer::Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_ if (num_nvl_bytes > 0) { // Local IPC: alloc local memory and set local IPC handles - CUDA_CHECK(cudaMalloc(&buffer_ptrs[nvl_rank], num_nvl_bytes + barrier_signal_bytes + buffer_ptr_bytes + barrier_signal_ptr_bytes)); - CUDA_CHECK(cudaIpcGetMemHandle(&ipc_handles[nvl_rank], buffer_ptrs[nvl_rank])); + shared_memory_allocator.malloc(&buffer_ptrs[nvl_rank], num_nvl_bytes + barrier_signal_bytes + buffer_ptr_bytes + barrier_signal_ptr_bytes); + shared_memory_allocator.get_mem_handle(&ipc_handles[nvl_rank], buffer_ptrs[nvl_rank]); buffer_ptrs_gpu = reinterpret_cast(static_cast(buffer_ptrs[nvl_rank]) + num_nvl_bytes + barrier_signal_bytes); // Set barrier signals @@ -115,7 +249,8 @@ int Buffer::get_local_device_id() const { } pybind11::bytearray Buffer::get_local_ipc_handle() const { - return {ipc_handles[nvl_rank].reserved, CUDA_IPC_HANDLE_SIZE}; + const shared_memory::MemHandle& handle = ipc_handles[nvl_rank]; + return {reinterpret_cast(&handle), sizeof(handle)}; } pybind11::bytearray Buffer::get_local_nvshmem_unique_id() const { @@ -154,11 +289,11 @@ void Buffer::destroy() { // Close remote IPC if (is_available()) { for (int i = 0; i < num_nvl_ranks; ++ i) if (i != nvl_rank) - CUDA_CHECK(cudaIpcCloseMemHandle(buffer_ptrs[i])); + shared_memory_allocator.close_mem_handle(buffer_ptrs[i]); } // Free local buffer and error flag - CUDA_CHECK(cudaFree(buffer_ptrs[nvl_rank])); + shared_memory_allocator.free(buffer_ptrs[nvl_rank]); } // Free NVSHMEM @@ -194,13 +329,13 @@ void Buffer::sync(const std::vector &device_ids, for (int i = 0, offset = rdma_rank * num_nvl_ranks; i < num_nvl_ranks; ++ i) { EP_HOST_ASSERT(all_gathered_handles[offset + i].has_value()); auto handle_str = std::string(all_gathered_handles[offset + i].value()); - EP_HOST_ASSERT(handle_str.size() == CUDA_IPC_HANDLE_SIZE); + EP_HOST_ASSERT(handle_str.size() == shared_memory::HANDLE_SIZE); if (offset + i != rank) { - std::memcpy(ipc_handles[i].reserved, handle_str.c_str(), CUDA_IPC_HANDLE_SIZE); - CUDA_CHECK(cudaIpcOpenMemHandle(&buffer_ptrs[i], ipc_handles[i], cudaIpcMemLazyEnablePeerAccess)); + std::memcpy(&ipc_handles[i], handle_str.c_str(), shared_memory::HANDLE_SIZE); + shared_memory_allocator.open_mem_handle(&buffer_ptrs[i], &ipc_handles[i]); barrier_signal_ptrs[i] = reinterpret_cast(static_cast(buffer_ptrs[i]) + num_nvl_bytes); } else { - EP_HOST_ASSERT(std::memcmp(ipc_handles[i].reserved, handle_str.c_str(), CUDA_IPC_HANDLE_SIZE) == 0); + EP_HOST_ASSERT(std::memcmp(&ipc_handles[i], handle_str.c_str(), shared_memory::HANDLE_SIZE) == 0); } } @@ -1087,24 +1222,41 @@ void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int #endif } +// TODO `x` (in the new approach), `zeroed_tensor`, `dst_signals` etc will be modified. shall we represent in m.def? std::tuple, torch::Tensor, torch::Tensor, torch::Tensor, std::optional, std::optional>> -Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx, +Buffer::low_latency_dispatch(bool enable_v2, const torch::Tensor& x, const torch::Tensor& topk_idx, const std::optional& cumulative_local_expert_recv_stats, const std::optional& dispatch_wait_recv_cost_stats, int num_max_dispatch_tokens_per_rank, int num_experts, bool use_fp8, bool round_scale, bool use_ue8m0, - bool async, bool return_recv_hook) { + bool async, bool return_recv_hook, + const std::optional& zeroed_tensor_a, + const std::optional& zeroed_tensor_b, + const std::optional& zeroed_buffer_for_atomic_counter_per_expert, + bool use_nvfp4, + const std::optional& dst_signals, + const std::optional& count_per_expert, const std::optional& token_idx_and_dst_expert_and_dst_slot_idx_flat_list, + const std::optional& debug_tensor) { #ifndef DISABLE_NVSHMEM EP_HOST_ASSERT(low_latency_mode); // Tensor checks // By default using `ptp128c` FP8 cast - EP_HOST_ASSERT(x.dim() == 2 and x.is_contiguous() and x.scalar_type() == torch::kBFloat16); - EP_HOST_ASSERT(x.size(1) % sizeof(int4) == 0 and x.size(1) % 128 == 0); - EP_HOST_ASSERT(topk_idx.dim() == 2 and topk_idx.is_contiguous()); - EP_HOST_ASSERT(x.size(0) == topk_idx.size(0) and x.size(0) <= num_max_dispatch_tokens_per_rank); - EP_HOST_ASSERT(topk_idx.scalar_type() == torch::kInt64); - EP_HOST_ASSERT(num_experts % num_ranks == 0); + + if (enable_v2) { + // NOTE `x` is packed now + using Consts = internode_ll::DispatchConstsTemplate; + EP_HOST_ASSERT(x.dim() == 2 and x.is_contiguous() and x.scalar_type() == torch::kUInt8); + EP_HOST_ASSERT(x.size(1) == Consts::num_bytes_per_msg); + + EP_HOST_ASSERT(x.size(0) == topk_idx.size(0) and x.size(0) <= num_max_dispatch_tokens_per_rank); + EP_HOST_ASSERT(topk_idx.dim() == 2 and topk_idx.is_contiguous()); + EP_HOST_ASSERT(topk_idx.scalar_type() == torch::kInt64); + EP_HOST_ASSERT(num_experts % num_ranks == 0); + } else { + EP_HOST_ASSERT(x.dim() == 2 and x.is_contiguous() and x.scalar_type() == torch::kBFloat16); + EP_HOST_ASSERT(x.size(1) % sizeof(int4) == 0 and x.size(1) % 128 == 0); + } // Diagnosis tensors if (cumulative_local_expert_recv_stats.has_value()) { @@ -1118,10 +1270,36 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i EP_HOST_ASSERT(dispatch_wait_recv_cost_stats->size(0) == num_ranks); } - auto num_tokens = static_cast(x.size(0)), hidden = static_cast(x.size(1)); + // auto num_tokens = static_cast(x.size(0)), hidden = static_cast(x.size(1)); + auto num_tokens = static_cast(x.size(0)); + auto hidden = enable_v2 + ? HIDDEN_DIM + : static_cast(x.size(1)); + auto num_topk = static_cast(topk_idx.size(1)); auto num_local_experts = num_experts / num_ranks; + // NOTE ADD + if (count_per_expert.has_value()) { + EP_HOST_ASSERT(count_per_expert->is_contiguous()); + EP_HOST_ASSERT(count_per_expert->dim() == 1); + EP_HOST_ASSERT(count_per_expert->size(0) == num_experts); + EP_HOST_ASSERT(count_per_expert->dtype() == torch::kUInt32); + } +// if (token_ids_of_expert.has_value()) { +// EP_HOST_ASSERT(token_ids_of_expert->is_contiguous()); +// EP_HOST_ASSERT(token_ids_of_expert->dim() == 2); +// EP_HOST_ASSERT(token_ids_of_expert->size(0) == num_experts); +// // EP_HOST_ASSERT(token_ids_of_expert->size(1) == ...whatever...); +// EP_HOST_ASSERT(token_ids_of_expert->dtype() == torch::kInt32); +// } + if (token_idx_and_dst_expert_and_dst_slot_idx_flat_list.has_value()) { + EP_HOST_ASSERT(token_idx_and_dst_expert_and_dst_slot_idx_flat_list->is_contiguous()); + EP_HOST_ASSERT(token_idx_and_dst_expert_and_dst_slot_idx_flat_list->dim() == 1); + EP_HOST_ASSERT(token_idx_and_dst_expert_and_dst_slot_idx_flat_list->size(0) == num_tokens * num_topk); + EP_HOST_ASSERT(token_idx_and_dst_expert_and_dst_slot_idx_flat_list->dtype() == torch::kInt64); + } + // Buffer control LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts); EP_HOST_ASSERT(layout.total_bytes <= num_rdma_bytes); @@ -1137,11 +1315,48 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i stream_wait(launch_stream, compute_stream); // Allocate packed tensors - auto packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden}, - x.options().dtype(use_fp8 ? torch::kFloat8_e4m3fn: torch::kBFloat16)); + constexpr int NUM_ELEMS_PER_PACK = 8; + // TODO do not allocate this in v2 + auto packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, use_nvfp4 ? hidden / NUM_ELEMS_PER_PACK : hidden}, + x.options().dtype(use_nvfp4 ? torch::kInt32 : (use_fp8 ? torch::kFloat8_e4m3fn: torch::kBFloat16))); auto packed_recv_src_info = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype(torch::kInt32).device(torch::kCUDA)); auto packed_recv_layout_range = torch::empty({num_local_experts, num_ranks}, torch::dtype(torch::kInt64).device(torch::kCUDA)); - auto packed_recv_count = torch::empty({num_local_experts}, torch::dtype(torch::kInt32).device(torch::kCUDA)); + + // NOTE let users do the zeroing + EP_HOST_ASSERT(enable_v2 == zeroed_tensor_a.has_value()); + auto packed_recv_count = zeroed_tensor_a.has_value() + ? zeroed_tensor_a.value() + : torch::empty({num_local_experts}, torch::dtype(torch::kInt32).device(torch::kCUDA)); + EP_HOST_ASSERT(packed_recv_count.is_contiguous()); + EP_HOST_ASSERT(packed_recv_count.dim() == 1); + EP_HOST_ASSERT(packed_recv_count.size(0) == num_local_experts); + EP_HOST_ASSERT(packed_recv_count.dtype() == torch::kInt32); + EP_HOST_ASSERT(packed_recv_count.device().is_cuda()); + EP_HOST_ASSERT(packed_recv_count.stride(0) == 1); + EP_HOST_ASSERT(((int64_t)packed_recv_count.data_ptr()) % 16 == 0); // alignment + + // (num_experts,). used in curr gpu. for i-th dst rank, what is the start offset in the remote buffer + const std::optional& remote_start_offset_buffer = zeroed_tensor_b; + EP_HOST_ASSERT(enable_v2 == remote_start_offset_buffer.has_value()); + if (enable_v2) { + EP_HOST_ASSERT(remote_start_offset_buffer->is_contiguous()); + EP_HOST_ASSERT(remote_start_offset_buffer->dim() == 1); + EP_HOST_ASSERT(remote_start_offset_buffer->size(0) == num_experts); + EP_HOST_ASSERT(remote_start_offset_buffer->dtype() == torch::kInt32); + EP_HOST_ASSERT(remote_start_offset_buffer->device().is_cuda()); + EP_HOST_ASSERT(remote_start_offset_buffer->stride(0) == 1); + EP_HOST_ASSERT(((int64_t)remote_start_offset_buffer->data_ptr()) % 16 == 0); // alignment + } + + if (enable_v2) { + EP_HOST_ASSERT(zeroed_buffer_for_atomic_counter_per_expert->is_contiguous()); + EP_HOST_ASSERT(zeroed_buffer_for_atomic_counter_per_expert->dim() == 1); + EP_HOST_ASSERT(zeroed_buffer_for_atomic_counter_per_expert->size(0) == num_experts); + EP_HOST_ASSERT(zeroed_buffer_for_atomic_counter_per_expert->dtype() == torch::kInt32); + EP_HOST_ASSERT(zeroed_buffer_for_atomic_counter_per_expert->device().is_cuda()); + EP_HOST_ASSERT(zeroed_buffer_for_atomic_counter_per_expert->stride(0) == 1); + EP_HOST_ASSERT(((int64_t)zeroed_buffer_for_atomic_counter_per_expert->data_ptr()) % 16 == 0); // alignment + } // Allocate column-majored scales auto packed_recv_x_scales = std::optional(); @@ -1161,17 +1376,37 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i } packed_recv_x_scales = torch::transpose(packed_recv_x_scales.value(), 1, 2); packed_recv_x_scales_ptr = packed_recv_x_scales->data_ptr(); + } else if (use_nvfp4) { + constexpr int kNumPerChannels = 16; + constexpr int NUM_SF_ELEMS_PER_PACK = 4; + constexpr int mTileSize_dim_0 = 32; + constexpr int mTileSize_dim_1 = 4; + constexpr int mTileSize = mTileSize_dim_0 * mTileSize_dim_1; + + auto l = num_local_experts; + auto m = num_ranks * num_max_dispatch_tokens_per_rank; + auto rm = (m + 127) / 128; + auto rk = hidden / (kNumPerChannels * NUM_SF_ELEMS_PER_PACK); + // The physical layout is (l, rm, rk, 32, 4, 4). + packed_recv_x_scales = torch::empty({l, rm, rk, 32, 4, 4}, + torch::dtype(torch::kInt8).device(torch::kCUDA)); + // After permute, the logical shape is (32, 4, rm, 4, rk, l) + packed_recv_x_scales = packed_recv_x_scales.value().permute({3, 4, 1, 5, 2, 0}); + + packed_recv_x_scales_ptr = packed_recv_x_scales->data_ptr(); +// packed_recv_x_sf_scale = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype(torch::kFloat32).device(torch::kCUDA)); +// packed_recv_x_sf_scale_ptr = packed_recv_x_sf_scale->data_ptr(); } // Kernel launch auto next_clean_meta = next_buffer.clean_meta(); auto launcher = [=](int phases) { - internode_ll::dispatch(packed_recv_x.data_ptr(), packed_recv_x_scales_ptr, + internode_ll::dispatch(enable_v2, packed_recv_x.data_ptr(), packed_recv_x_scales_ptr, packed_recv_src_info.data_ptr(), packed_recv_layout_range.data_ptr(), packed_recv_count.data_ptr(), cumulative_local_expert_recv_stats.has_value() ? cumulative_local_expert_recv_stats->data_ptr() : nullptr, dispatch_wait_recv_cost_stats.has_value() ? dispatch_wait_recv_cost_stats->data_ptr() : nullptr, - buffer.dispatch_rdma_recv_data_buffer, buffer.dispatch_rdma_recv_count_buffer, + buffer.dispatch_rdma_recv_data_buffer, buffer.dispatch_rdma_general_signal_buffer, buffer.dispatch_rdma_send_buffer, x.data_ptr(), topk_idx.data_ptr(), next_clean_meta.first, next_clean_meta.second, @@ -1179,7 +1414,16 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i num_topk, num_experts, rank, num_ranks, use_fp8, round_scale, use_ue8m0, workspace, num_device_sms, - launch_stream, phases); + launch_stream, phases, + use_nvfp4, + dst_signals.has_value() ? dst_signals->data_ptr() : nullptr, + count_per_expert.has_value() ? count_per_expert->data_ptr() : nullptr, + token_idx_and_dst_expert_and_dst_slot_idx_flat_list.has_value() ? token_idx_and_dst_expert_and_dst_slot_idx_flat_list->data_ptr() : nullptr, +// token_ids_of_expert.has_value() ? token_ids_of_expert->data_ptr() : nullptr, +// token_ids_of_expert.has_value() ? token_ids_of_expert->stride(0) : 0, + remote_start_offset_buffer.has_value() ? remote_start_offset_buffer->data_ptr() : nullptr, + zeroed_buffer_for_atomic_counter_per_expert.has_value() ? zeroed_buffer_for_atomic_counter_per_expert->data_ptr() : nullptr, + debug_tensor.has_value() ? debug_tensor->data_ptr() : nullptr); }; launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE)); @@ -1198,8 +1442,36 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i if (return_recv_hook) recv_hook = [=]() { launcher(LOW_LATENCY_RECV_PHASE); }; + using Consts = internode_ll::DispatchConstsTemplate; + const auto dim0 = num_local_experts; + const auto dim1 = num_ranks * num_max_dispatch_tokens_per_rank; + const auto dim2 = Consts::num_bytes_per_msg; + const auto returned_x = enable_v2 + // https://stackoverflow.com/questions/58631466/create-a-torchtensor-from-c-c-array-without-using-from-blob + // https://docs.pytorch.org/cppdocs/api/function_namespacetorch_1ac009244049812a3efdf4605d19c5e79b.html + ? torch::from_blob( + buffer.dispatch_rdma_recv_data_buffer, + // ref: LowLatencyLayout constructor `dispatch_recv_data_buffer_bytes` + {dim0, dim1, dim2}, + {(int)(dim1 * dim2), dim2, 1}, + torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA) + ).index({ + torch::indexing::Slice(), + torch::indexing::Slice(), + torch::indexing::Slice(sizeof(int4), sizeof(int4) + hidden / 2) + }) + : packed_recv_x; + if (enable_v2) { + // ref: packed_recv_x's shape etc + EP_HOST_ASSERT(returned_x.dim() == 3); + EP_HOST_ASSERT(returned_x.size(0) == num_local_experts); + EP_HOST_ASSERT(returned_x.size(1) == num_ranks * num_max_dispatch_tokens_per_rank); + EP_HOST_ASSERT(returned_x.size(2) == hidden / 2); + EP_HOST_ASSERT(returned_x.dtype() == torch::kUInt8); + } + // Return values - return {packed_recv_x, packed_recv_x_scales, packed_recv_count, packed_recv_src_info, packed_recv_layout_range, event, recv_hook}; + return {returned_x, packed_recv_x_scales, packed_recv_count, packed_recv_src_info, packed_recv_layout_range, event, recv_hook}; #else EP_HOST_ASSERT(false and "NVSHMEM is disabled during compilation"); return {}; @@ -1207,12 +1479,13 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i } std::tuple, std::optional>> -Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights, +Buffer::low_latency_combine(bool enable_v2, const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights, const torch::Tensor& src_info, const torch::Tensor& layout_range, const std::optional& combine_wait_recv_cost_stats, int num_max_dispatch_tokens_per_rank, int num_experts, bool use_logfmt, bool zero_copy, bool async, bool return_recv_hook, - const std::optional& out) { + const std::optional& out, + const std::optional& src_signals, uint32_t src_signal_expect_value) { #ifndef DISABLE_NVSHMEM EP_HOST_ASSERT(low_latency_mode); @@ -1271,7 +1544,7 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id // Kernel launch auto next_clean_meta = next_buffer.clean_meta(); auto launcher = [=](int phases) { - internode_ll::combine(combined_x.data_ptr(), + internode_ll::combine(enable_v2, combined_x.data_ptr(), buffer.combine_rdma_recv_data_buffer, buffer.combine_rdma_recv_flag_buffer, buffer.combine_rdma_send_buffer, x.data_ptr(), topk_idx.data_ptr(), topk_weights.data_ptr(), @@ -1282,7 +1555,8 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id num_topk, num_experts, rank, num_ranks, use_logfmt, workspace, num_device_sms, - launch_stream, phases, zero_copy); + launch_stream, phases, zero_copy, + src_signals.has_value() ? src_signals->data_ptr() : nullptr, src_signal_expect_value); }; launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE)); diff --git a/csrc/deep_ep.hpp b/csrc/deep_ep.hpp index aa62ccb0..d8ba1414 100644 --- a/csrc/deep_ep.hpp +++ b/csrc/deep_ep.hpp @@ -20,6 +20,33 @@ #define TORCH_EXTENSION_NAME deep_ep_cpp #endif +namespace shared_memory { + +union MemHandleInner { + cudaIpcMemHandle_t cuda_ipc_mem_handle; + CUmemFabricHandle cu_mem_fabric_handle; +}; + +struct MemHandle { + MemHandleInner inner; + size_t size; +}; + +constexpr size_t HANDLE_SIZE = sizeof(MemHandle); + +class SharedMemoryAllocator { +public: + SharedMemoryAllocator(); + void malloc(void** ptr, size_t size); + void free(void* ptr); + void get_mem_handle(MemHandle* mem_handle, void* ptr); + void open_mem_handle(void** ptr, MemHandle* mem_handle); + void close_mem_handle(void* ptr); +private: + bool enable_fabric; +}; +} + namespace deep_ep { struct Buffer { @@ -44,7 +71,7 @@ struct Buffer { int num_device_sms; int rank, rdma_rank, nvl_rank; int num_ranks, num_rdma_ranks, num_nvl_ranks; - cudaIpcMemHandle_t ipc_handles[NUM_MAX_NVL_PEERS]; + shared_memory::MemHandle ipc_handles[NUM_MAX_NVL_PEERS]; // Stream for communication at::cuda::CUDAStream comm_stream; @@ -76,6 +103,8 @@ struct Buffer { volatile int* moe_recv_rdma_counter = nullptr; int* moe_recv_rdma_counter_mapped = nullptr; + shared_memory::SharedMemoryAllocator shared_memory_allocator; + public: Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_bytes, bool low_latency_mode, bool explicitly_destroy); @@ -144,20 +173,28 @@ struct Buffer { void clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts); std::tuple, torch::Tensor, torch::Tensor, torch::Tensor, std::optional, std::optional>> - low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx, + low_latency_dispatch(bool enable_v2, const torch::Tensor& x, const torch::Tensor& topk_idx, const std::optional& cumulative_local_expert_recv_stats, const std::optional& dispatch_wait_recv_cost_stats, int num_max_dispatch_tokens_per_rank, int num_experts, bool use_fp8, bool round_scale, bool use_ue8m0, - bool async, bool return_recv_hook); + bool async, bool return_recv_hook, + const std::optional& zeroed_tensor_a, + const std::optional& zeroed_tensor_b, + const std::optional& zeroed_buffer_for_atomic_counter_per_expert, + bool use_nvfp4, + const std::optional& dst_signals, + const std::optional& count_per_expert, const std::optional& token_idx_and_dst_expert_and_dst_slot_idx_flat_list, + const std::optional& debug_tensor); std::tuple, std::optional>> - low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights, + low_latency_combine(bool enable_v2, const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights, const torch::Tensor& src_info, const torch::Tensor& layout_range, const std::optional& combine_wait_recv_cost_stats, int num_max_dispatch_tokens_per_rank, int num_experts, bool use_logfmt, bool zero_copy, bool async, bool return_recv_hook, - const std::optional& out = std::nullopt); + const std::optional& out = std::nullopt, + const std::optional& src_signals = std::nullopt, uint32_t src_signal_expect_value = 0); torch::Tensor get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts) const; diff --git a/csrc/kernels/api.cuh b/csrc/kernels/api.cuh index d34775fd..03f20ca7 100644 --- a/csrc/kernels/api.cuh +++ b/csrc/kernels/api.cuh @@ -139,21 +139,26 @@ void clean_low_latency_buffer(int* clean_0, int num_clean_int_0, int* clean_1, int num_clean_int_1, cudaStream_t stream); -void dispatch(void* packed_recv_x, void* packed_recv_x_scales, +void dispatch(bool enable_v2, void* packed_recv_x, void* packed_recv_x_scales, int* packed_recv_src_info, int64_t* packed_recv_layout_range, int* packed_recv_count, int* cumulative_local_expert_recv_stats, int64_t* dispatch_wait_recv_cost_stats, void* rdma_recv_x, int* rdma_recv_count, void* rdma_x, - const void* x, const int64_t* topk_idx, + void* x, const int64_t* topk_idx, // NOTE rm `const` of x int* next_clean, int num_next_clean_int, int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank, int num_topk, int num_experts, int rank, int num_ranks, bool use_fp8, bool round_scale, bool use_ue8m0, void* workspace, int num_device_sms, - cudaStream_t stream, int phases); - -void combine(void* combined_x, + cudaStream_t stream, int phases, + bool use_nvfp4, uint32_t* dst_signals, + uint32_t* count_per_expert, int64_t* token_idx_and_dst_expert_and_dst_slot_idx_flat_list, + int* remote_start_offset_buffer, + int* zeroed_buffer_for_atomic_counter_per_expert, + int* debug_tensor); + +void combine(bool enable_v2, void* combined_x, void* rdma_recv_x, int* rdma_recv_flag, void* rdma_send_x, const void* x, const int64_t* topk_idx, const float* topk_weights, const int* src_info, const int64_t* layout_range, @@ -163,7 +168,8 @@ void combine(void* combined_x, int num_topk, int num_experts, int rank, int num_ranks, bool use_logfmt, void* workspace, int num_device_sms, - cudaStream_t stream, int phases, bool zero_copy); + cudaStream_t stream, int phases, bool zero_copy, + uint32_t* src_signals, uint32_t src_signal_expect_value); } // namespace internode_ll diff --git a/csrc/kernels/exception.cuh b/csrc/kernels/exception.cuh index 7db0ddb7..b98cc2f2 100644 --- a/csrc/kernels/exception.cuh +++ b/csrc/kernels/exception.cuh @@ -31,6 +31,18 @@ do { \ } while (0) #endif +#ifndef CU_CHECK +#define CU_CHECK(cmd) \ +do { \ + CUresult e = (cmd); \ + if (e != CUDA_SUCCESS) { \ + const char *error_str = NULL; \ + cuGetErrorString(e, &error_str); \ + throw EPException("CU", __FILE__, __LINE__, std::string(error_str)); \ + } \ +} while (0) +#endif + #ifndef EP_HOST_ASSERT #define EP_HOST_ASSERT(cond) \ do { \ @@ -49,3 +61,6 @@ do { \ } \ } while (0) #endif + +#define EP_DEBUG_DEVICE_ASSERT(cond) EP_DEVICE_ASSERT(cond) +// #define EP_DEBUG_DEVICE_ASSERT(cond) do {} while (0) diff --git a/csrc/kernels/internode_ll.cu b/csrc/kernels/internode_ll.cu index 391a4b3d..e45d990b 100644 --- a/csrc/kernels/internode_ll.cu +++ b/csrc/kernels/internode_ll.cu @@ -3,6 +3,10 @@ #include "launch.cuh" #include "ibgda_device.cuh" +// temporary hack to put it into cuh +#include "internode_ll_common.cuh" +#include "internode_ll_v2.cuh" + namespace deep_ep { namespace internode_ll { @@ -334,19 +338,48 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, } } -void dispatch(void* packed_recv_x, void* packed_recv_x_scales, +void dispatch(bool enable_v2, void* packed_recv_x, void* packed_recv_x_scales, int* packed_recv_src_info, int64_t* packed_recv_layout_range, int* packed_recv_count, int* cumulative_local_expert_recv_stats, int64_t* dispatch_wait_recv_cost_stats, void* rdma_recv_x, int* rdma_recv_count, void* rdma_x, - const void* x, const int64_t* topk_idx, + void* x, const int64_t* topk_idx, // NOTE rm `const` of x int* next_clean, int num_next_clean_int, int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank, int num_topk, int num_experts, int rank, int num_ranks, bool use_fp8, bool round_scale, bool use_ue8m0, void* workspace, int num_device_sms, - cudaStream_t stream, int phases) { + cudaStream_t stream, int phases, + bool use_nvfp4, uint32_t* dst_signals, + uint32_t* count_per_expert, int64_t* token_idx_and_dst_expert_and_dst_slot_idx_flat_list, + int* remote_start_offset_buffer, + int* zeroed_buffer_for_atomic_counter_per_expert, + int* debug_tensor) { + if (enable_v2) { + return dispatch_v2( + packed_recv_x, packed_recv_x_scales, + packed_recv_src_info, packed_recv_layout_range, + packed_recv_count, + cumulative_local_expert_recv_stats, + dispatch_wait_recv_cost_stats, + rdma_recv_x, rdma_recv_count, + // rdma_x, // NOTE removed + x, topk_idx, // NOTE rm `const` of x + next_clean, num_next_clean_int, + num_tokens, hidden, num_max_dispatch_tokens_per_rank, + num_topk, num_experts, rank, num_ranks, + use_fp8, round_scale, use_ue8m0, + workspace, num_device_sms, + stream, phases, + use_nvfp4, dst_signals, + count_per_expert, token_idx_and_dst_expert_and_dst_slot_idx_flat_list, + remote_start_offset_buffer, + zeroed_buffer_for_atomic_counter_per_expert, + debug_tensor + ); + } + constexpr int kNumMaxTopK = 9; const int num_warp_groups = ceil_div(num_experts, num_device_sms); const int num_warps_per_group = 32 / num_warp_groups; @@ -392,165 +425,6 @@ LAUNCH_KERNEL(&cfg, dispatch_func, \ #undef DISPATCH_LAUNCH_CASE } -template -__forceinline__ __device__ int logfmt_encode(void* buffer, nv_bfloat162 *shared_amaxmin, const int& lane_id) { - constexpr int kNumElemsPerInt4 = sizeof(int4) / sizeof(nv_bfloat16); - constexpr float kLogThreshold = 0; - constexpr float kMinClip = 32; // `== log_2(2 ^ (2 ^ 5))` - constexpr int kNumBits = 10; - constexpr int kNumValues = 1 << (kNumBits - 1); - - int4 int4_values[kNumSendUnrolls]; - const auto& uint32_values = reinterpret_cast(int4_values); - const auto& bf162_values = reinterpret_cast(int4_values); - - // Calculate lane offset - const auto& ld_buffer = reinterpret_cast(static_cast(buffer) + lane_id * (kNumSendUnrolls * sizeof(int4))); - const auto& st_buffer = reinterpret_cast(static_cast(buffer) + lane_id * (kNumSendUnrolls * sizeof(int4) * 10 / 16)); - - // Local log amax - auto bf162_amax = __nv_bfloat162(CUDART_ZERO_BF16, CUDART_ZERO_BF16); - auto bf162_amin = __nv_bfloat162(CUDART_INF_BF16, CUDART_INF_BF16); - uint32_t local_signs = 0; - #pragma unroll - for (int k = 0; k < kNumSendUnrolls * kNumElemsPerInt4 / 2; ++ k) { - // TODO: eliminate bank conflicts - uint32_values[k] = ld_buffer[k]; - local_signs |= ((uint32_values[k] >> 15) & 1) << (k * 2); - local_signs |= ((uint32_values[k] >> 31) & 1) << (k * 2 + 1); - uint32_values[k] &= 0x7fff7fff; - - bf162_amax = __hmax2(bf162_amax, bf162_values[k]); - bf162_amin = __hmin2(bf162_amin, bf162_values[k]); - } - - // Reduce per 128 channels - // TODO: figure out how hardware do 2-byte min/max - auto amax = std::max(static_cast(bf162_amax.x), static_cast(bf162_amax.y)); - auto amin = std::min(static_cast(bf162_amin.x), static_cast(bf162_amin.y)); - constexpr static int kNumLanesToReduce = 128 * sizeof(nv_bfloat16) / (kNumSendUnrolls * sizeof(int4)); - amax = warp_reduce_max(amax); - amin = warp_reduce_min(amin); - - // Write min/max into the shared memory - if (shared_amaxmin != nullptr) - *shared_amaxmin = __nv_bfloat162(amax, amin); - __syncwarp(); - - // Calculate log amin/amax float - const auto& log_amax = log2f_approx(amax); - const auto& log_amin = fmaxf(log2f_approx(amin), log_amax - kMinClip); - const bool& enable_cast = warp_reduce_and(log_amax < kLogThreshold and log_amin < log_amax); - - // Case into LogFMT-10 if satisfied - if (enable_cast) { - const auto step = (log_amax - log_amin) / static_cast(kNumValues - 2); - const auto step_inv = 1.0f / step; - const auto rounding = 2.0f - log2f_approx((1.0f + exp2f_approx(step)) * 0.5f) * step_inv; - const auto fused_rounding = rounding - log_amin * step_inv; - - // Pack every 256 bits into 160 bits - EP_STATIC_ASSERT(kNumSendUnrolls == 2 or kNumSendUnrolls == 4, "kNumSendUnrolls == 2 or 4 only"); - uint32_t encoded[kNumElemsPerInt4 * 2]; - #pragma unroll 1 - for (int i = 0; i < kNumSendUnrolls / 2; ++ i) { - #pragma unroll - for (int k = 0; k < kNumElemsPerInt4; ++ k) { - const auto& [x, y] = __bfloat1622float2(bf162_values[i * kNumElemsPerInt4 + k]); - encoded[k * 2 + 0] = __float2uint_rd(fmaxf(log2f_approx(x) * step_inv + fused_rounding, 0)); - encoded[k * 2 + 1] = __float2uint_rd(fmaxf(log2f_approx(y) * step_inv + fused_rounding, 0)); - } - st_buffer[i * 5 + 0] = (encoded[ 0] >> 0) | (encoded[ 1] << 9) | (encoded[ 2] << 18) | (encoded[ 3] << 27); - st_buffer[i * 5 + 1] = (encoded[ 3] >> 5) | (encoded[ 4] << 4) | (encoded[ 5] << 13) | (encoded[ 6] << 22) | (encoded[7] << 31); - st_buffer[i * 5 + 2] = (encoded[ 7] >> 1) | (encoded[ 8] << 8) | (encoded[ 9] << 17) | (encoded[10] << 26); - st_buffer[i * 5 + 3] = (encoded[10] >> 6) | (encoded[11] << 3) | (encoded[12] << 12) | (encoded[13] << 21) | (encoded[14] << 30); - st_buffer[i * 5 + 4] = (encoded[14] >> 2) | (encoded[15] << 7) | ((i == 0) ? (local_signs << 16) : (local_signs & 0xffff0000u)); - } - tma_store_fence(); - __syncwarp(); - } - - // Return TMA copy bytes - return enable_cast ? (32 * (kNumSendUnrolls * sizeof(int4) * 8 * 10 / 16 / 8)): - (32 * (kNumSendUnrolls * sizeof(int4))); -} - -template -__forceinline__ __device__ void logfmt_check_amaxmin(uint8_t* meta_buffer, float2* shared_log_amax, - float2* shared_log_amin, int* shared_cast_info, - const int lane_id) { - constexpr float kLogThreshold = 0; - constexpr float kMinClip = 32; // `== log_2(2 ^ (2 ^ 5))` - - bool enable_cast = true; - if (lane_id < kNumLanes) { - // Calculate log amin/amax float - auto amaxmin2 = reinterpret_cast(meta_buffer)[lane_id]; - const auto& bf162_amaxmin = reinterpret_cast<__nv_bfloat162*>(&amaxmin2); - float log_amax[2], log_amin[2]; - #pragma unroll - for (int i = 0; i < 2; ++ i) { - auto amax = static_cast(bf162_amaxmin[i].x); - auto amin = static_cast(bf162_amaxmin[i].y); - log_amax[i] = log2f_approx(amax); - log_amin[i] = amin == 0 ? log_amax[i] - kMinClip : fmaxf(log2f_approx(amin), log_amax[i] - kMinClip); - enable_cast = enable_cast and log_amax[i] < kLogThreshold and log_amin[i] < log_amax[i]; - } - shared_log_amax[lane_id] = make_float2(log_amax[0], log_amax[1]); - shared_log_amin[lane_id] = make_float2(log_amin[0], log_amin[1]); - } - - const auto& casted = warp_reduce_and(enable_cast) ? 1u << (lane_id / kNumRecvUnrolls): 0u; - const auto& num_casted_prefix = __popc(warp_reduce_or(casted) & ((1u << (lane_id / kNumRecvUnrolls)) - 1)); - - if (lane_id < kNumLanes and lane_id % kNumRecvUnrolls == 0) - shared_cast_info[lane_id / kNumRecvUnrolls] = (num_casted_prefix << 1) | (casted ? 1u : 0u); - __syncwarp(); -} - -template -__forceinline__ __device__ void decode_and_accumulate(uint32_t* ld_buffer, float* accum, - const float& log_amax, const float& log_amin, - const bool& enable_cast, const float& weight) { - if (enable_cast) { - constexpr int kNumBits = 10; - constexpr int kNumValues = 1 << (kNumBits - 1); - - const auto& step = (log_amax - log_amin) / static_cast(kNumValues - 2); - auto decode = [=](const uint32_t &encoded, const uint32_t &sign) { - const auto decoded = encoded == 0 ? .0f : exp2f_approx((encoded - 1) * step + log_amin); - return sign ? -decoded : decoded; - }; - - EP_STATIC_ASSERT(kNumRecvUnrolls == 2 or kNumRecvUnrolls == 4, "kNumRecvUnrolls == 2 or 4 only"); - #pragma unroll - for (int i = 0; i < kNumRecvUnrolls / 2; ++ i) { - uint32_t concat[6]; - concat[0] = ld_buffer[i * 5]; - #pragma unroll - for (int k = 1; k < 5; ++ k) - concat[k] = (ld_buffer[i * 5 + k - 1] >> (32 - k * 5)) | (ld_buffer[i * 5 + k] << (k * 5)); - concat[5] = ld_buffer[i * 5 + 4] >> 7; - - const uint32_t& local_signs = ld_buffer[i * 5 + 4] >> 16; - #pragma unroll - for (int k = 0; k < 5; ++ k) { - accum[i * 16 + k * 3 + 0] += decode((concat[k] >> 0) & 0x1ff, (local_signs >> (k * 3 + 0)) & 1) * weight; - accum[i * 16 + k * 3 + 1] += decode((concat[k] >> 9) & 0x1ff, (local_signs >> (k * 3 + 1)) & 1) * weight; - accum[i * 16 + k * 3 + 2] += decode((concat[k] >> 18) & 0x1ff, (local_signs >> (k * 3 + 2)) & 1) * weight; - } - accum[i * 16 + 15] += decode(concat[5] & 0x1ff, (local_signs >> 15) & 1) * weight; - } - } else { - #pragma unroll - for (int k = 0; k < kNumRecvUnrolls * 4; ++ k) { - auto bf16_pack = *reinterpret_cast<__nv_bfloat162*>(ld_buffer + k); - accum[k * 2 + 0] += static_cast(bf16_pack.x) * weight; - accum[k * 2 + 1] += static_cast(bf16_pack.y) * weight; - } - } -} - template __global__ __launch_bounds__(1024, 1) void combine(void* combined_x, @@ -917,7 +791,7 @@ combine(void* combined_x, } } -void combine(void* combined_x, +void combine(bool enable_v2, void* combined_x, void* rdma_recv_x, int* rdma_recv_flag, void* rdma_send_x, const void* x, const int64_t* topk_idx, const float* topk_weights, const int* src_info, const int64_t* layout_range, @@ -927,7 +801,25 @@ void combine(void* combined_x, int num_topk, int num_experts, int rank, int num_ranks, bool use_logfmt, void* workspace, int num_device_sms, - cudaStream_t stream, int phases, bool zero_copy) { + cudaStream_t stream, int phases, bool zero_copy, + uint32_t* src_signals, uint32_t src_signal_expect_value) { + if (enable_v2) { + return combine_v2( + combined_x, + rdma_recv_x, rdma_recv_flag, rdma_send_x, + x, topk_idx, topk_weights, + src_info, layout_range, + combine_wait_recv_cost_stats, + next_clean, num_next_clean_int, + num_combined_tokens, hidden, num_max_dispatch_tokens_per_rank, + num_topk, num_experts, rank, num_ranks, + use_logfmt, + workspace, num_device_sms, + stream, phases, zero_copy, + src_signals, src_signal_expect_value + ); + } + constexpr int kNumMaxTopk = 9; const int num_warp_groups = ceil_div(num_experts, num_device_sms); const int num_warps_per_group = 32 / num_warp_groups; diff --git a/csrc/kernels/internode_ll_common.cuh b/csrc/kernels/internode_ll_common.cuh new file mode 100644 index 00000000..25962dcd --- /dev/null +++ b/csrc/kernels/internode_ll_common.cuh @@ -0,0 +1,169 @@ +#include "configs.cuh" +#include "exception.cuh" +#include "launch.cuh" +#include "ibgda_device.cuh" + +namespace deep_ep { +namespace internode_ll { + +template +__forceinline__ __device__ int logfmt_encode(void* buffer, nv_bfloat162 *shared_amaxmin, const int& lane_id) { + constexpr int kNumElemsPerInt4 = sizeof(int4) / sizeof(nv_bfloat16); + constexpr float kLogThreshold = 0; + constexpr float kMinClip = 32; // `== log_2(2 ^ (2 ^ 5))` + constexpr int kNumBits = 10; + constexpr int kNumValues = 1 << (kNumBits - 1); + + int4 int4_values[kNumSendUnrolls]; + const auto& uint32_values = reinterpret_cast(int4_values); + const auto& bf162_values = reinterpret_cast(int4_values); + + // Calculate lane offset + const auto& ld_buffer = reinterpret_cast(static_cast(buffer) + lane_id * (kNumSendUnrolls * sizeof(int4))); + const auto& st_buffer = reinterpret_cast(static_cast(buffer) + lane_id * (kNumSendUnrolls * sizeof(int4) * 10 / 16)); + + // Local log amax + auto bf162_amax = __nv_bfloat162(CUDART_ZERO_BF16, CUDART_ZERO_BF16); + auto bf162_amin = __nv_bfloat162(CUDART_INF_BF16, CUDART_INF_BF16); + uint32_t local_signs = 0; + #pragma unroll + for (int k = 0; k < kNumSendUnrolls * kNumElemsPerInt4 / 2; ++ k) { + // TODO: eliminate bank conflicts + uint32_values[k] = ld_buffer[k]; + local_signs |= ((uint32_values[k] >> 15) & 1) << (k * 2); + local_signs |= ((uint32_values[k] >> 31) & 1) << (k * 2 + 1); + uint32_values[k] &= 0x7fff7fff; + + bf162_amax = __hmax2(bf162_amax, bf162_values[k]); + bf162_amin = __hmin2(bf162_amin, bf162_values[k]); + } + + // Reduce per 128 channels + // TODO: figure out how hardware do 2-byte min/max + auto amax = std::max(static_cast(bf162_amax.x), static_cast(bf162_amax.y)); + auto amin = std::min(static_cast(bf162_amin.x), static_cast(bf162_amin.y)); + constexpr static int kNumLanesToReduce = 128 * sizeof(nv_bfloat16) / (kNumSendUnrolls * sizeof(int4)); + amax = warp_reduce_max(amax); + amin = warp_reduce_min(amin); + + // Write min/max into the shared memory + if (shared_amaxmin != nullptr) + *shared_amaxmin = __nv_bfloat162(amax, amin); + __syncwarp(); + + // Calculate log amin/amax float + const auto& log_amax = log2f_approx(amax); + const auto& log_amin = fmaxf(log2f_approx(amin), log_amax - kMinClip); + const bool& enable_cast = warp_reduce_and(log_amax < kLogThreshold and log_amin < log_amax); + + // Case into LogFMT-10 if satisfied + if (enable_cast) { + const auto step = (log_amax - log_amin) / static_cast(kNumValues - 2); + const auto step_inv = 1.0f / step; + const auto rounding = 2.0f - log2f_approx((1.0f + exp2f_approx(step)) * 0.5f) * step_inv; + const auto fused_rounding = rounding - log_amin * step_inv; + + // Pack every 256 bits into 160 bits + EP_STATIC_ASSERT(kNumSendUnrolls == 2 or kNumSendUnrolls == 4, "kNumSendUnrolls == 2 or 4 only"); + uint32_t encoded[kNumElemsPerInt4 * 2]; + #pragma unroll 1 + for (int i = 0; i < kNumSendUnrolls / 2; ++ i) { + #pragma unroll + for (int k = 0; k < kNumElemsPerInt4; ++ k) { + const auto& [x, y] = __bfloat1622float2(bf162_values[i * kNumElemsPerInt4 + k]); + encoded[k * 2 + 0] = __float2uint_rd(fmaxf(log2f_approx(x) * step_inv + fused_rounding, 0)); + encoded[k * 2 + 1] = __float2uint_rd(fmaxf(log2f_approx(y) * step_inv + fused_rounding, 0)); + } + st_buffer[i * 5 + 0] = (encoded[ 0] >> 0) | (encoded[ 1] << 9) | (encoded[ 2] << 18) | (encoded[ 3] << 27); + st_buffer[i * 5 + 1] = (encoded[ 3] >> 5) | (encoded[ 4] << 4) | (encoded[ 5] << 13) | (encoded[ 6] << 22) | (encoded[7] << 31); + st_buffer[i * 5 + 2] = (encoded[ 7] >> 1) | (encoded[ 8] << 8) | (encoded[ 9] << 17) | (encoded[10] << 26); + st_buffer[i * 5 + 3] = (encoded[10] >> 6) | (encoded[11] << 3) | (encoded[12] << 12) | (encoded[13] << 21) | (encoded[14] << 30); + st_buffer[i * 5 + 4] = (encoded[14] >> 2) | (encoded[15] << 7) | ((i == 0) ? (local_signs << 16) : (local_signs & 0xffff0000u)); + } + tma_store_fence(); + __syncwarp(); + } + + // Return TMA copy bytes + return enable_cast ? (32 * (kNumSendUnrolls * sizeof(int4) * 8 * 10 / 16 / 8)): + (32 * (kNumSendUnrolls * sizeof(int4))); +} + +template +__forceinline__ __device__ void logfmt_check_amaxmin(uint8_t* meta_buffer, float2* shared_log_amax, + float2* shared_log_amin, int* shared_cast_info, + const int lane_id) { + constexpr float kLogThreshold = 0; + constexpr float kMinClip = 32; // `== log_2(2 ^ (2 ^ 5))` + + bool enable_cast = true; + if (lane_id < kNumLanes) { + // Calculate log amin/amax float + auto amaxmin2 = reinterpret_cast(meta_buffer)[lane_id]; + const auto& bf162_amaxmin = reinterpret_cast<__nv_bfloat162*>(&amaxmin2); + float log_amax[2], log_amin[2]; + #pragma unroll + for (int i = 0; i < 2; ++ i) { + auto amax = static_cast(bf162_amaxmin[i].x); + auto amin = static_cast(bf162_amaxmin[i].y); + log_amax[i] = log2f_approx(amax); + log_amin[i] = amin == 0 ? log_amax[i] - kMinClip : fmaxf(log2f_approx(amin), log_amax[i] - kMinClip); + enable_cast = enable_cast and log_amax[i] < kLogThreshold and log_amin[i] < log_amax[i]; + } + shared_log_amax[lane_id] = make_float2(log_amax[0], log_amax[1]); + shared_log_amin[lane_id] = make_float2(log_amin[0], log_amin[1]); + } + + const auto& casted = warp_reduce_and(enable_cast) ? 1u << (lane_id / kNumRecvUnrolls): 0u; + const auto& num_casted_prefix = __popc(warp_reduce_or(casted) & ((1u << (lane_id / kNumRecvUnrolls)) - 1)); + + if (lane_id < kNumLanes and lane_id % kNumRecvUnrolls == 0) + shared_cast_info[lane_id / kNumRecvUnrolls] = (num_casted_prefix << 1) | (casted ? 1u : 0u); + __syncwarp(); +} + +template +__forceinline__ __device__ void decode_and_accumulate(uint32_t* ld_buffer, float* accum, + const float& log_amax, const float& log_amin, + const bool& enable_cast, const float& weight) { + if (enable_cast) { + constexpr int kNumBits = 10; + constexpr int kNumValues = 1 << (kNumBits - 1); + + const auto& step = (log_amax - log_amin) / static_cast(kNumValues - 2); + auto decode = [=](const uint32_t &encoded, const uint32_t &sign) { + const auto decoded = encoded == 0 ? .0f : exp2f_approx((encoded - 1) * step + log_amin); + return sign ? -decoded : decoded; + }; + + EP_STATIC_ASSERT(kNumRecvUnrolls == 2 or kNumRecvUnrolls == 4, "kNumRecvUnrolls == 2 or 4 only"); + #pragma unroll + for (int i = 0; i < kNumRecvUnrolls / 2; ++ i) { + uint32_t concat[6]; + concat[0] = ld_buffer[i * 5]; + #pragma unroll + for (int k = 1; k < 5; ++ k) + concat[k] = (ld_buffer[i * 5 + k - 1] >> (32 - k * 5)) | (ld_buffer[i * 5 + k] << (k * 5)); + concat[5] = ld_buffer[i * 5 + 4] >> 7; + + const uint32_t& local_signs = ld_buffer[i * 5 + 4] >> 16; + #pragma unroll + for (int k = 0; k < 5; ++ k) { + accum[i * 16 + k * 3 + 0] += decode((concat[k] >> 0) & 0x1ff, (local_signs >> (k * 3 + 0)) & 1) * weight; + accum[i * 16 + k * 3 + 1] += decode((concat[k] >> 9) & 0x1ff, (local_signs >> (k * 3 + 1)) & 1) * weight; + accum[i * 16 + k * 3 + 2] += decode((concat[k] >> 18) & 0x1ff, (local_signs >> (k * 3 + 2)) & 1) * weight; + } + accum[i * 16 + 15] += decode(concat[5] & 0x1ff, (local_signs >> 15) & 1) * weight; + } + } else { + #pragma unroll + for (int k = 0; k < kNumRecvUnrolls * 4; ++ k) { + auto bf16_pack = *reinterpret_cast<__nv_bfloat162*>(ld_buffer + k); + accum[k * 2 + 0] += static_cast(bf16_pack.x) * weight; + accum[k * 2 + 1] += static_cast(bf16_pack.y) * weight; + } + } +} + +} // namespace internode_ll +} // namespace deep_ep diff --git a/csrc/kernels/internode_ll_v2.cuh b/csrc/kernels/internode_ll_v2.cuh new file mode 100644 index 00000000..a79e7a3a --- /dev/null +++ b/csrc/kernels/internode_ll_v2.cuh @@ -0,0 +1,1477 @@ +#include "configs.cuh" +#include "exception.cuh" +#include "launch.cuh" +#include "ibgda_device.cuh" + +#include "internode_ll_v2_inc.cuh" + +constexpr int DST_SIGNAL_EXPECT_VALUE = 1000000; + +namespace deep_ep { +namespace internode_ll { + +constexpr int kNumMaxWarpGroups = 32; + +#define ENABLE_DEBUG_TIMING_TENSOR 0 + +#if ENABLE_DEBUG_TIMING_TENSOR +constexpr int DT_MAX_NUM_EVENT_GROUPS = 10; +constexpr int DT_MAX_NUM_EVENTS_PER_GROUP = 100; +constexpr int DT_MAX_NUM_MODES = 2; +constexpr int DT_MAX_NUM_SMS = 200; +constexpr int DT_MAX_NUM_WARPS_PER_SM = 100; +__forceinline__ __device__ void write_debug_time( + int* debug_tensor, + uint32_t t_start, + int event_group_id, + int event_id, + int mode_id, + int sm_id, + int warp_id +) { + if (get_lane_id() == 0) { + uint32_t t_delta = ((uint32_t)clock()) - t_start; + + int idx = ( + event_group_id * (DT_MAX_NUM_EVENTS_PER_GROUP * DT_MAX_NUM_MODES * DT_MAX_NUM_SMS * DT_MAX_NUM_WARPS_PER_SM) + + event_id * (DT_MAX_NUM_MODES * DT_MAX_NUM_SMS * DT_MAX_NUM_WARPS_PER_SM) + + mode_id * (DT_MAX_NUM_SMS * DT_MAX_NUM_WARPS_PER_SM) + + sm_id * (DT_MAX_NUM_WARPS_PER_SM) + + warp_id + ); + + debug_tensor[idx] = t_delta; + } +} +#endif + +template +__forceinline__ __device__ void dispatch_send( + int subroutine_thread_id, int num_warp_groups, + + // copied args + void* packed_recv_x, void* packed_recv_x_scales, + int* packed_recv_src_info, int64_t* packed_recv_layout_range, + int* packed_recv_count, + int* cumulative_local_expert_recv_stats, + int64_t* dispatch_wait_recv_cost_stats, + void* rdma_recv_x, + // int* rdma_recv_count, // NOTE removed + // void* rdma_x, // NOTE removed + void* x, const int64_t* topk_idx, // NOTE rm `const` of x + int* atomic_counter_per_expert, + // int* atomic_finish_counter_per_expert, // NOTE removed + int* next_clean, int num_next_clean_int, + int num_tokens, int num_max_dispatch_tokens_per_rank, + int num_topk, int num_experts, int rank, int num_ranks, + // int num_send_warp_groups, int num_recv_warp_groups, // NOTE removed + int num_warps_per_group, + bool round_scale, int phases, + uint32_t* dst_signals, + uint32_t* count_per_expert, int64_t* token_idx_and_dst_expert_and_dst_slot_idx_flat_list, + int64_t* layout_range_buffer, int* negotiate_offset_of_expert_buffer, int* remote_start_offset_buffer, + int* debug_tensor +) { + uint32_t t_start = clock(); + + using Consts = DispatchConstsTemplate; + EP_DEVICE_ASSERT(Consts::num_bytes_per_msg % sizeof(int4) == 0); + + // NOTE copied from dispatch body + const auto sm_id = static_cast(blockIdx.x); + const auto num_sms = static_cast(gridDim.x); + const auto warp_id = subroutine_thread_id / 32, lane_id = get_lane_id(); + const auto num_warps = num_warp_groups * num_warps_per_group; + const auto num_local_experts = num_experts / num_ranks; + // unused + // const auto warp_group_id = warp_id / num_warps_per_group; + // const auto sub_warp_id = warp_id % num_warps_per_group; + + // NOTE removed + // const auto responsible_expert_idx = sm_id * num_warp_groups + warp_group_id; + // Expert counts + // __shared__ int shared_num_tokens_sent_per_expert[kNumMaxWarpGroups]; + +// if (subroutine_thread_id % 32 == 0) { printf("[R%d,S%d,T%d] dispatch_send START\n", rank, sm_id, subroutine_thread_id); } + + if ((sm_id == 0) and (warp_id == 0)) { + // The first SM is also responsible for cleaning the next buffer + #pragma unroll + for (int i = lane_id; i < num_next_clean_int; i += 32) + next_clean[i] = 0; + + // TODO do we really need this? since `next_clean` will be used only in the next round of kernels + // not needed in per-token signal approach +// // Notify before executing `int_p` +// __syncwarp(); +// #pragma unroll +// for (int i = lane_id; i < num_experts; i += 32) +// atomic_add_release_global(atomic_finish_counter_per_expert + i, FINISHED_SUM_TAG); + } + + // Reserve remote locations + { + EP_DEVICE_ASSERT(num_ranks <= num_sms); + EP_DEVICE_ASSERT(num_local_experts <= num_warps * 32); + const int dst_rank = sm_id; + const int dst_expert_local_idx = subroutine_thread_id; + + if ((dst_rank < num_ranks) and (dst_expert_local_idx < num_local_experts)) { + const auto dst_global_expert_idx = dst_rank * num_local_experts + dst_expert_local_idx; + + const int num_tokens_to_send = count_per_expert[dst_global_expert_idx]; + + // 1. Compete to get a range of locations to set data to + // TODO maybe do not need `release` (but yes need `sys`) + int remote_start_offset; + { + const auto dst_ptr = reinterpret_cast(negotiate_offset_of_expert_buffer); + const auto dst_p2p_ptr = reinterpret_cast(nvshmemi_get_p2p_ptr(dst_ptr, rank, dst_rank)); + remote_start_offset = atomic_add_release_sys_global(dst_p2p_ptr + dst_expert_local_idx, num_tokens_to_send); + } + + // 2. Write metadata to remote + // TODO is this strong enough + { + const auto dst_ptr = reinterpret_cast(layout_range_buffer + dst_expert_local_idx * num_ranks + rank); + const auto dst_p2p_ptr = reinterpret_cast(nvshmemi_get_p2p_ptr(dst_ptr, rank, dst_rank)); + const auto raw_val = pack2(num_tokens_to_send, remote_start_offset); + // TODO use which? + // st_volatile_global(dst_p2p_ptr, -raw_val-1); + *dst_p2p_ptr = -raw_val-1; + +// printf("[R%d,S%d,T%d] st-layout dst_ptr=%lld delta_addr=%d raw_val=%lld\n", +// rank, sm_id, subroutine_thread_id, dst_ptr, (int) (((uint64_t)dst_ptr) - ((uint64_t)layout_range_buffer)), raw_val); + } + + // 2. Write metadata to local + // TODO is this strong enough + remote_start_offset_buffer[dst_global_expert_idx] = -remote_start_offset-1; + } + } + + // There are 2 kinds of warps in this part: + // 1. The first-kind warps for FP8 cast and sending top-k tokens + // 2. The last warp for reading `topk_idx` and count for per-expert information + + // NOTE remove the last warp (and thus the if) + // if (warp_id < num_warps - 1) { + + constexpr int kNumElemsPerRead = sizeof(int4) / sizeof(nv_bfloat16); + EP_STATIC_ASSERT(kHidden % (32 * kNumElemsPerRead) == 0, "Invalid hidden"); + EP_STATIC_ASSERT(kNumElemsPerRead * 32 % Consts::kNumPerChannels == 0, "Invalid vectorization"); + + // NOTE no need "-1" b/c we do not reserve one warp for counting anymore + // const auto num_threads = (num_warps - 1) * 32; + // const auto num_threads = num_warps * 32; // not used + + // unused + // const size_t hidden_bf16_int4 = kHidden / kNumElemsPerRead; + + // NOTE + // before: one SM = one token, one warp = one dst rank of that token, only use first 8 warps of the SM (?) + // after: flatten all (warp_id, sm_id), + // then one warp = one pseudo_token_idx (i.e. one dst rank of one token) + // + // NOTE: deliberately be (warp_id, sm_id) instead of (sm_id, warp_id) + // to allow work be distributed to all SMs when few work + // TODO is these ordering suboptimal for nvlink write or gmem read? + // TODO may use multi warp to send one token + const int flat_worker_id = warp_id * num_sms + sm_id; + const int flat_worker_num = num_warps * num_sms; + for ( + // "tesfl" := "token_idx_and_dst_expert_and_dst_slot_idx_flat_list" + int tesfl_idx = flat_worker_id, debug_iter_idx = 0; + tesfl_idx < num_tokens * num_topk; + tesfl_idx += flat_worker_num, debug_iter_idx += 1 + ) { +// if (subroutine_thread_id % 32 == 0) { printf("[R%d,S%d,T%d] dispatch_send tesfl_idx=%d START \n", rank, sm_id, subroutine_thread_id, tesfl_idx); } +#if ENABLE_DEBUG_TIMING_TENSOR + write_debug_time( + debug_tensor, t_start, + /* event_group_id */ 0, + /* event_id */ debug_iter_idx, + /* mode_id */ 0, + sm_id, warp_id + ); +#endif + + // TODO do prefetching if needed + // NOTE ldg is for read-only data cache, if token_idx_and_dst_expert_and_dst_slot_idx_flat_list is somehow overlapped in the future we should change it + const auto token_idx_and_dst_expert_and_slot_idx = __ldg(token_idx_and_dst_expert_and_dst_slot_idx_flat_list + tesfl_idx); + const auto ptr = (int16_t*) &token_idx_and_dst_expert_and_slot_idx; + const int token_idx = ptr[0], dst_expert_idx = ptr[1], slot_idx = ptr[2]; + // if (subroutine_thread_id % 32 == 0) { printf("[R%d,S%d,T%d] dispatch_send tesfl_idx=%d token_idx=%d dst_expert_idx=%d slot_idx=%d \n", + // rank, sm_id, subroutine_thread_id, tesfl_idx, token_idx, dst_expert_idx, slot_idx); } + // const auto dst_rank = dst_expert_idx / num_local_experts; + + // TODO can speedup by prefetching, delayed checking, etc + // TODO is this load strong enough? + int remote_start_offset; + while ((remote_start_offset = ld_volatile_global(remote_start_offset_buffer + dst_expert_idx)) == 0); + remote_start_offset = -remote_start_offset - 1; + +#if ENABLE_DEBUG_TIMING_TENSOR + write_debug_time( + debug_tensor, t_start, + /* event_group_id */ 1, + /* event_id */ debug_iter_idx, + /* mode_id */ 0, + sm_id, warp_id + ); +#endif + + // NOTE changed, see "before-after" above + // for (int token_idx = sm_id; token_idx < num_tokens; token_idx += num_sms) { + + // const auto x_int4 = static_cast(x) + token_idx * hidden_bf16_int4; + + // NOTE do not use `rdma_x` but use `x` + // const auto rdma_x_src_idx = reinterpret_cast(static_cast(rdma_x) + token_idx * Consts::num_bytes_per_msg); + const auto x_src_idx = reinterpret_cast(reinterpret_cast(x) + token_idx * Consts::num_bytes_per_msg); + + // const auto rdma_x_vec = reinterpret_cast(reinterpret_cast(rdma_x_src_idx) + sizeof(int4)); + // const auto rdma_x_scales = reinterpret_cast(reinterpret_cast(rdma_x_vec) + Consts::hidden_bytes); + + // Overlap top-k index read and source token index writes + // NOTE the parallel strategy is changed + // auto dst_expert_idx = warp_id < num_topk ? static_cast(__ldg(topk_idx + token_idx * num_topk + warp_id)) : -1; + + // NOTE (0828) require users to set this value + // NOTE do not use `rdma_x` but use `x` + // NOTE use lane_id instead of local_thread id + // NOTE and the new code will write `x_src_idx` *MULTIPLE* times w/ same value, thus wasting but correct + // subroutine_thread_id == 0 ? (*rdma_x_src_idx = token_idx) : 0; + // lane_id == 0 ? (*x_src_idx = token_idx) : 0; + + // NOTE no read or cast in fp4 + // FP8 cast +// EP_STATIC_ASSERT(hidden_bf16_int4 % 32 == 0, "Must use the full warp to reduce"); +// #pragma unroll +// for (int i = subroutine_thread_id; i < hidden_bf16_int4; i += num_threads) { +// // Read +// auto int4_value = __ldg(x_int4 + i); +// +// if constexpr (kUseFP8) { +// // Calculate local amax +// auto bf16_values = reinterpret_cast(&int4_value); +// float fp32_values[kNumElemsPerRead]; +// float amax = kFP8Margin, scale, scale_inv; +// #pragma unroll +// for (int j = 0; j < kNumElemsPerRead; ++ j) { +// fp32_values[j] = static_cast(bf16_values[j]); +// amax = fmaxf(amax, fabsf(fp32_values[j])); +// } +// +// // Reduce amax and scale +// EP_STATIC_ASSERT(kNumElemsPerRead * 32 / kNumPerChannels == 2, "Invalid vectorization"); +// amax = warp_reduce_max<16>(amax); +// calculate_fp8_scales(amax, scale, scale_inv, round_scale); +// if (lane_id == 0 or lane_id == 16) +// rdma_x_scales[i * kNumElemsPerRead / 128] = scale_inv; +// +// // Cast into send buffer +// vec_t int2_value; +// auto fp8x2_values = reinterpret_cast<__nv_fp8x2_storage_t*>(&int2_value); +// #pragma unroll +// for (int j = 0; j < kNumElemsPerRead; j += 2) { +// float2 fp32x2 = {fp32_values[j] * scale, fp32_values[j + 1] * scale}; +// fp8x2_values[j / 2] = __nv_cvt_float2_to_fp8x2(fp32x2, __NV_SATFINITE, __NV_E4M3); +// } +// rdma_x_vec[i] = int2_value; +// } else { +// // Reinterpret-cast is for C++14 compatibility +// rdma_x_vec[i] = *reinterpret_cast(&int4_value); +// } +// } + + // NOTE this cannot be removed even if we do not do casting + // b/c we need to write to `rdma_x_src_idx` + // (but we may optimize it later) + // asm volatile("bar.sync 1, %0;" :: "r"(num_threads)); + + // Issue IBGDA sends +// if (dst_expert_idx >= 0) { + + // NOTE: let external give this + // int slot_idx = lane_id == 0 ? atomicAdd(atomic_counter_per_expert + dst_expert_idx, 1) : 0; + // slot_idx = __shfl_sync(0xffffffff, slot_idx, 0); + + const auto dst_rank = dst_expert_idx / num_local_experts; + const auto dst_expert_local_idx = dst_expert_idx % num_local_experts; + // NOTE do not use `rdma_x` but use `x` + // const auto src_ptr = reinterpret_cast(rdma_x_src_idx); + const auto src_ptr = reinterpret_cast(x_src_idx); + const auto dst_ptr = reinterpret_cast(rdma_recv_x) + + dst_expert_local_idx * num_ranks * num_max_dispatch_tokens_per_rank * Consts::num_bytes_per_msg + + // NOTE modified rm + // rank * num_max_dispatch_tokens_per_rank * Consts::num_bytes_per_msg + + remote_start_offset * Consts::num_bytes_per_msg + + slot_idx * Consts::num_bytes_per_msg; + const auto dst_p2p_ptr = nvshmemi_get_p2p_ptr(dst_ptr, rank, dst_rank); + +// if (dst_p2p_ptr == 0) { +// nvshmemi_ibgda_put_nbi_warp(dst_ptr, src_ptr, Consts::num_bytes_per_msg, dst_rank, dst_expert_local_idx, lane_id, slot_idx); +// } else { + + // NOTES: only 2 load iterations for 7K hidden with 8 unrolls + const auto* src_int4_ptr = reinterpret_cast(src_ptr); + const auto* dst_int4_ptr = reinterpret_cast(dst_p2p_ptr); + + // NOTE do *not* send the first int4, which is handled via the signal + const int4* body_src_int4_ptr = src_int4_ptr + 1; + const int4* body_dst_int4_ptr = dst_int4_ptr + 1; + constexpr int body_num_int4_per_msg = Consts::num_int4_per_msg - 1; + + // UNROLLED_WARP_COPY(8, lane_id, Consts::num_int4_per_msg, dst_int4_ptr, src_int4_ptr, ld_nc_global, st_na_global); + // UNROLLED_WARP_COPY(8, lane_id, body_num_int4_per_msg, body_dst_int4_ptr, body_src_int4_ptr, ld_nc_global, st_na_global); + + constexpr int num_threads_for_copy = 32; + constexpr int loop_num = ceil_div(body_num_int4_per_msg, num_threads_for_copy); + EP_STATIC_ASSERT(loop_num == 8, "unexpected loop_num"); + int4 body_buf[loop_num]; + #pragma unroll + for (int i = 0; i < loop_num; ++i) { + int offset = lane_id + i * num_threads_for_copy; + if (offset < body_num_int4_per_msg) { + body_buf[i] = ld_nc_global(body_src_int4_ptr + offset); + } + } + #pragma unroll + for (int i = 0; i < loop_num; ++i) { + int offset = lane_id + i * num_threads_for_copy; + if (offset < body_num_int4_per_msg) { + st_na_global(body_dst_int4_ptr + offset, body_buf[i]); + } + } + + // Send per-token signal + // NOTE only first 4B of 16B has value, the other 12B is not needed + __syncwarp(); + if (lane_id == 0) { +// if (subroutine_thread_id % 32 == 0) { printf("[R%d,S%d,T%d] st-token-signal START dst_rank=%d addr=%p delta_addr=%d token_idx=%d\n", +// rank, sm_id, subroutine_thread_id, +// dst_rank, (int*)dst_ptr, (int)((int64_t)dst_ptr - (int64_t)rdma_recv_x), token_idx); } + + st_release_sys_global(reinterpret_cast(dst_p2p_ptr), -token_idx - 1); + } +// } + +#if ENABLE_DEBUG_TIMING_TENSOR + write_debug_time( + debug_tensor, t_start, + /* event_group_id */ 2, + /* event_id */ debug_iter_idx, + /* mode_id */ 0, + sm_id, warp_id + ); +#endif + + // not needed in per-token signal approach +// // Increase counter after finishing +// __syncwarp(); +// lane_id == 0 ? atomic_add_release_global(atomic_finish_counter_per_expert + dst_expert_idx, 1) : 0; +// } + + // NOTE: put this check this late to let dst_expert_idx be loaded + // for negative ones (if any), filter them out in previous kernels + EP_DEVICE_ASSERT(dst_expert_idx >= 0); + + // NOTE mv from do-once to do-per-local-expert + // TODO what does this do? do we break something, b/c we let multi SM cooperate? + // (seems it is safe, b/c our next step will check gmem?) + // __syncthreads(); + + // not needed in per-token signal approach +// // NOTE mv from do-once to do-per-local-expert +// // +// // NOTE +// // before: one (sm_id, warp_group_id) = one responsible_expert_idx = send counter to that (dst rank, dst local expert) +// // thus use one thread per warp_group +// // after: reuse the (cooperate_idx, dst_rank) assignment and send counter to that (dsk_rank, const local_expert_idx) +// // thus use one thread per SM +// // +// // Issue count sends +// EP_DEBUG_DEVICE_ASSERT(num_sms >= num_ranks); +// // NOTE changed +// // if (responsible_expert_idx < num_experts and sub_warp_id == 0 and lane_id == 0) { +// if ((cooperate_idx == 0) and (lane_id == 0)) { +// // NOTE changed +// // const auto dst_rank = responsible_expert_idx / num_local_experts; +// // const auto dst_expert_local_idx = responsible_expert_idx % num_local_experts; +// // const auto num_tokens_sent = shared_num_tokens_sent_per_expert[responsible_expert_idx - sm_id * num_warp_groups]; +// const auto dst_expert_local_idx = local_expert_idx; +// const auto responsible_expert_idx = dst_expert_idx; +// const int num_tokens_sent = num_tokens_of_dst_expert; +// +// // Wait local sends issued and send expert counts +// while ( +// ld_acquire_global(atomic_finish_counter_per_expert + responsible_expert_idx) != +// // NOTE changed +// // FINISHED_SUM_TAG * 2 +// FINISHED_SUM_TAG + num_tokens_sent +// ); +// auto dst_ptr = reinterpret_cast(rdma_recv_count + dst_expert_local_idx * num_ranks + rank); +// auto dst_p2p_ptr = nvshmemi_get_p2p_ptr(dst_ptr, rank, dst_rank); +// if (dst_p2p_ptr == 0) { +// nvshmemi_ibgda_amo_nonfetch_add(reinterpret_cast(dst_ptr), -num_tokens_sent - 1, dst_rank, dst_expert_local_idx); +// } else { +// st_release_sys_global(reinterpret_cast(dst_p2p_ptr), -num_tokens_sent - 1); +// } +// +// // Clean workspace for next use +// atomic_counter_per_expert[responsible_expert_idx] = 0; +// atomic_finish_counter_per_expert[responsible_expert_idx] = 0; +// +// // NOTE packed_recv_count zeroing is removed +// // // Clean `packed_recv_count` +// // if (dst_rank == 0) +// // packed_recv_count[dst_expert_local_idx] = 0; +// } +// // TODO what does this do? +// __syncwarp(); + } + +// if (subroutine_thread_id % 32 == 0) { printf("[R%d,S%d,T%d] dispatch_send END\n", rank, sm_id, subroutine_thread_id); } + +// } else if (warp_id == num_warps - 1) { +// EP_DEVICE_ASSERT(num_sms > 1); +// if (sm_id == 0) { +// // The first SM is also responsible for checking QPs +// EP_DEVICE_ASSERT(ibgda_get_state()->num_rc_per_pe >= num_local_experts); +// +// // NOTE (the `next_clean` + notify part is moved) +// } +// +// // This SM should be responsible for some destination experts, read `topk_idx` for them +// int expert_count[kNumMaxWarpGroups] = {0}; +// const auto expert_begin_idx = sm_id * num_warp_groups; +// const auto expert_end_idx = min(expert_begin_idx + num_warp_groups, num_experts); +// +// // Per lane count +// #pragma unroll 8 +// for (int i = lane_id; i < num_tokens * num_topk; i += 32) { +// auto idx = static_cast(__ldg(topk_idx + i)); +// if (idx >= expert_begin_idx and idx < expert_end_idx) +// expert_count[idx - expert_begin_idx] ++; +// } +// +// // Warp reduce +// #pragma unroll +// for (int i = expert_begin_idx; i < expert_end_idx; ++ i) { +// auto sum = warp_reduce_sum(expert_count[i - expert_begin_idx]); +// if (lane_id == 0) { +// shared_num_tokens_sent_per_expert[i - expert_begin_idx] = sum; +// atomic_add_release_global(atomic_finish_counter_per_expert + i, FINISHED_SUM_TAG - sum); +// } +// } +// } +} + +template +__forceinline__ __device__ void dispatch_recv( + int subroutine_thread_id, int num_warp_groups, + + // copied args + void* packed_recv_x, void* packed_recv_x_scales, + int* packed_recv_src_info, int64_t* packed_recv_layout_range, + int* packed_recv_count, + int* cumulative_local_expert_recv_stats, + int64_t* dispatch_wait_recv_cost_stats, + void* rdma_recv_x, + // int* rdma_recv_count, // NOTE removed + // void* rdma_x, // NOTE removed + const void* x, const int64_t* topk_idx, + int* atomic_counter_per_expert, + // int* atomic_finish_counter_per_expert, // NOTE removed + int* next_clean, int num_next_clean_int, + int num_tokens, int num_max_dispatch_tokens_per_rank, + int num_topk, int num_experts, int rank, int num_ranks, + // int num_send_warp_groups, int num_recv_warp_groups, // NOTE removed + int num_warps_per_group, + bool round_scale, int phases, + uint32_t* dst_signals, + uint32_t* count_per_expert, int64_t* token_idx_and_dst_expert_and_dst_slot_idx_flat_list, + int64_t* layout_range_buffer, int* negotiate_offset_of_expert_buffer, int* remote_start_offset_buffer, + int* debug_tensor +) { + uint32_t t_start = clock(); + + using Consts = DispatchConstsTemplate; + + // NOTE copied from dispatch body + const auto sm_id = static_cast(blockIdx.x); + const auto num_sms = static_cast(gridDim.x); // unused + const auto warp_id = subroutine_thread_id / 32, lane_id = get_lane_id(); + const auto num_warps = num_warp_groups * num_warps_per_group; // unused + const auto num_local_experts = num_experts / num_ranks; + // const auto warp_group_id = warp_id / num_warps_per_group; + // const auto sub_warp_id = warp_id % num_warps_per_group; + + // NOTE rm + // const auto responsible_expert_idx = sm_id * num_warp_groups + warp_group_id; + + // May extract UE8M0 from the scales + using scale_t = std::conditional_t; + using packed_t = std::conditional_t; + EP_STATIC_ASSERT(sizeof(packed_t) % sizeof(scale_t) == 0, "Invalid vector length"); + EP_STATIC_ASSERT(!(kUseFP8 && kUseNVFP4), "FP8 and NVFP4 cannot be used together"); + +// if (subroutine_thread_id % 32 == 0) { printf("[R%d,S%d,T%d] dispatch_recv START\n", rank, sm_id, subroutine_thread_id); } + +// NOTE packed_recv_count zeroing is removed +// // For send-and-recv kernels, we need a grid sync for making `packed_recv_count` visible +// if (phases & LOW_LATENCY_SEND_PHASE) +// cg::this_grid().sync(); + + // TODO a lot of SM is wasted, optimize it later + // TODO at least make dispatch_recv have 16 instead of 8 warps + // + // NOTE + // before: one (sm_id, warp_group_id) = one responsible_expert_idx = handle all tokens for one (src_rank, local_expert_idx) + // after: reshape (warp_id, sm_id) into (cooperate_idx, src_rank) + // then all num_cooperate warps handle tokens from one src_rank + const int num_cooperate_parts = num_sms * num_warps / num_ranks; + EP_DEVICE_ASSERT(num_sms * num_warps == num_cooperate_parts * num_ranks); // even division + const int flatten_id = warp_id * num_sms + sm_id; + const int cooperate_idx = flatten_id / num_ranks; + const int src_rank = flatten_id % num_ranks; + + // Receiving and packing + // NOTE if -> for + // if (responsible_expert_idx < num_experts) { + EP_DEVICE_ASSERT(num_warp_groups == 1); // not consider multi warp_group case below + for (int local_expert_idx = 0; local_expert_idx < num_local_experts; ++local_expert_idx) { +// if (subroutine_thread_id % 32 == 0) { printf("[R%d,S%d,T%d] dispatch_recv local_expert_idx=%d START \n", rank, sm_id, subroutine_thread_id, local_expert_idx); } +#if ENABLE_DEBUG_TIMING_TENSOR + write_debug_time( + debug_tensor, t_start, + /* event_group_id */ 0, + /* event_id */ local_expert_idx, + /* mode_id */ 1, + sm_id, warp_id + ); +#endif + + // NOTE modified + // const auto src_rank = responsible_expert_idx / num_local_experts; + // const auto local_expert_idx = responsible_expert_idx % num_local_experts; + + // NOTE MODIFIED + const auto rdma_recv_x_uint8 = static_cast(rdma_recv_x) + + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * Consts::num_bytes_per_msg; + // this is removed + // + src_rank * num_max_dispatch_tokens_per_rank * Consts::num_bytes_per_msg; +// const auto recv_x_int4 = static_cast(packed_recv_x) + +// local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * Consts::hidden_int4; + const auto recv_src_info = packed_recv_src_info + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank; + const auto recv_range = packed_recv_layout_range + local_expert_idx * num_ranks; + const auto num_aligned_scales = align(Consts::num_scales, sizeof(float) / sizeof(scale_t)); + const auto recv_x_scales = static_cast(packed_recv_x_scales) + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_aligned_scales; + + int num_recv_tokens, token_start_offset; + if (lane_id == 0) { +// if (subroutine_thread_id % 32 == 0) { printf("[R%d,S%d,T%d] ld-layout START\n", rank, sm_id, subroutine_thread_id); } + +// auto loop_start_time = clock64(); + int64_t layout; + while((layout = ld_volatile_global(layout_range_buffer + local_expert_idx * num_ranks + src_rank)) == 0) { +// if ((clock64() - loop_start_time) >= 20000000000ULL) { +// printf("[R%d,S%d,T%d] ld-layout STUCK\n", rank, sm_id, subroutine_thread_id); +// loop_start_time = clock64(); // reset warning +// } + } + layout = -layout - 1; + unpack2(layout, num_recv_tokens, token_start_offset); + +#if ENABLE_DEBUG_TIMING_TENSOR + write_debug_time( + debug_tensor, t_start, + /* event_group_id */ 1, + /* event_id */ local_expert_idx, + /* mode_id */ 1, + sm_id, warp_id + ); +#endif + + if (cooperate_idx == 0) { + // TODO may not need to do this extra copy - directly use the `layout_range_buffer` + recv_range[src_rank] = layout; + // TODO may also not need to do this extra copy - directly use the `negotiate_offset_of_expert_buffer` + atomicAdd(packed_recv_count + local_expert_idx, num_recv_tokens); + } + +// if (subroutine_thread_id % 32 == 0) { printf("[R%d,S%d,T%d] ld-layout END num_recv_tokens=%d token_start_offset=%d\n", rank, sm_id, subroutine_thread_id, num_recv_tokens, token_start_offset); } + + if ((dst_signals != nullptr) and (cooperate_idx == 0)) { + atomic_add_release_global(dst_signals + local_expert_idx, ((src_rank == 0) ? DST_SIGNAL_EXPECT_VALUE: 0) - num_recv_tokens); + } + } + num_recv_tokens = __shfl_sync(0xffffffff, num_recv_tokens, 0); + token_start_offset = __shfl_sync(0xffffffff, token_start_offset, 0); + + // NOTE no longer have per-expert signals +// // Shared between sub-warps in warp groups +// __shared__ int shared_num_recv_tokens[kNumMaxWarpGroups], shared_recv_token_begin_idx[kNumMaxWarpGroups]; +// +// // Wait tokens to arrive +// // NOTES: using sub-warp 1 to overlap with sub-warp 0 +// int num_recv_tokens, recv_token_begin_idx; +// EP_DEVICE_ASSERT(num_warps_per_group > 1 and num_warp_groups < 15); +// if (sub_warp_id == 1 and lane_id == 0) { +// // auto start_time = clock64(); // not used +// while ((num_recv_tokens = ld_acquire_sys_global(rdma_recv_count + local_expert_idx * num_ranks + src_rank)) == 0); +// // auto wait_recv_cost = clock64() - start_time; // not used +// num_recv_tokens = -num_recv_tokens - 1; +// recv_token_begin_idx = atomicAdd(packed_recv_count + local_expert_idx, num_recv_tokens); +// shared_num_recv_tokens[warp_group_id] = num_recv_tokens; +// shared_recv_token_begin_idx[warp_group_id] = recv_token_begin_idx; +// recv_range[src_rank] = pack2(num_recv_tokens, recv_token_begin_idx); +// +// // not handled +// // // Add stats for diagnosis +// // if (cumulative_local_expert_recv_stats != nullptr) +// // atomicAdd(cumulative_local_expert_recv_stats + local_expert_idx, num_recv_tokens); +// // if (dispatch_wait_recv_cost_stats != nullptr) +// // atomicAdd(reinterpret_cast(dispatch_wait_recv_cost_stats + src_rank), wait_recv_cost); +// } +// asm volatile("bar.sync %0, %1;" :: "r"(warp_group_id + 2), "r"(num_warps_per_group * 32)); +// num_recv_tokens = shared_num_recv_tokens[warp_group_id]; +// recv_token_begin_idx = shared_recv_token_begin_idx[warp_group_id]; + + // Copy tokens + // for (int i = sub_warp_id; i < num_recv_tokens; i += num_warps_per_group) { + for ( + int i_raw = cooperate_idx, debug_inner_idx = 0; + i_raw < num_recv_tokens; + i_raw += num_cooperate_parts, debug_inner_idx++ + ) { + const int token_idx = i_raw + token_start_offset; + +// // Copy source info + const auto src_src_idx = reinterpret_cast(rdma_recv_x_uint8 + token_idx * Consts::num_bytes_per_msg); +// if (lane_id == 0) +// recv_src_info[recv_token_begin_idx + i] = ld_nc_global(src_src_idx); + + // Read signal + Copy source info + if (lane_id == 0) { +// if (subroutine_thread_id % 32 == 0) { printf("[R%d,S%d,T%d] ld-token-signal START addr=%p delta_addr=%d token_idx=%d\n", +// rank, sm_id, subroutine_thread_id, src_src_idx, (int)((int64_t)src_src_idx - (int64_t)rdma_recv_x), token_idx); } + +// auto loop_start_time = clock64(); + int recv_src_idx; + while ((recv_src_idx = ld_acquire_sys_global(src_src_idx)) == 0) { +// if ((clock64() - loop_start_time) >= 20000000000ULL) { +// printf("[R%d,S%d,T%d] ld-token-signal STUCK\n", rank, sm_id, subroutine_thread_id); +// loop_start_time = clock64(); // reset warning +// } + } + recv_src_idx = -recv_src_idx-1; + +#if ENABLE_DEBUG_TIMING_TENSOR + write_debug_time( + debug_tensor, t_start, + /* event_group_id */ 2, + /* event_id */ local_expert_idx * 10 + debug_inner_idx, + /* mode_id */ 1, + sm_id, warp_id + ); +#endif + +// if (subroutine_thread_id % 32 == 0) { printf("[R%d,S%d,T%d] ld-token-signal END recv_src_idx=%d\n", rank, sm_id, subroutine_thread_id, recv_src_idx); } + + // cleanup (will be used in the next round) + *src_src_idx = 0; + + recv_src_info[token_idx] = recv_src_idx; + } + __syncwarp(); + + // do not need to copy real data now +// // Copy data +// // NOTES: only 2 load iterations for 7K hidden with 7 unrolls + const auto src_data = reinterpret_cast(reinterpret_cast(src_src_idx) + sizeof(int4)); +// const auto dst_data = recv_x_int4 + (recv_token_begin_idx + i) * Consts::hidden_int4; +// UNROLLED_WARP_COPY(7, lane_id, Consts::hidden_int4, dst_data, src_data, ld_nc_global, st_na_global); + + // Copy scales + if constexpr (kUseFP8) { + // NOTE simply remove to simplify code + EP_DEVICE_ASSERT(false); +// EP_DEVICE_ASSERT(Consts::num_scales <= 64); +// // Equivalent CuTe layout: +// // (num_tokens, (num_packed, num_elems_per_pack)):(num_elems_per_pack, (num_tokens * num_elems_per_pack, 1)) +// const auto src_scales = reinterpret_cast(reinterpret_cast(src_data) + Consts::hidden_bytes); +// const auto num_elems_per_pack = static_cast(sizeof(packed_t) / sizeof(scale_t)); +// const auto token_idx = recv_token_begin_idx + i; +// const auto token_stride = num_elems_per_pack; +// const auto pack_stride = num_ranks * num_max_dispatch_tokens_per_rank * num_elems_per_pack; +// if (lane_id < Consts::num_scales) { +// const auto pack_idx = lane_id / num_elems_per_pack; +// const auto elem_idx = lane_id % num_elems_per_pack; +// auto scale = extract_required_scale_format(ld_nc_global(src_scales + lane_id)); +// recv_x_scales[token_idx * token_stride + pack_idx * pack_stride + elem_idx] = scale; +// } +// if (lane_id + 32 < Consts::num_scales) { +// const auto pack_idx = (lane_id + 32) / num_elems_per_pack; +// const auto elem_idx = (lane_id + 32) % num_elems_per_pack; +// auto scale = extract_required_scale_format(ld_nc_global(src_scales + lane_id + 32)); +// recv_x_scales[token_idx * token_stride + pack_idx * pack_stride + elem_idx] = scale; +// } + } else if constexpr (kUseNVFP4) { + // The physical layout is (l, rm, rk, 32, 4, 4). + const auto src_scales = reinterpret_cast(reinterpret_cast(src_data) + Consts::hidden_bytes); + const auto num_elems_per_pack = static_cast(sizeof(packed_t) / sizeof(scale_t)); + // const auto token_idx = recv_token_begin_idx + i; // NOTE changed + const auto token_stride = Consts::num_scales * sizeof(scale_t); + const auto pack_stride = num_elems_per_pack; + const auto rm = token_idx / 128; + const auto rm_res = token_idx % 128; + + // TODO use int4 read + constexpr int loop_num = ceil_div(Consts::num_scales, 32); + EP_STATIC_ASSERT(loop_num == 14, "unexpected loop_num"); + EP_STATIC_ASSERT(loop_num * 32 == Consts::num_scales, "expect even division"); + uint8_t buf[loop_num]; + #pragma unroll + for (int loop_idx = 0; loop_idx < loop_num; ++loop_idx) { + const int j = lane_id + loop_idx * 32; + buf[loop_idx] = ld_nc_global(src_scales + j); + } + #pragma unroll + for (int loop_idx = 0; loop_idx < loop_num; ++loop_idx) { + const int j = lane_id + loop_idx * 32; + const auto pack_idx = j / num_elems_per_pack; + const auto elem_idx = j % num_elems_per_pack; + recv_x_scales[rm * token_stride * 128 + pack_idx * pack_stride * 128 + rm_res * pack_stride + elem_idx] = buf[loop_idx]; + } + } + + if (dst_signals != nullptr) { + __syncwarp(); + if (lane_id == 0) { + atomic_add_release_global(dst_signals + local_expert_idx, 1); + } + } + } + } + +// if (subroutine_thread_id % 32 == 0) { printf("[R%d,S%d,T%d] dispatch_recv END\n", rank, sm_id, subroutine_thread_id); } +} + +template +__global__ +// __launch_bounds__(1024, 1) +__maxnreg__(48) // TODO +void +dispatch_v2(void* packed_recv_x, void* packed_recv_x_scales, + int* packed_recv_src_info, int64_t* packed_recv_layout_range, + int* packed_recv_count, + int* cumulative_local_expert_recv_stats, + int64_t* dispatch_wait_recv_cost_stats, + void* rdma_recv_x, + int* rdma_general_signal, // NOTE renamed from `rdma_recv_count` + // void* rdma_x, // NOTE removed + void* x, const int64_t* topk_idx, // NOTE rm `const` of x + int* atomic_counter_per_expert, + // int* atomic_finish_counter_per_expert, // NOTE removed + int* next_clean, int num_next_clean_int, + int num_tokens, int num_max_dispatch_tokens_per_rank, + int num_topk, int num_experts, int rank, int num_ranks, + // NOTE split num_warp_groups + int num_send_warp_groups, int num_recv_warp_groups, + int num_send_warps_per_group, int num_recv_warps_per_group, + bool round_scale, int phases, + uint32_t* dst_signals, + uint32_t* count_per_expert, int64_t* token_idx_and_dst_expert_and_dst_slot_idx_flat_list, + int* remote_start_offset_buffer, + int* debug_tensor + ) { + const auto sm_id = static_cast(blockIdx.x); + const auto num_send_threads = num_send_warp_groups * num_send_warps_per_group * 32; + const auto raw_thread_id = static_cast(threadIdx.x); + + // NOTE Please keep in sync: Config.get_nvl_buffer_size_hint, LowLatencyLayout.constructor, internode_ll_v2 + // + // (num_local_experts, num_ranks). written by REMOTE gpus, read by curr gpu. + // arr[local_expert_idx, src_rank] := the (num_tokens, start_offset) layout information of that src_rank + // similar to `packed_recv_layout_range`, but written remotely + int64_t* layout_range_buffer = (int64_t*) rdma_general_signal; + // (num_local_experts,). use by REMOTE gpus. all gpus atomic-add on it to get a slice of locations to send data to + int* negotiate_offset_of_expert_buffer = (int*) (((uint8_t*)rdma_general_signal) + num_experts * sizeof(int64_t)); + + if ((sm_id == 0) and (raw_thread_id == 0)) { + // assert alignment + EP_DEVICE_ASSERT(((int64_t)layout_range_buffer) % 16 == 0); + EP_DEVICE_ASSERT(((int64_t)negotiate_offset_of_expert_buffer) % 16 == 0); + } + + if (raw_thread_id < num_send_threads) { + if (phases & LOW_LATENCY_SEND_PHASE) { + const auto send_thread_id = raw_thread_id; + dispatch_send( + send_thread_id, num_send_warp_groups, + + // forward args + packed_recv_x, packed_recv_x_scales, + packed_recv_src_info, packed_recv_layout_range, + packed_recv_count, + cumulative_local_expert_recv_stats, + dispatch_wait_recv_cost_stats, + rdma_recv_x, + x, topk_idx, + atomic_counter_per_expert, + next_clean, num_next_clean_int, + num_tokens, num_max_dispatch_tokens_per_rank, + num_topk, num_experts, rank, num_ranks, + num_send_warps_per_group, + round_scale, phases, + dst_signals, + count_per_expert, token_idx_and_dst_expert_and_dst_slot_idx_flat_list, + layout_range_buffer, negotiate_offset_of_expert_buffer, remote_start_offset_buffer, + debug_tensor + ); + } + } else { + if (phases & LOW_LATENCY_RECV_PHASE) { + const auto recv_thread_id = raw_thread_id - num_send_threads; + dispatch_recv( + recv_thread_id, num_recv_warp_groups, + + // forward args + packed_recv_x, packed_recv_x_scales, + packed_recv_src_info, packed_recv_layout_range, + packed_recv_count, + cumulative_local_expert_recv_stats, + dispatch_wait_recv_cost_stats, + rdma_recv_x, + x, topk_idx, + atomic_counter_per_expert, + next_clean, num_next_clean_int, + num_tokens, num_max_dispatch_tokens_per_rank, + num_topk, num_experts, rank, num_ranks, + num_recv_warps_per_group, + round_scale, phases, + dst_signals, + count_per_expert, token_idx_and_dst_expert_and_dst_slot_idx_flat_list, + layout_range_buffer, negotiate_offset_of_expert_buffer, remote_start_offset_buffer, + debug_tensor + ); + } + } + +// NOTE removed +// // Sending phase +// if ((phases & LOW_LATENCY_SEND_PHASE) == 0) +// goto LOW_LATENCY_DISPATCH_RECV; +// +// // Receiving phase +// LOW_LATENCY_DISPATCH_RECV: +// if ((phases & LOW_LATENCY_RECV_PHASE) == 0) +// return; +} + +void dispatch_v2(void* packed_recv_x, void* packed_recv_x_scales, + int* packed_recv_src_info, int64_t* packed_recv_layout_range, + int* packed_recv_count, + int* cumulative_local_expert_recv_stats, + int64_t* dispatch_wait_recv_cost_stats, + void* rdma_recv_x, int* rdma_recv_count, + // void* rdma_x, // NOTE removed + void* x, const int64_t* topk_idx, // NOTE rm `const` of x + int* next_clean, int num_next_clean_int, + int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank, + int num_topk, int num_experts, int rank, int num_ranks, + bool use_fp8, bool round_scale, bool use_ue8m0, + void* workspace, int num_device_sms, + cudaStream_t stream, int phases, + bool use_nvfp4, uint32_t* dst_signals, + uint32_t* count_per_expert, int64_t* token_idx_and_dst_expert_and_dst_slot_idx_flat_list, + int* remote_start_offset_buffer, int* zeroed_buffer_for_atomic_counter_per_expert, + int* debug_tensor) { + + EP_HOST_ASSERT(false); // should re-create deep_ep.cpp cudaMallocAndZero before using this + + constexpr int kNumMaxTopK = 9; + + // NOTE simple renaming + int* rdma_general_signal = rdma_recv_count; + + // NOTE MODIFIED + // const int num_warp_groups = ceil_div(num_experts, num_device_sms); + const int num_warp_groups = 2; + + // NOTE temporarily reduce num warps per group to avoid workload imbalance in dispatch_send + // TODO may increase it later e.g. for dispatch_recv + int num_send_warps_per_group = 32 / num_warp_groups; + int num_recv_warps_per_group = num_send_warps_per_group; + EP_HOST_ASSERT(num_warp_groups > 0 and num_send_warps_per_group > 0 and num_recv_warps_per_group > 0); + + // NOTE temp hack + if (phases == LOW_LATENCY_SEND_PHASE) { +// printf("HACK: give all warps to send!\n"); + num_send_warps_per_group = 32; + num_recv_warps_per_group = 0; + } else if (phases == LOW_LATENCY_RECV_PHASE) { +// printf("HACK: give all warps to recv!\n"); + num_send_warps_per_group = 0; + num_recv_warps_per_group = 32; + } else { + // do nothing + } + + // NOTE no longer need one SM to send all topk destinations + // EP_HOST_ASSERT(kNumMaxTopK + 1 <= num_warp_groups * num_warps_per_group); + + // Workspace checks + // auto atomic_counter_per_expert = static_cast(workspace); // NOTE let users pass a zeroed buffer + // auto atomic_finish_counter_per_expert = atomic_counter_per_expert + num_experts; // NOTE removed + // EP_HOST_ASSERT(num_experts * sizeof(int) * 2 <= NUM_WORKSPACE_BYTES); + + // TODO inefficient, may change it + // NOTE add + EP_HOST_ASSERT(num_warp_groups >= 2); + const int num_send_warp_groups = num_warp_groups - 1; + const int num_recv_warp_groups = 1; + + const auto num_warps = num_send_warp_groups * num_send_warps_per_group + num_recv_warp_groups * num_recv_warps_per_group; + const auto num_sms = ceil_div(num_experts, num_warp_groups); + EP_HOST_ASSERT(num_topk <= kNumMaxTopK); + + // FP8 checks + if (use_ue8m0) + EP_HOST_ASSERT(round_scale and "UE8M0 SF requires `round_scale=True`"); + + EP_HOST_ASSERT(use_nvfp4 and (not use_fp8) and (not use_ue8m0)); +// auto dispatch_func = dispatch_v2