diff --git a/Cargo.lock b/Cargo.lock index c9f6377..4c0ce74 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -17,6 +17,15 @@ version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627" +[[package]] +name = "aho-corasick" +version = "1.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916" +dependencies = [ + "memchr", +] + [[package]] name = "anyhow" version = "1.0.98" @@ -67,6 +76,21 @@ dependencies = [ "rayon", ] +[[package]] +name = "bit-set" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0700ddab506f33b20a03b13996eccd309a48e5ff77d0d95926aa0210fb4e95f1" +dependencies = [ + "bit-vec", +] + +[[package]] +name = "bit-vec" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "349f9b6a179ed607305526ca489b34ad0a41aed5f7980fa90eb03160b69598fb" + [[package]] name = "bitflags" version = "1.3.2" @@ -171,6 +195,19 @@ dependencies = [ "rayon", ] +[[package]] +name = "candle-moe" +version = "0.0.1" +dependencies = [ + "anyhow", + "bindgen_cuda", + "candle-core", + "candle-nn", + "candle-transformers", + "cudarc", + "half", +] + [[package]] name = "candle-nn" version = "0.8.4" @@ -197,6 +234,25 @@ dependencies = [ "half", ] +[[package]] +name = "candle-transformers" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94a0900d49f8605e0e7e6693a1f560e6271279de98e5fa369e7abf3aac245020" +dependencies = [ + "byteorder", + "candle-core", + "candle-nn", + "fancy-regex", + "num-traits", + "rand", + "rayon", + "serde", + "serde_json", + "serde_plain", + "tracing", +] + [[package]] name = "cfg-if" version = "1.0.0" @@ -318,6 +374,17 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" +[[package]] +name = "fancy-regex" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "531e46835a22af56d1e3b66f04844bed63158bc094a628bec1d321d9b4c44bf2" +dependencies = [ + "bit-set", + "regex-automata", + "regex-syntax", +] + [[package]] name = "gemm" version = "0.17.1" @@ -959,6 +1026,23 @@ version = "0.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "03251193000f4bd3b042892be858ee50e8b3719f2b08e5833ac4353724632430" +[[package]] +name = "regex-automata" +version = "0.4.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", +] + +[[package]] +name = "regex-syntax" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" + [[package]] name = "rustc-demangle" version = "0.1.24" @@ -1028,6 +1112,15 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_plain" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ce1fc6db65a611022b23a0dec6975d63fb80a302cb3388835ff02c097258d50" +dependencies = [ + "serde", +] + [[package]] name = "stable_deref_trait" version = "1.2.0" diff --git a/Cargo.toml b/Cargo.toml index 62d015f..6b1287f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,6 +4,7 @@ members = [ "candle-rotary", "candle-flash-attn-v1", "candle-cublaslt", + "candle-moe", ] resolver = "2" diff --git a/candle-moe/.gitignore b/candle-moe/.gitignore new file mode 100644 index 0000000..5a8bfa6 --- /dev/null +++ b/candle-moe/.gitignore @@ -0,0 +1,17 @@ +# Generated by Cargo +# will have compiled files and executables +debug/ +target/ + +# These are backup files generated by rustfmt +**/*.rs.bk + +# MSVC Windows builds of rustc generate these, which store debugging information +*.pdb + +# RustRover +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +.idea/ diff --git a/candle-moe/Cargo.toml b/candle-moe/Cargo.toml new file mode 100644 index 0000000..78a48e1 --- /dev/null +++ b/candle-moe/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "candle-moe" +description = "fused MoE layer for the candle ML framework." +homepage = "https://github.com/huggingface/candle-extensions/candle-moe/" +version.workspace = true +edition.workspace = true +keywords.workspace = true +categories.workspace = true +license.workspace = true +repository.workspace = true + +[dependencies] +candle = { version = "0.8", package = "candle-core", features = ["cuda"] } +cudarc = { version = "0.13.3", features = ["cuda-12080"], default-features = false } +half = { workspace = true } + +[build-dependencies] +anyhow = { workspace = true } +bindgen_cuda = { workspace = true } + +[dev-dependencies] +anyhow = { workspace = true } +candle-nn = { version = "0.8", features = ["cuda"] } +candle-transformers = { version = "0.8" } diff --git a/candle-moe/README.md b/candle-moe/README.md new file mode 100644 index 0000000..725257d --- /dev/null +++ b/candle-moe/README.md @@ -0,0 +1,5 @@ +# candle-moe + +fused MoE kernel in Candle backend + +This layer is adapted from https://huggingface.co/kernels-community/moe. diff --git a/candle-moe/build.rs b/candle-moe/build.rs new file mode 100644 index 0000000..f6005df --- /dev/null +++ b/candle-moe/build.rs @@ -0,0 +1,65 @@ +// Build script to run nvcc and generate the C glue code for launching the flash-attention kernel. +// The cuda build time is very long so one can set the CANDLE_FLASH_ATTN_BUILD_DIR environment +// variable in order to cache the compiled artifacts and avoid recompiling too often. +use anyhow::{Context, Result}; +use std::path::PathBuf; + +const KERNEL_FILES: [&str; 2] = [ + "kernels/topk_softmax_kernels.cu", + "kernels/moe_align_sum_kernels.cu", +]; + +fn main() -> Result<()> { + println!("cargo:rerun-if-changed=build.rs"); + for kernel_file in KERNEL_FILES.iter() { + println!("cargo:rerun-if-changed={kernel_file}"); + } + println!("cargo:rerun-if-changed=kernels/moe_wna16_utils.h"); + + let out_dir = PathBuf::from(std::env::var("OUT_DIR").context("OUT_DIR not set")?); + let build_dir = match std::env::var("CANDLE_MOE_BUILD_DIR") { + Err(_) => + { + #[allow(clippy::redundant_clone)] + out_dir.clone() + } + Ok(build_dir) => { + let path = PathBuf::from(build_dir); + let current_dir = std::env::current_dir()?; + path.canonicalize().unwrap_or_else(|_| { + panic!( + "Directory doesn't exists: {} (the current directory is {})", + &path.display(), + current_dir.display() + ) + }) + } + }; + + let kernels: Vec<_> = KERNEL_FILES.iter().collect(); + let builder = bindgen_cuda::Builder::default() + .kernel_paths(kernels) + .out_dir(build_dir.clone()) + .arg("-std=c++17") + .arg("-O3") + .arg("--compiler-options") + .arg("-fPIC") + .arg("-U__CUDA_NO_HALF_OPERATORS__") + .arg("-U__CUDA_NO_HALF_CONVERSIONS__") + .arg("-U__CUDA_NO_HALF2_OPERATORS__") + .arg("-U__CUDA_NO_BFLOAT16_CONVERSIONS__") + .arg("--expt-relaxed-constexpr") + .arg("--expt-extended-lambda") + .arg("--use_fast_math") + .arg("--ptxas-options=-v") + .arg("--verbose"); + + let out_file = build_dir.join("libmoe.a"); + builder.build_lib(out_file); + + println!("cargo:rustc-link-search={}", build_dir.display()); + println!("cargo:rustc-link-lib=moe"); + println!("cargo:rustc-link-lib=dylib=cudart"); + + Ok(()) +} diff --git a/candle-moe/kernels/cuda_compat.h b/candle-moe/kernels/cuda_compat.h new file mode 100644 index 0000000..82e5561 --- /dev/null +++ b/candle-moe/kernels/cuda_compat.h @@ -0,0 +1,49 @@ +#pragma once + +#ifdef USE_ROCM + #include +#endif + +#ifndef USE_ROCM + #define WARP_SIZE 32 +#else + #define WARP_SIZE warpSize +#endif + +#ifndef USE_ROCM + #define VLLM_LDG(arg) __ldg(arg) +#else + #define VLLM_LDG(arg) *(arg) +#endif + +#ifndef USE_ROCM + #define VLLM_SHFL_XOR_SYNC(var, lane_mask) \ + __shfl_xor_sync(uint32_t(-1), var, lane_mask) + #define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) \ + __shfl_xor_sync(uint32_t(-1), var, lane_mask, width) +#else + #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask) + #define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) \ + __shfl_xor(var, lane_mask, width) +#endif + +#ifndef USE_ROCM + #define VLLM_SHFL_SYNC(var, src_lane) __shfl_sync(uint32_t(-1), var, src_lane) +#else + #define VLLM_SHFL_SYNC(var, src_lane) __shfl(var, src_lane) +#endif + +#ifndef USE_ROCM + #define VLLM_SHFL_DOWN_SYNC(var, lane_delta) \ + __shfl_down_sync(uint32_t(-1), var, lane_delta) +#else + #define VLLM_SHFL_DOWN_SYNC(var, lane_delta) __shfl_down(var, lane_delta) +#endif + +#ifndef USE_ROCM + #define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \ + cudaFuncSetAttribute(FUNC, cudaFuncAttributeMaxDynamicSharedMemorySize, VAL) +#else + #define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \ + hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL) +#endif diff --git a/candle-moe/kernels/moe_align_sum_kernels.cu b/candle-moe/kernels/moe_align_sum_kernels.cu new file mode 100644 index 0000000..32968c6 --- /dev/null +++ b/candle-moe/kernels/moe_align_sum_kernels.cu @@ -0,0 +1,222 @@ +#include +#include + +#include "cuda_compat.h" + +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#define CEILDIV(x, y) (((x) + (y) - 1) / (y)) + +namespace vllm { +namespace moe { + +namespace { +__device__ __forceinline__ int32_t index(int32_t total_col, int32_t row, + int32_t col) { + // don't worry about overflow because num_experts is relatively small + return row * total_col + col; +} +} // namespace + +template +__global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, + int32_t* sorted_token_ids, + int32_t* expert_ids, + int32_t* total_tokens_post_pad, + int32_t num_experts, + int32_t block_size, size_t numel) { + const size_t tokens_per_thread = CEILDIV(numel, blockDim.x); + const size_t start_idx = threadIdx.x * tokens_per_thread; + + extern __shared__ int32_t shared_mem[]; + int32_t* cumsum = shared_mem; // 1d tensor with shape (num_experts + 1) + token_cnts_t* tokens_cnts = + (token_cnts_t*)(shared_mem + num_experts + + 1); // 2d tensor with shape (blockDim.x + 1, num_experts) + + for (int i = 0; i < num_experts; ++i) { + tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0; + } + + /** + * In the first step we compute token_cnts[thread_index + 1][expert_index], + * which counts how many tokens in the token shard of thread_index are + * assigned to expert expert_index. + */ + for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { + ++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])]; + } + + __syncthreads(); + + // For each expert we accumulate the token counts from the different threads. + if (threadIdx.x < num_experts) { + tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0; + for (int i = 1; i <= blockDim.x; ++i) { + tokens_cnts[index(num_experts, i, threadIdx.x)] += + tokens_cnts[index(num_experts, i - 1, threadIdx.x)]; + } + } + + __syncthreads(); + + // We accumulate the token counts of all experts in thread 0. + if (threadIdx.x == 0) { + cumsum[0] = 0; + for (int i = 1; i <= num_experts; ++i) { + cumsum[i] = cumsum[i - 1] + + CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)], + block_size) * + block_size; + } + *total_tokens_post_pad = static_cast(cumsum[num_experts]); + } + + __syncthreads(); + + /** + * For each expert, each thread processes the tokens of the corresponding + * blocks and stores the corresponding expert_id for each block. + */ + if (threadIdx.x < num_experts) { + for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; + i += block_size) { + expert_ids[i / block_size] = threadIdx.x; + } + } + + /** + * Each thread processes a token shard, calculating the index of each token + * after sorting by expert number. Given the example topk_ids = + * [0,1,2,1,2,3,0,3,4] and block_size = 4, then the output would be [0, 6, *, + * *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *], where * represents a + * padding value(preset in python). + */ + for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { + int32_t expert_id = topk_ids[i]; + /** The cumsum[expert_id] stores the starting index of the tokens that the + * expert with expert_id needs to process, and + * tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens + * processed by the expert with expert_id within the current thread's token + * shard. + */ + int32_t rank_post_pad = + tokens_cnts[index(num_experts, threadIdx.x, expert_id)] + + cumsum[expert_id]; + sorted_token_ids[rank_post_pad] = i; + ++tokens_cnts[index(num_experts, threadIdx.x, expert_id)]; + } +} + +template +__global__ void moe_sum_kernel( + scalar_t* __restrict__ out, // [..., d] + const scalar_t* __restrict__ input, // [..., topk, d] + const int d) { + const int64_t token_idx = blockIdx.x; + for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { + scalar_t x = 0.0; +#pragma unroll + for (int k = 0; k < TOPK; ++k) { + x += VLLM_LDG(&input[token_idx * TOPK * d + k * d + idx]); + } + out[token_idx * d + idx] = x; + } +} + +} // namespace moe +} // namespace vllm + +#define CALL_MOE_ALIGN_BLOCK_SIZE_KERNEL(T) \ + vllm::moe::moe_align_block_size_kernel<<<1, num_thread, shared_mem_i32, stream>>>( \ + reinterpret_cast(topk_ids), \ + reinterpret_cast(sorted_token_ids), \ + reinterpret_cast(experts_ids), \ + reinterpret_cast(num_tokens_post_pad), \ + num_experts, \ + block_size, \ + numel \ + ); + +extern "C" void moe_align_block_size( + void *topk_ids, + int64_t num_experts, + int64_t block_size, + int64_t numel, + void *sorted_token_ids, + void *experts_ids, + void *num_tokens_post_pad, + uint32_t dtype +) { + const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE); + const int32_t shared_mem_i32 = + ((num_thread + 1) * num_experts + (num_experts + 1)) * sizeof(int32_t); + const cudaStream_t stream = 0; + + if (dtype == 0) { + CALL_MOE_ALIGN_BLOCK_SIZE_KERNEL(half); + } else if (dtype == 1) { + CALL_MOE_ALIGN_BLOCK_SIZE_KERNEL(__nv_bfloat16); + } else { + CALL_MOE_ALIGN_BLOCK_SIZE_KERNEL(float); + } +} + +#define CALL_MOE_SUM_KERNEL(T, TOPK) \ + vllm::moe::moe_sum_kernel<<>>( \ + reinterpret_cast(output), \ + reinterpret_cast(input), \ + hidden_size \ + ); + +extern "C" void moe_sum( + void *input, // [num_tokens, topk, hidden_size] + void *output, // [num_tokens, hidden_size] + + uint64_t hidden_size, + uint64_t num_tokens, + uint64_t topk, + uint32_t dtype // 0 => f16; 1 => bf16; 2 => f32 +) { + dim3 grid(num_tokens); + dim3 block(MIN(hidden_size, 1024)); + const cudaStream_t stream = 0; + + switch (topk) { + case 2: + if (dtype == 0) { + CALL_MOE_SUM_KERNEL(half, 2); + } else if (dtype == 1) { + CALL_MOE_SUM_KERNEL(__nv_bfloat16, 2); + } else { + CALL_MOE_SUM_KERNEL(float, 2); + } + break; + case 3: + if (dtype == 0) { + CALL_MOE_SUM_KERNEL(half, 3); + } else if (dtype == 1) { + CALL_MOE_SUM_KERNEL(__nv_bfloat16, 3); + } else { + CALL_MOE_SUM_KERNEL(float, 3); + } + break; + case 4: + if (dtype == 0) { + CALL_MOE_SUM_KERNEL(half, 4); + } else if (dtype == 1) { + CALL_MOE_SUM_KERNEL(__nv_bfloat16, 4); + } else { + CALL_MOE_SUM_KERNEL(float, 4); + } + break; + default: + if (dtype == 0) { + CALL_MOE_SUM_KERNEL(half, 1); + } else if (dtype == 1) { + CALL_MOE_SUM_KERNEL(__nv_bfloat16, 1); + } else { + CALL_MOE_SUM_KERNEL(float, 1); + } + break; + } +} diff --git a/candle-moe/kernels/moe_wna16.cu b/candle-moe/kernels/moe_wna16.cu new file mode 100644 index 0000000..3dd1ab1 --- /dev/null +++ b/candle-moe/kernels/moe_wna16.cu @@ -0,0 +1,336 @@ +#include +#include + +#include + +#include "moe_wna16_utils.h" +#include "cuda_compat.h" + +#define DIVIDE(x, size) (((x) + (size) - 1) / (size)) + +template +__global__ void moe_wna16_gemm_kernel( + const scalar_t* __restrict__ input, scalar_t* __restrict__ output, + + const uint32_t* __restrict__ qweight, const scalar_t* __restrict__ scales, + const uint32_t* __restrict__ qzeros, + + const float* __restrict__ topk_weights, + const int32_t* __restrict__ sorted_token_ids, + const int32_t* __restrict__ expert_ids, + const int32_t* __restrict__ num_tokens_post_pad, + + uint16_t num_experts, uint16_t group_size, uint16_t top_k, uint32_t size_m, + uint32_t size_n, uint32_t size_k, uint16_t BLOCK_SIZE_M, + uint16_t BLOCK_SIZE_N, uint16_t BLOCK_SIZE_K, bool has_zp, + bool mul_topk_weight) { +#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800 + if constexpr (std::is_same::value) { + return; + } else { +#endif + + using Dtype = ScalarType; + using scalar_t2 = typename ScalarType::scalar_t2; + + if (blockIdx.x * BLOCK_SIZE_M >= num_tokens_post_pad[0]) return; + + const int32_t offset_n = blockIdx.y * BLOCK_SIZE_N + threadIdx.x; + const int32_t offset_k = blockIdx.z * BLOCK_SIZE_K; + + const int32_t expert_id = expert_ids[blockIdx.x]; + + int32_t num_valid_tokens = 0; + extern __shared__ uint16_t block_input_tmp[]; + scalar_t* block_input = reinterpret_cast(block_input_tmp); + scalar_t2* block_input_half2 = reinterpret_cast(block_input); + + // load BLOCK_SIZE_M * BLOCK_SIZE_K into shared memory + for (int m = 0; m < BLOCK_SIZE_M; m++) { + const int32_t offset_m = blockIdx.x * BLOCK_SIZE_M + m; + const int32_t token_index = sorted_token_ids[offset_m]; + if (token_index / top_k >= size_m) break; + + num_valid_tokens = m + 1; + if (blockIdx.z == 0 && offset_n < size_n) + output[token_index * size_n + offset_n] = Dtype::int2num(0); + + if (expert_id != -1) { + int k_per_thread = DIVIDE(BLOCK_SIZE_K, BLOCK_SIZE_N); + for (int i = 0; i < k_per_thread; i++) { + int k = BLOCK_SIZE_N * i + threadIdx.x; + if (k >= BLOCK_SIZE_K) break; + if (offset_k + k >= size_k) break; + + // load input to shared memory + // use a special layout to fit the layout of dequanted-weight + int origin_k; + if constexpr (bit == 4) { + // [0, 4, 1, 5, 2, 6, 3, 7] + int8_t order = (threadIdx.x % 2) * 4 + ((threadIdx.x % 8) / 2); + origin_k = BLOCK_SIZE_N * i + threadIdx.x / 8 * 8 + order; + } else { + // [0, 2, 1, 3] + int8_t order = (threadIdx.x % 2) * 2 + ((threadIdx.x % 4) / 2); + origin_k = BLOCK_SIZE_N * i + threadIdx.x / 4 * 4 + order; + } + + origin_k += token_index / top_k * size_k + blockIdx.z * BLOCK_SIZE_K; + block_input[m * BLOCK_SIZE_K + k] = input[origin_k]; + } + } + } + + if (expert_id == -1) return; + __syncthreads(); + if (threadIdx.x >= BLOCK_SIZE_N || offset_n >= size_n) return; + + float res[64]; // assume BLOCK_SIZE_M <= 64 + scalar_t2 res2; + scalar_t2 scale_f2; + scalar_t2 qzero_f2; + + // note that (size_n * size_k * expert_id) may greater than 2 ** 31 + constexpr int8_t pack_factor = 32 / bit; + const uint64_t expert_offset = ((uint64_t)size_n) * size_k * expert_id; + const uint32_t* expert_qweight = qweight + expert_offset / pack_factor; + const scalar_t* expert_scales = scales + expert_offset / group_size; + const uint32_t* expert_qzeros = + qzeros + expert_offset / group_size / pack_factor; + + // load 4*int32 one time: 4 int32 = 128 bit = 1 float4 + // weight would be loaded in loop + uint32_t expert_qweight_tmp[4]; + float4* expert_qweight_tmp_float4 = + reinterpret_cast(expert_qweight_tmp); + + // load all required scales one time + scalar_t expert_scales_groups[GROUPS]; + int scales_offset_tmp = + (offset_n * size_k + offset_k) / group_size / GROUPS; + if constexpr (GROUPS == 1) { + *expert_scales_groups = expert_scales[scales_offset_tmp]; + } else if constexpr (GROUPS == 2) { + float* expert_scales_groups_tmp = + reinterpret_cast(expert_scales_groups); + *expert_scales_groups_tmp = + reinterpret_cast(expert_scales)[scales_offset_tmp]; + } else if constexpr (GROUPS == 4) { + float2* expert_scales_groups_tmp = + reinterpret_cast(expert_scales_groups); + *expert_scales_groups_tmp = + reinterpret_cast(expert_scales)[scales_offset_tmp]; + } else if constexpr (GROUPS == 8) { + float4* expert_scales_groups_tmp = + reinterpret_cast(expert_scales_groups); + *expert_scales_groups_tmp = + reinterpret_cast(expert_scales)[scales_offset_tmp]; + } + + // load all required qzeros one time + uint8_t expert_qzeros_groups[GROUPS]; + if (!has_zp) { + if constexpr (bit == 4) { + qzero_f2 = Dtype::num2num2(Dtype::int2num(8)); + } else { + qzero_f2 = Dtype::num2num2(Dtype::int2num(128)); + } + } else { + int qzeros_offset_tmp = + (offset_n / (8 / bit)) * (size_k / group_size / GROUPS) + + offset_k / group_size / GROUPS; + if constexpr (GROUPS == 1) { + uint8_t* expert_qzeros_groups_tmp = + reinterpret_cast(expert_qzeros_groups); + *expert_qzeros_groups_tmp = + reinterpret_cast(expert_qzeros)[qzeros_offset_tmp]; + } else if constexpr (GROUPS == 2) { + uint16_t* expert_qzeros_groups_tmp = + reinterpret_cast(expert_qzeros_groups); + *expert_qzeros_groups_tmp = + reinterpret_cast(expert_qzeros)[qzeros_offset_tmp]; + } else if constexpr (GROUPS == 4) { + uint32_t* expert_qzeros_groups_tmp = + reinterpret_cast(expert_qzeros_groups); + *expert_qzeros_groups_tmp = + reinterpret_cast(expert_qzeros)[qzeros_offset_tmp]; + } else if constexpr (GROUPS == 8) { + uint64_t* expert_qzeros_groups_tmp = + reinterpret_cast(expert_qzeros_groups); + *expert_qzeros_groups_tmp = + reinterpret_cast(expert_qzeros)[qzeros_offset_tmp]; + } + } + + for (int tmp_k = 0; tmp_k < BLOCK_SIZE_K / pack_factor; tmp_k++) { + int k = offset_k + tmp_k * pack_factor; + if (k >= size_k) break; + const int32_t weight_offset = offset_n * size_k + k; + + if (tmp_k % 4 == 0) { + *expert_qweight_tmp_float4 = reinterpret_cast( + expert_qweight)[weight_offset / pack_factor / 4]; + } + + if (tmp_k % (group_size / pack_factor) == 0) { + scalar_t scale_f = + expert_scales_groups[tmp_k / (group_size / pack_factor)]; + scale_f2 = Dtype::num2num2(scale_f); + + if (has_zp) { + uint8_t qzero = + expert_qzeros_groups[tmp_k / (group_size / pack_factor)]; + if constexpr (bit == 4) { + qzero = (qzero >> ((threadIdx.x % 2) * 4)) & 0xF; + } + qzero_f2 = Dtype::num2num2(Dtype::int2num(qzero)); + } + } + + scalar_t2 weight_half2[16 / bit]; + dequant(expert_qweight_tmp[tmp_k % 4], weight_half2); + + for (int m = 0; m < num_valid_tokens; m++) { + res2 = {}; + +#pragma unroll + for (int i = 0; i < 16 / bit; i++) { + int32_t offset_input = m * BLOCK_SIZE_K / 2 + tmp_k * (16 / bit) + i; + res2 = __hfma2(__hmul2(__hsub2(weight_half2[i], qzero_f2), scale_f2), + block_input_half2[offset_input], res2); + } + + if (tmp_k == 0) { + res[m] = Dtype::num2float(res2.x) + Dtype::num2float(res2.y); + } else { + res[m] += Dtype::num2float(res2.x) + Dtype::num2float(res2.y); + } + } + } + + for (int m = 0; m < num_valid_tokens; ++m) { + const int32_t token_index = + sorted_token_ids[blockIdx.x * BLOCK_SIZE_M + m]; + if (mul_topk_weight) { + res[m] *= topk_weights[token_index]; + } + atomicAdd(&output[token_index * size_n + offset_n], + Dtype::float2num(res[m])); + } + +#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800 + } +#endif +} + +template +void run_moe_wna16_gemm(const scalar_t* input, scalar_t* output, + const uint32_t* b_qweight, const scalar_t* b_scales, + const uint32_t* b_qzeros, const float* topk_weights, + const int32_t* sorted_token_ids, + const int32_t* expert_ids, + const int32_t* num_tokens_post_pad, int num_experts, + int group_size, int num_token_blocks, int top_k, + int size_m, int size_n, int size_k, int BLOCK_SIZE_M, + int BLOCK_SIZE_N, int BLOCK_SIZE_K, int bit, + bool has_zp, bool mul_topk_weight) { + dim3 blockDim, gridDim; + blockDim.x = BLOCK_SIZE_N; + blockDim.y = 1; + blockDim.z = 1; + gridDim.x = num_token_blocks; + gridDim.y = DIVIDE(size_n, BLOCK_SIZE_N); + gridDim.z = DIVIDE(size_k, BLOCK_SIZE_K); + + auto kernel = moe_wna16_gemm_kernel; + if (bit == 4) { + if (BLOCK_SIZE_K / group_size == 2) { + kernel = moe_wna16_gemm_kernel; + } else if (BLOCK_SIZE_K / group_size == 4) { + kernel = moe_wna16_gemm_kernel; + } else if (BLOCK_SIZE_K / group_size == 8) { + kernel = moe_wna16_gemm_kernel; + } + } else { + if (BLOCK_SIZE_K / group_size == 1) { + kernel = moe_wna16_gemm_kernel; + } else if (BLOCK_SIZE_K / group_size == 2) { + kernel = moe_wna16_gemm_kernel; + } else if (BLOCK_SIZE_K / group_size == 4) { + kernel = moe_wna16_gemm_kernel; + } else if (BLOCK_SIZE_K / group_size == 8) { + kernel = moe_wna16_gemm_kernel; + } + } + + const int shared_mem_size = BLOCK_SIZE_M * BLOCK_SIZE_K * 2; + const cudaStream_t stream = 0; + + kernel<<>>( + input, output, b_qweight, b_scales, b_qzeros, topk_weights, + sorted_token_ids, expert_ids, num_tokens_post_pad, num_experts, + group_size, top_k, size_m, size_n, size_k, BLOCK_SIZE_M, BLOCK_SIZE_N, + BLOCK_SIZE_K, has_zp, mul_topk_weight); +} + +#define CALL_MOE_WNA16_GEMM(T) \ + run_moe_wna16_gemm( \ + reinterpret_cast(input), \ + reinterpret_cast(output), \ + reinterpret_cast(b_qweight), \ + reinterpret_cast(b_scales), \ + reinterpret_cast(b_qzeros), \ + reinterpret_cast(topk_weights), \ + reinterpret_cast(sorted_token_ids), \ + reinterpret_cast(expert_ids), \ + reinterpret_cast(num_tokens_post_pad), \ + num_experts, \ + group_size, \ + num_token_blocks, \ + top_k, \ + size_m, \ + size_n, \ + size_k, \ + BLOCK_SIZE_M, \ + BLOCK_SIZE_N, \ + BLOCK_SIZE_K, \ + bit, \ + has_zp, \ + mul_topk_weight); + +extern "C" void moe_wna16_gemm( + void *input, + void *output, + void *b_qweight, + void *b_scales, + void *b_qzeros, + void *topk_weights, + void *sorted_token_ids, + void *expert_ids, + void *num_tokens_post_pad, + int64_t top_k, + int64_t BLOCK_SIZE_M, + int64_t BLOCK_SIZE_N, + int64_t BLOCK_SIZE_K, + int64_t bit, + int32_t num_experts, + int32_t size_m, + int32_t size_n, + int32_t size_k, + int32_t group_size, + int64_t EM, + bool has_zp, + bool mul_topk_weight, + uint32_t dtype // 0 => f16; 1 => bf16; 2 => f32 +) { + if (size_m <= BLOCK_SIZE_M) { + EM = min(EM, size_m * BLOCK_SIZE_M * top_k); + } + const int num_token_blocks = (EM + BLOCK_SIZE_M - 1) / BLOCK_SIZE_M; + + if (dtype == 0) { + CALL_MOE_WNA16_GEMM(half); + } else if (dtype == 1) { + CALL_MOE_WNA16_GEMM(__nv_bfloat16); + } +} diff --git a/candle-moe/kernels/moe_wna16_utils.h b/candle-moe/kernels/moe_wna16_utils.h new file mode 100644 index 0000000..4396b80 --- /dev/null +++ b/candle-moe/kernels/moe_wna16_utils.h @@ -0,0 +1,200 @@ + +#include +#include + +template +class ScalarType {}; + +template <> +class ScalarType { + public: + using scalar_t = half; + using scalar_t2 = half2; + + static __device__ float inline num2float(const half x) { + return __half2float(x); + } + + static __device__ half2 inline num2num2(const half x) { + return __half2half2(x); + } + + static __device__ half2 inline nums2num2(const half x1, const half x2) { + return __halves2half2(x1, x2); + } + + static __host__ __device__ half inline float2num(const float x) { + return __float2half(x); + } + + static __host__ __device__ half inline int2num(const float x) { + return __int2half_rn(x); + } + + static __host__ __device__ float2 inline num22float2(const half2 x) { + return __half22float2(x); + } + + static __host__ __device__ half2 inline float22num2(const float2 x) { + return __float22half2_rn(x); + } +}; + +template <> +class ScalarType { + public: + using scalar_t = nv_bfloat16; + using scalar_t2 = nv_bfloat162; + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + static __device__ float inline num2float(const nv_bfloat16 x) { + return __bfloat162float(x); + } + + static __device__ nv_bfloat162 inline num2num2(const nv_bfloat16 x) { + return __bfloat162bfloat162(x); + } + + static __device__ nv_bfloat162 inline nums2num2(const nv_bfloat16 x1, + const nv_bfloat16 x2) { + return __halves2bfloat162(x1, x2); + } + + static __host__ __device__ nv_bfloat16 inline float2num(const float x) { + return __float2bfloat16(x); + } + + static __host__ __device__ nv_bfloat16 inline int2num(const float x) { + return __int2bfloat16_rn(x); + } + + static __host__ __device__ float2 inline num22float2(const nv_bfloat162 x) { + return __bfloat1622float2(x); + } + + static __host__ __device__ nv_bfloat162 inline float22num2(const float2 x) { + return __float22bfloat162_rn(x); + } +#endif +}; + +template +__device__ inline int lop3(int a, int b, int c) { + int res; + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(res) + : "r"(a), "r"(b), "r"(c), "n"(lut)); + return res; +} + +template +__device__ inline uint32_t prmt(uint32_t a) { + uint32_t res; + asm volatile("prmt.b32 %0, %1, %2, %3;\n" + : "=r"(res) + : "r"(a), "n"(start_byte), "n"(mask)); + return res; +} + +template +__device__ inline void dequant(int q, scalar_t2* res) {} + +template <> +__device__ inline void dequant(int q, half2* res) { + const int LO = 0x000f000f; + const int HI = 0x00f000f0; + const int EX = 0x64006400; + const int SUB = 0x64006400; + const int MUL = 0x2c002c00; + const int ADD = 0xd400d400; + + int lo0 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, LO, EX); + int hi0 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, HI, EX); + q >>= 8; + int lo1 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, LO, EX); + int hi1 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, HI, EX); + + res[0] = __hsub2(*reinterpret_cast(&lo0), + *reinterpret_cast(&SUB)); + res[1] = __hfma2(*reinterpret_cast(&hi0), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + res[2] = __hsub2(*reinterpret_cast(&lo1), + *reinterpret_cast(&SUB)); + res[3] = __hfma2(*reinterpret_cast(&hi1), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); +} + +template <> +__device__ inline void dequant(int q, half2* res) { + static constexpr uint32_t mask_for_elt_01 = 0x5250; + static constexpr uint32_t mask_for_elt_23 = 0x5351; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + + uint32_t lo = prmt(q); + uint32_t hi = prmt(q); + + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400; + + res[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + res[1] = __hsub2(*reinterpret_cast(&hi), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); +} + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +template <> +__device__ inline void dequant(int q, nv_bfloat162* res) { + static constexpr uint32_t MASK = 0x000f000f; + static constexpr uint32_t EX = 0x43004300; + + int lo0 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX); + q >>= 4; + int hi0 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX); + q >>= 4; + int lo1 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX); + q >>= 4; + int hi1 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX); + + static constexpr uint32_t MUL = 0x3F803F80; + static constexpr uint32_t ADD = 0xC300C300; + + res[0] = __hfma2(*reinterpret_cast(&lo0), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + res[1] = __hfma2(*reinterpret_cast(&hi0), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + res[2] = __hfma2(*reinterpret_cast(&lo1), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + res[3] = __hfma2(*reinterpret_cast(&hi1), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); +} + +template <> +__device__ inline void dequant(int q, nv_bfloat162* res) { + float fp32_intermediates[4]; + uint32_t* fp32_intermediates_casted = + reinterpret_cast(fp32_intermediates); + + static constexpr uint32_t fp32_base = 0x4B000000; + fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650); + fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652); + fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651); + fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653); + + fp32_intermediates[0] -= 8388608.f; + fp32_intermediates[1] -= 8388608.f; + fp32_intermediates[2] -= 8388608.f; + fp32_intermediates[3] -= 8388608.f; + + uint32_t* bf16_result_ptr = reinterpret_cast(res); + bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], + fp32_intermediates_casted[1], 0x7632); + bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], + fp32_intermediates_casted[3], 0x7632); +} +#endif diff --git a/candle-moe/kernels/topk_softmax_kernels.cu b/candle-moe/kernels/topk_softmax_kernels.cu new file mode 100644 index 0000000..e852a78 --- /dev/null +++ b/candle-moe/kernels/topk_softmax_kernels.cu @@ -0,0 +1,496 @@ +/* + * Adapted from https://github.com/NVIDIA/TensorRT-LLM/blob/v0.7.1/cpp/tensorrt_llm/kernels/mixtureOfExperts/moe_kernels.cu + * Copyright (c) 2024, The vLLM team. + * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef USE_ROCM + #include + #include +#else + #include + #include +#endif + +#include +#include + +#include "cuda_compat.h" + +#define MAX(a, b) ((a) > (b) ? (a) : (b)) +#define MIN(a, b) ((a) < (b) ? (a) : (b)) + +namespace vllm { +namespace moe { + +/// Aligned array type +template < + typename T, + /// Number of elements in the array + int N, + /// Alignment requirement in bytes + int Alignment = sizeof(T) * N +> +class alignas(Alignment) AlignedArray { + float data[N]; +}; + +// ====================== Softmax things =============================== +// We have our own implementation of softmax here so we can support transposing the output +// in the softmax kernel when we extend this module to support expert-choice routing. +template +__launch_bounds__(TPB) __global__ + void moeSoftmax(const float* input, const bool* finished, float* output, const int num_cols) +{ + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage tmpStorage; + + __shared__ float normalizing_factor; + __shared__ float float_max; + + const int thread_row_offset = blockIdx.x * num_cols; + + cub::Sum sum; + float threadData(-FLT_MAX); + + // Don't touch finished rows. + if ((finished != nullptr) && finished[blockIdx.x]) + { + return; + } + + for (int ii = threadIdx.x; ii < num_cols; ii += TPB) + { + const int idx = thread_row_offset + ii; + threadData = max(static_cast(input[idx]), threadData); + } + + const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max()); + if (threadIdx.x == 0) + { + float_max = maxElem; + } + __syncthreads(); + + threadData = 0; + + for (int ii = threadIdx.x; ii < num_cols; ii += TPB) + { + const int idx = thread_row_offset + ii; + threadData += exp((static_cast(input[idx]) - float_max)); + } + + const auto Z = BlockReduce(tmpStorage).Reduce(threadData, sum); + + if (threadIdx.x == 0) + { + normalizing_factor = 1.f / Z; + } + __syncthreads(); + + for (int ii = threadIdx.x; ii < num_cols; ii += TPB) + { + const int idx = thread_row_offset + ii; + const float val = exp((static_cast(input[idx]) - float_max)) * normalizing_factor; + output[idx] = val; + } +} + +template +__launch_bounds__(TPB) __global__ void moeTopK(const float* inputs_after_softmax, const bool* finished, float* output, + int* indices, int* source_rows, const int num_experts, const int k, const int start_expert, const int end_expert) +{ + + using cub_kvp = cub::KeyValuePair; + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage tmpStorage; + + cub_kvp thread_kvp; + cub::ArgMax arg_max; + + const int num_rows = gridDim.x; + const int block_row = blockIdx.x; + + const bool row_is_active = finished ? !finished[block_row] : true; + const int thread_read_offset = blockIdx.x * num_experts; + for (int k_idx = 0; k_idx < k; ++k_idx) + { + thread_kvp.key = 0; + thread_kvp.value = -1.f; // This is OK because inputs are probabilities + + cub_kvp inp_kvp; + for (int expert = threadIdx.x; expert < num_experts; expert += TPB) + { + const int idx = thread_read_offset + expert; + inp_kvp.key = expert; + inp_kvp.value = inputs_after_softmax[idx]; + + for (int prior_k = 0; prior_k < k_idx; ++prior_k) + { + const int prior_winning_expert = indices[k * block_row + prior_k]; + + if (prior_winning_expert == expert) + { + inp_kvp = thread_kvp; + } + } + + thread_kvp = arg_max(inp_kvp, thread_kvp); + } + + const cub_kvp result_kvp = BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max); + if (threadIdx.x == 0) + { + // Ignore experts the node isn't responsible for with expert parallelism + const int expert = result_kvp.key; + const bool node_uses_expert = expert >= start_expert && expert < end_expert; + const bool should_process_row = row_is_active && node_uses_expert; + + const int idx = k * block_row + k_idx; + output[idx] = result_kvp.value; + indices[idx] = should_process_row ? (expert - start_expert) : num_experts; + assert(indices[idx] >= 0); + source_rows[idx] = k_idx * num_rows + block_row; + } + __syncthreads(); + } +} + +// ====================== TopK softmax things =============================== + +/* + A Top-K gating softmax written to exploit when the number of experts in the MoE layers + are a small power of 2. This allows us to cleanly share the rows among the threads in + a single warp and eliminate communication between warps (so no need to use shared mem). + + It fuses the softmax, max and argmax into a single kernel. + + Limitations: + 1) This implementation is intended for when the number of experts is a small power of 2. + 2) This implementation assumes k is small, but will work for any k. +*/ + +template +__launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ + void topkGatingSoftmax(const float* input, const bool* finished, float* output, const int num_rows, int* indices, + int* source_rows, const int k, const int start_expert, const int end_expert) +{ + // We begin by enforcing compile time assertions and setting up compile time constants. + static_assert(VPT == (VPT & -VPT), "VPT must be power of 2"); + static_assert(NUM_EXPERTS == (NUM_EXPERTS & -NUM_EXPERTS), "NUM_EXPERTS must be power of 2"); + static_assert(BYTES_PER_LDG == (BYTES_PER_LDG & -BYTES_PER_LDG), "BYTES_PER_LDG must be power of 2"); + static_assert(BYTES_PER_LDG <= 16, "BYTES_PER_LDG must be leq 16"); + + // Number of bytes each thread pulls in per load + static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float); + static constexpr int ELTS_PER_ROW = NUM_EXPERTS; + static constexpr int THREADS_PER_ROW = ELTS_PER_ROW / VPT; + static constexpr int LDG_PER_THREAD = VPT / ELTS_PER_LDG; + + // Restrictions based on previous section. + static_assert(VPT % ELTS_PER_LDG == 0, "The elements per thread must be a multiple of the elements per ldg"); + static_assert(WARP_SIZE % THREADS_PER_ROW == 0, "The threads per row must cleanly divide the threads per warp"); + static_assert(THREADS_PER_ROW == (THREADS_PER_ROW & -THREADS_PER_ROW), "THREADS_PER_ROW must be power of 2"); + static_assert(THREADS_PER_ROW <= WARP_SIZE, "THREADS_PER_ROW can be at most warp size"); + + // We have NUM_EXPERTS elements per row. We specialize for small #experts + static constexpr int ELTS_PER_WARP = WARP_SIZE * VPT; + static constexpr int ROWS_PER_WARP = ELTS_PER_WARP / ELTS_PER_ROW; + static constexpr int ROWS_PER_CTA = WARPS_PER_CTA * ROWS_PER_WARP; + + // Restrictions for previous section. + static_assert(ELTS_PER_WARP % ELTS_PER_ROW == 0, "The elts per row must cleanly divide the total elt per warp"); + + // ===================== From this point, we finally start computing run-time variables. ======================== + + // Compute CTA and warp rows. We pack multiple rows into a single warp, and a block contains WARPS_PER_CTA warps. + // This, each block processes a chunk of rows. We start by computing the start row for each block. + const int cta_base_row = blockIdx.x * ROWS_PER_CTA; + + // Now, using the base row per thread block, we compute the base row per warp. + const int warp_base_row = cta_base_row + threadIdx.y * ROWS_PER_WARP; + + // The threads in a warp are split into sub-groups that will work on a row. + // We compute row offset for each thread sub-group + const int thread_row_in_warp = threadIdx.x / THREADS_PER_ROW; + const int thread_row = warp_base_row + thread_row_in_warp; + + // Threads with indices out of bounds should early exit here. + if (thread_row >= num_rows) + { + return; + } + const bool row_is_active = finished ? !finished[thread_row] : true; + + // We finally start setting up the read pointers for each thread. First, each thread jumps to the start of the + // row it will read. + const float* thread_row_ptr = input + thread_row * ELTS_PER_ROW; + + // Now, we compute the group each thread belong to in order to determine the first column to start loads. + const int thread_group_idx = threadIdx.x % THREADS_PER_ROW; + const int first_elt_read_by_thread = thread_group_idx * ELTS_PER_LDG; + const float* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread; + + // Determine the pointer type to use to read in the data depending on the BYTES_PER_LDG template param. In theory, + // this can support all powers of 2 up to 16. + // NOTE(woosuk): The original implementation uses CUTLASS aligned array here. + // We defined our own aligned array and use it here to avoid the dependency on CUTLASS. + using AccessType = AlignedArray; + + // Finally, we pull in the data from global mem + float row_chunk[VPT]; + AccessType* row_chunk_vec_ptr = reinterpret_cast(&row_chunk); + const AccessType* vec_thread_read_ptr = reinterpret_cast(thread_read_ptr); +#pragma unroll + for (int ii = 0; ii < LDG_PER_THREAD; ++ii) + { + row_chunk_vec_ptr[ii] = vec_thread_read_ptr[ii * THREADS_PER_ROW]; + } + + // First, we perform a max reduce within the thread. We can do the max in fp16 safely (I think) and just + // convert to float afterwards for the exp + sum reduction. + float thread_max = row_chunk[0]; +#pragma unroll + for (int ii = 1; ii < VPT; ++ii) + { + thread_max = max(thread_max, row_chunk[ii]); + } + +// Now, we find the max within the thread group and distribute among the threads. We use a butterfly reduce. +#pragma unroll + for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) + { + thread_max = max(thread_max, VLLM_SHFL_XOR_SYNC_WIDTH(thread_max, mask, THREADS_PER_ROW)); + } + + // From this point, thread max in all the threads have the max within the row. + // Now, we subtract the max from each element in the thread and take the exp. We also compute the thread local sum. + float row_sum = 0; +#pragma unroll + for (int ii = 0; ii < VPT; ++ii) + { + row_chunk[ii] = expf(row_chunk[ii] - thread_max); + row_sum += row_chunk[ii]; + } + +// Now, we perform the sum reduce within each thread group. Similar to the max reduce, we use a bufferfly pattern. +#pragma unroll + for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) + { + row_sum += VLLM_SHFL_XOR_SYNC_WIDTH(row_sum, mask, THREADS_PER_ROW); + } + + // From this point, all threads have the max and the sum for their rows in the thread_max and thread_sum variables + // respectively. Finally, we can scale the rows for the softmax. Technically, for top-k gating we don't need to + // compute the entire softmax row. We can likely look at the maxes and only compute for the top-k values in the row. + // However, this kernel will likely not be a bottle neck and it seems better to closer match torch and find the + // argmax after computing the softmax. + const float reciprocal_row_sum = 1.f / row_sum; + +#pragma unroll + for (int ii = 0; ii < VPT; ++ii) + { + row_chunk[ii] = row_chunk[ii] * reciprocal_row_sum; + } + + // Now, softmax_res contains the softmax of the row chunk. Now, I want to find the topk elements in each row, along + // with the max index. + int start_col = first_elt_read_by_thread; + static constexpr int COLS_PER_GROUP_LDG = ELTS_PER_LDG * THREADS_PER_ROW; + + for (int k_idx = 0; k_idx < k; ++k_idx) + { + // First, each thread does the local argmax + float max_val = row_chunk[0]; + int expert = start_col; +#pragma unroll + for (int ldg = 0, col = start_col; ldg < LDG_PER_THREAD; ++ldg, col += COLS_PER_GROUP_LDG) + { +#pragma unroll + for (int ii = 0; ii < ELTS_PER_LDG; ++ii) + { + float val = row_chunk[ldg * ELTS_PER_LDG + ii]; + + // No check on the experts here since columns with the smallest index are processed first and only + // updated if > (not >=) + if (val > max_val) + { + max_val = val; + expert = col + ii; + } + } + } + +// Now, we perform the argmax reduce. We use the butterfly pattern so threads reach consensus about the max. +// This will be useful for K > 1 so that the threads can agree on "who" had the max value. That thread can +// then blank out their max with -inf and the warp can run more iterations... +#pragma unroll + for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) + { + float other_max = VLLM_SHFL_XOR_SYNC_WIDTH(max_val, mask, THREADS_PER_ROW); + int other_expert = VLLM_SHFL_XOR_SYNC_WIDTH(expert, mask, THREADS_PER_ROW); + + // We want lower indices to "win" in every thread so we break ties this way + if (other_max > max_val || (other_max == max_val && other_expert < expert)) + { + max_val = other_max; + expert = other_expert; + } + } + + // Write the max for this k iteration to global memory. + if (thread_group_idx == 0) + { + // Add a guard to ignore experts not included by this node + const bool node_uses_expert = expert >= start_expert && expert < end_expert; + const bool should_process_row = row_is_active && node_uses_expert; + + // The lead thread from each sub-group will write out the final results to global memory. (This will be a + // single) thread per row of the input/output matrices. + const int idx = k * thread_row + k_idx; + output[idx] = max_val; + indices[idx] = should_process_row ? (expert - start_expert) : NUM_EXPERTS; + source_rows[idx] = k_idx * num_rows + thread_row; + } + + // Finally, we clear the value in the thread with the current max if there is another iteration to run. + if (k_idx + 1 < k) + { + const int ldg_group_for_expert = expert / COLS_PER_GROUP_LDG; + const int thread_to_clear_in_group = (expert / ELTS_PER_LDG) % THREADS_PER_ROW; + + // Only the thread in the group which produced the max will reset the "winning" value to -inf. + if (thread_group_idx == thread_to_clear_in_group) + { + const int offset_for_expert = expert % ELTS_PER_LDG; + // Safe to set to any negative value since row_chunk values must be between 0 and 1. + row_chunk[ldg_group_for_expert * ELTS_PER_LDG + offset_for_expert] = -10000.f; + } + } + } +} + +namespace detail +{ +// Constructs some constants needed to partition the work across threads at compile time. +template +struct TopkConstants +{ + static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float); + static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE) == 0, ""); + static constexpr int VECs_PER_THREAD = MAX(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE)); + static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG; + static constexpr int THREADS_PER_ROW = EXPERTS / VPT; + static constexpr int ROWS_PER_WARP = WARP_SIZE / THREADS_PER_ROW; +}; +} // namespace detail + +template +void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, float* output, int* indices, + int* source_row, const int num_rows, const int k, const int start_expert, const int end_expert, cudaStream_t stream) +{ + static constexpr std::size_t MAX_BYTES_PER_LDG = 16; + + static constexpr int BYTES_PER_LDG = MIN(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS); + using Constants = detail::TopkConstants; + static constexpr int VPT = Constants::VPT; + static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP; + const int num_warps = (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP; + const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB; + + dim3 block_dim(WARP_SIZE, WARPS_PER_TB); + topkGatingSoftmax<<>>( + input, finished, output, num_rows, indices, source_row, k, start_expert, end_expert); +} + +#define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB) \ + topkGatingSoftmaxLauncherHelper( \ + gating_output, nullptr, topk_weights, topk_indicies, \ + token_expert_indices, num_tokens, topk, 0, num_experts, \ + stream); + +void topkGatingSoftmaxKernelLauncher( + const float* gating_output, + float* topk_weights, + int* topk_indicies, + int* token_expert_indices, + const int num_tokens, + const int num_experts, + const int topk, + cudaStream_t stream +) { + static constexpr int WARPS_PER_TB = 4; + switch (num_experts) { + case 1: + LAUNCH_SOFTMAX(1, WARPS_PER_TB); + break; + case 2: + LAUNCH_SOFTMAX(2, WARPS_PER_TB); + break; + case 4: + LAUNCH_SOFTMAX(4, WARPS_PER_TB); + break; + case 8: + LAUNCH_SOFTMAX(8, WARPS_PER_TB); + break; + case 16: + LAUNCH_SOFTMAX(16, WARPS_PER_TB); + break; + case 32: + LAUNCH_SOFTMAX(32, WARPS_PER_TB); + break; + case 64: + LAUNCH_SOFTMAX(64, WARPS_PER_TB); + break; + case 128: + LAUNCH_SOFTMAX(128, WARPS_PER_TB); + break; + case 256: + LAUNCH_SOFTMAX(256, WARPS_PER_TB); + break; + default: { + LAUNCH_SOFTMAX(256, WARPS_PER_TB); + break; + } + } +} + +} // namespace moe +} // namespace vllm + +extern "C" void topk_softmax( + void *gating_output, // [num_tokens, num_experts] + void *topk_weights, // [num_tokens, topk] + void *topk_indices, // [num_tokens, topk] + void *token_expert_indices, // [num_tokens, topk] + + int32_t num_experts, + int64_t num_tokens, + int32_t topk +) { + const cudaStream_t stream = 0; + + vllm::moe::topkGatingSoftmaxKernelLauncher( + reinterpret_cast(gating_output), + reinterpret_cast(topk_weights), + reinterpret_cast(topk_indices), + reinterpret_cast(token_expert_indices), + num_tokens, + num_experts, + topk, + stream + ); +} diff --git a/candle-moe/src/ffi.rs b/candle-moe/src/ffi.rs new file mode 100644 index 0000000..2767393 --- /dev/null +++ b/candle-moe/src/ffi.rs @@ -0,0 +1,65 @@ +use core::ffi::{c_int, c_long, c_void}; + +unsafe extern "C" { + pub(crate) fn topk_softmax( + gating_output: *const c_void, + topk_weight: *const c_void, + topk_indices: *const c_void, + token_expert_indices: *const c_void, + + num_experts: c_int, + num_tokens: c_long, + topk: c_int, + ); + + pub(crate) fn moe_sum( + input: *const c_void, + output: *const c_void, + hidden_size: c_int, + num_token: c_long, + topk: c_int, + dtype: u32, + ); + + #[allow(dead_code)] + pub(crate) fn moe_align_block_size( + topk_ids: *const c_void, + num_experts: c_long, + block_size: c_long, + numel: c_long, + sorted_token_ids: *const c_void, + experts_ids: *const c_void, + num_tokens_post_pad: *const c_void, + dtype: u32, + ); + + #[allow(dead_code)] + pub(crate) fn moe_wna16_gemm( + input: *const c_void, + output: *const c_void, + b_qweight: *const c_void, + b_scales: *const c_void, + b_qzeros: *const c_void, + topk_weights: *const c_void, + sorted_token_ids: *const c_void, + expert_ids: *const c_void, + num_tokens_post_pad: *const c_void, + + top_k: c_long, + BLOCK_SIZE_M: c_long, + BLOCK_SIZE_N: c_long, + BLOCK_SIZE_K: c_long, + bit: c_long, + + num_experts: c_int, + size_m: c_int, + size_n: c_int, + size_k: c_int, + group_size: c_int, + EM: c_long, + has_zp: bool, + mul_topk_weight: bool, + + dtype: u32, + ); +} diff --git a/candle-moe/src/lib.rs b/candle-moe/src/lib.rs new file mode 100644 index 0000000..38b2144 --- /dev/null +++ b/candle-moe/src/lib.rs @@ -0,0 +1,223 @@ +mod ffi; + +use candle::cuda_backend::cudarc::driver::DevicePtr; +use candle::{DType, Result, Storage, Tensor}; +use half::{bf16, f16}; + +pub fn apply_topk_softmax_< + T: candle::cuda_backend::CudaDType + candle::cuda_backend::cudarc::driver::DeviceRepr, +>( + gating_output: &Tensor, + topk_weight: &Tensor, + topk_indices: &Tensor, + token_expert_indices: &Tensor, +) -> Result<()> { + let (g, g_l) = gating_output.storage_and_layout(); + let g: &candle::CudaStorage = match &*g { + Storage::Cuda(g) => g, + _ => candle::bail!("gating_output must be a cuda tensor"), + }; + + let (w, w_l) = topk_weight.storage_and_layout(); + let w = match &*w { + Storage::Cuda(w) => w, + _ => candle::bail!("topk_weight must be a cuda tensor"), + }; + + let (i, i_l) = topk_indices.storage_and_layout(); + let i = match &*i { + Storage::Cuda(i) => i, + _ => candle::bail!("topk_indices must be a cuda tensor"), + }; + + let (ei, ei_l) = token_expert_indices.storage_and_layout(); + let ei: &candle::CudaStorage = match &*ei { + Storage::Cuda(ei) => ei, + _ => candle::bail!("token_expert_indices must be a cuda tensor"), + }; + + let g_rank = g_l.stride().len(); + let w_rank = w_l.stride().len(); + let i_rank = i_l.stride().len(); + let ei_rank = ei_l.stride().len(); + + if g_rank != 2 || w_rank != 2 || i_rank != 2 || ei_rank != 2 { + candle::bail!( + "apply_topk_softmax_inplace expects input tensors of rank 2 (w: {w_l:?}, i: {i_l:?}, ei: {ei_l:?}, g: {g_l:?})" + ) + } + + // Get cuda slices for all tensors + let g = g.as_cuda_slice::()?; + let w = w.as_cuda_slice::()?; + let i = i.as_cuda_slice::()?; + let ei = ei.as_cuda_slice::()?; + + // Get cuda views for all tensors + let g = g.slice(g_l.start_offset()..); + let w = w.slice(w_l.start_offset()..); + let i = i.slice(i_l.start_offset()..); + let ei = ei.slice(ei_l.start_offset()..); + + let (num_tokens, top_k) = w_l.shape().dims2()?; + let (_, num_experts) = g_l.shape().dims2()?; + + let is_pow2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0); + if !is_pow2 || num_experts > 256 { + candle::bail!( + "num_experts should be power of 2 and smaller than 256 (num_experts: {num_experts:?})" + ) + } + + if (num_tokens, top_k) != i_l.shape().dims2()? { + candle::bail!( + "shape mismatch topk_indices {:?}, expected {:?}", + i_l.shape(), + (num_tokens, top_k) + ) + } + + if (num_tokens, top_k) != ei_l.shape().dims2()? { + candle::bail!( + "shape mismatch token_expert_indices {:?}, expected {:?}", + ei_l.shape(), + (num_tokens, top_k) + ) + } + + let gate_ptr = *g.device_ptr() as *const core::ffi::c_void; + let weight_ptr = *w.device_ptr() as *const core::ffi::c_void; + let indices_ptr = *i.device_ptr() as *const core::ffi::c_void; + let expert_indices_ptr = *ei.device_ptr() as *const core::ffi::c_void; + + unsafe { + ffi::topk_softmax( + gate_ptr, + weight_ptr, + indices_ptr, + expert_indices_ptr, + num_experts as i32, + num_tokens as i32, + top_k as i32, + ) + } + + Ok(()) +} + +pub fn apply_topk_softmax_inplace( + gating_output: &Tensor, + topk_weight: &Tensor, + topk_indices: &Tensor, + token_expert_indices: &Tensor, +) -> Result<()> { + match topk_weight.dtype() { + DType::F16 => apply_topk_softmax_::( + gating_output, + topk_weight, + topk_indices, + token_expert_indices, + ), + DType::BF16 => apply_topk_softmax_::( + gating_output, + topk_weight, + topk_indices, + token_expert_indices, + ), + DType::F32 => apply_topk_softmax_::( + gating_output, + topk_weight, + topk_indices, + token_expert_indices, + ), + dt => { + candle::bail!( + "apply_topk_softmax_inplace is only supported for f32, f16 and bf16 ({dt:?})" + ) + } + } +} + +pub fn apply_moe_sum_< + T: candle::cuda_backend::CudaDType + candle::cuda_backend::cudarc::driver::DeviceRepr, +>( + input: &Tensor, + output: &Tensor, + num_token: usize, + topk: usize, + dtype: u32, +) -> Result<()> { + let (i, i_l) = input.storage_and_layout(); + let i: &candle::CudaStorage = match &*i { + Storage::Cuda(i) => i, + _ => candle::bail!("input must be a cuda tensor"), + }; + + let (o, o_l) = output.storage_and_layout(); + let o: &candle::CudaStorage = match &*o { + Storage::Cuda(o) => o, + _ => candle::bail!("output must be a cuda tensor"), + }; + + let i_rank = i_l.stride().len(); + let o_rank = o_l.stride().len(); + + if i_rank != 3 { + candle::bail!("input should be rank 3 (input: {i_l:?})") + } + + if o_rank != 2 { + candle::bail!("output should be rank 2 (input: {o_l:?})") + } + + // Get cuda slices for all tensors + let i = i.as_cuda_slice::()?; + let o = o.as_cuda_slice::()?; + + // Get cuda views for all tensors + let i = i.slice(i_l.start_offset()..); + let o = o.slice(o_l.start_offset()..); + + let (num_tokens, _, hidden_size) = i_l.shape().dims3()?; + + if (num_tokens, hidden_size) != o_l.shape().dims2()? { + candle::bail!( + "shape mismatch output {:?}, expected {:?}", + o_l.shape(), + (num_tokens, hidden_size) + ) + } + + let input_ptr = *i.device_ptr() as *const core::ffi::c_void; + let output_ptr = *o.device_ptr() as *const core::ffi::c_void; + + unsafe { + ffi::moe_sum( + input_ptr, + output_ptr, + hidden_size as i32, + num_token as i32, + topk as i32, + dtype, + ) + } + + Ok(()) +} + +pub fn apply_moe_sum_inplace( + input: &Tensor, + output: &Tensor, + num_token: usize, + topk: usize, + dtype: u32, +) -> Result<()> { + match input.dtype() { + DType::F16 => apply_moe_sum_::(input, output, num_token, topk, dtype), + DType::BF16 => apply_moe_sum_::(input, output, num_token, topk, dtype), + DType::F32 => apply_moe_sum_::(input, output, num_token, topk, dtype), + dt => { + candle::bail!("apply_moe_sum_inplace is only supported for f32, f16 and bf16 ({dt:?})") + } + } +} diff --git a/candle-moe/tests/moe_sum_tests.rs b/candle-moe/tests/moe_sum_tests.rs new file mode 100644 index 0000000..98666ef --- /dev/null +++ b/candle-moe/tests/moe_sum_tests.rs @@ -0,0 +1,19 @@ +use anyhow::Result; +use candle::{DType, Device, Tensor}; + +#[test] +fn moe_sum() -> Result<()> { + let device = Device::new_cuda(0)?; + + let seq_len = 8; + let top_k: usize = 2; + let hidden_size = 4; + + let input = + Tensor::randn(0.0, 1.0, (seq_len, top_k, hidden_size), &device)?.to_dtype(DType::F16)?; + let output = Tensor::zeros((seq_len, hidden_size), DType::F16, &device)?; + + candle_moe::apply_moe_sum_inplace(&input, &output, seq_len, top_k, 1)?; + + Ok(()) +} diff --git a/candle-moe/tests/topk_softmax_tests.rs b/candle-moe/tests/topk_softmax_tests.rs new file mode 100644 index 0000000..de769f2 --- /dev/null +++ b/candle-moe/tests/topk_softmax_tests.rs @@ -0,0 +1,57 @@ +use anyhow::Result; +use candle::{DType, Device, Tensor}; +use candle_transformers::models::deepseek2::{TopKLastDimOp, TopKOutput}; + +fn to_vec2_round(t: Tensor, digits: i32) -> Result>> { + let b = 10f32.powi(digits); + let t = t.to_vec2::()?; + let t = t + .iter() + .map(|row| { + row.iter() + .map(|val| (val * b).round() / b) + .collect::>() + }) + .collect::>>(); + Ok(t) +} + +#[test] +fn topk_softmax() -> Result<()> { + let device = Device::new_cuda(0)?; + + let seq_len = 8; + let num_experts = 4; + let top_k = 2; + + let weights = Tensor::randn(0.0, 1.0, (seq_len, num_experts), &device)?.to_dtype(DType::F32)?; + let softmax_weights = candle_nn::ops::softmax_last_dim(&weights)?; + + let TopKOutput { + values: expected_values, + indices: expected_indices, + } = softmax_weights.topk(top_k)?; + + let topk_weight = Tensor::zeros((seq_len, top_k), DType::F32, &device)?; + let topk_indices = Tensor::zeros((seq_len, top_k), DType::U32, &device)?; + let token_expert_indices = Tensor::zeros((seq_len, top_k), DType::U32, &device)?; + + candle_moe::apply_topk_softmax_inplace( + &weights, + &topk_weight, + &topk_indices, + &token_expert_indices, + )?; + + assert_eq!( + to_vec2_round(expected_values, 3)?, + to_vec2_round(topk_weight, 3)? + ); + + assert_eq!( + expected_indices.to_vec2::()?, + topk_indices.to_vec2::()?, + ); + + Ok(()) +}