diff --git a/.github/workflows/pr-test-xpu.yml b/.github/workflows/pr-test-xpu.yml index 05e97e5..7539415 100644 --- a/.github/workflows/pr-test-xpu.yml +++ b/.github/workflows/pr-test-xpu.yml @@ -1,73 +1,73 @@ -name: PR Test (XPU) - -on: - pull_request: - branches: [main] - workflow_dispatch: - -concurrency: - group: pr-test-xpu-${{ github.ref }} - cancel-in-progress: true - -jobs: - build-and-test: - if: github.event_name != 'pull_request' || contains(github.event.pull_request.labels.*.name, 'run-ci') - runs-on: sglang-pvc - steps: - - name: Checkout code - uses: actions/checkout@v4 - with: - fetch-depth: 0 - - - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v2 - - - name: Build Docker image - run: | - docker build \ - --build-arg SG_LANG_KERNEL_BRANCH=${{ github.head_ref }} \ - --build-arg SG_LANG_KERNEL_REPO=${{ github.event.pull_request.head.repo.clone_url }} \ - --no-cache --progress=plain -f Dockerfile.xpu_kernel -t xpu_sglang:kernel . - - - name: Run container - run: | - docker run -dt \ - --device /dev/dri/ \ - --name ci_sglang_xpu \ - -e HF_TOKEN=$(cat ~/huggingface_token.txt) \ - xpu_sglang:kernel - - - name: Install Dependency - timeout-minutes: 20 - run: | - docker exec ci_sglang_xpu /miniforge3/envs/py3.10/bin/python3 -m pip install --upgrade pip - docker exec ci_sglang_xpu /miniforge3/envs/py3.10/bin/python3 -m pip install pytest expecttest ray huggingface_hub - docker exec ci_sglang_xpu /bin/bash -c '/miniforge3/envs/py3.10/bin/huggingface-cli login --token ${HF_TOKEN} ' - docker exec ci_sglang_xpu /bin/bash -c "ln -sf /miniforge3/envs/py3.10/bin/python3 /usr/bin/python3" - - - name: Run Sglang Kernel Cases - timeout-minutes: 20 - run: | - docker exec -w /root/sglang ci_sglang_xpu \ - /bin/bash -c "cd /root/sglang/sgl-kernel-xpu/tests && python3 run_suite.py --suite per-commit " - - - name: Run Sglang Kernel Benchmarks - timeout-minutes: 20 - run: | - docker exec -w /root/sglang ci_sglang_xpu \ - /bin/bash -c "cd /root/sglang/sgl-kernel-xpu/benchmark && python3 bench_flash_attn.py && python3 bench_moe_topk_softmax.py " - - - name: Run E2E Bfloat16 tests - timeout-minutes: 20 - run: | - echo "[PlaceHolder for E2E Test...]" - - - name: Run E2E Qunatization tests - timeout-minutes: 20 - run: | - echo "[PlaceHolder for E2E Test...]" - - - name: Cleanup container - if: always() - run: | - docker rm -f ci_sglang_xpu || true +name: PR Test (XPU) + +on: + pull_request: + branches: [main] + workflow_dispatch: + +concurrency: + group: pr-test-xpu-${{ github.ref }} + cancel-in-progress: true + +jobs: + build-and-test: + if: github.event_name != 'pull_request' || contains(github.event.pull_request.labels.*.name, 'run-ci') + runs-on: sglang-pvc + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v2 + + - name: Build Docker image + run: | + docker build \ + --build-arg SG_LANG_KERNEL_BRANCH=${{ github.head_ref }} \ + --build-arg SG_LANG_KERNEL_REPO=${{ github.event.pull_request.head.repo.clone_url }} \ + --no-cache --progress=plain -f Dockerfile.xpu_kernel -t xpu_sglang:kernel . + + - name: Run container + run: | + docker run -dt \ + --device /dev/dri/ \ + --name ci_sglang_xpu \ + -e HF_TOKEN=$(cat ~/huggingface_token.txt) \ + xpu_sglang:kernel + + - name: Install Dependency + timeout-minutes: 20 + run: | + docker exec ci_sglang_xpu /miniforge3/envs/py3.10/bin/python3 -m pip install --upgrade pip + docker exec ci_sglang_xpu /miniforge3/envs/py3.10/bin/python3 -m pip install pytest expecttest ray huggingface_hub + docker exec ci_sglang_xpu /bin/bash -c '/miniforge3/envs/py3.10/bin/huggingface-cli login --token ${HF_TOKEN} ' + docker exec ci_sglang_xpu /bin/bash -c "ln -sf /miniforge3/envs/py3.10/bin/python3 /usr/bin/python3" + + - name: Run Sglang Kernel Cases + timeout-minutes: 20 + run: | + docker exec -w /root/sglang ci_sglang_xpu \ + /bin/bash -c "cd /root/sglang/sgl-kernel-xpu/tests && python3 run_suite.py --suite per-commit " + + - name: Run Sglang Kernel Benchmarks + timeout-minutes: 20 + run: | + docker exec -w /root/sglang ci_sglang_xpu \ + /bin/bash -c "cd /root/sglang/sgl-kernel-xpu/benchmark && python3 bench_flash_attn.py && python3 bench_moe_topk_softmax.py && python3 bench_fused_moe.py " + + - name: Run E2E Bfloat16 tests + timeout-minutes: 20 + run: | + echo "[PlaceHolder for E2E Test...]" + + - name: Run E2E Qunatization tests + timeout-minutes: 20 + run: | + echo "[PlaceHolder for E2E Test...]" + + - name: Cleanup container + if: always() + run: | + docker rm -f ci_sglang_xpu || true diff --git a/.gitignore b/.gitignore index ee45cb7..7115e66 100644 --- a/.gitignore +++ b/.gitignore @@ -45,4 +45,6 @@ *.pyo build + +# vscode .vscode/ diff --git a/benchmark/bench_fused_moe.py b/benchmark/bench_fused_moe.py new file mode 100644 index 0000000..31e5272 --- /dev/null +++ b/benchmark/bench_fused_moe.py @@ -0,0 +1,300 @@ +# python3 benchmark/bench_fused_moe.py +from itertools import product + +import torch +import triton +from sgl_kernel import fused_experts, topk_softmax +from torch.nn import functional as F + +shape_configs = [ + # # Qwen/Qwen2-57B-A14B-Instruct, tp = 1 + # { + # "num_experts": 64, + # "topk": 8, + # "hidden_size": 3584, + # "shard_intermediate_size": 5120, + # "dtype": torch.bfloat16, + # "block_shape": None, + # }, + # # Qwen/Qwen2-57B-A14B-Instruct, tp = 2 + # { + # "num_experts": 64, + # "topk": 8, + # "hidden_size": 3584, + # "shard_intermediate_size": 2560, + # "dtype": torch.bfloat16, + # "block_shape": None, + # }, + # Qwen/Qwen2-57B-A14B-Instruct, tp = 4 + { + "num_experts": 64, + "topk": 8, + "hidden_size": 3584, + "shard_intermediate_size": 1280, + "dtype": torch.bfloat16, + "block_shape": None, + }, + # Qwen/Qwen2-57B-A14B-Instruct, tp = 8 + { + "num_experts": 64, + "topk": 8, + "hidden_size": 3584, + "shard_intermediate_size": 640, + "dtype": torch.bfloat16, + "block_shape": None, + }, + # # DeepSeek-V3-0324, tp = 1 + # { + # "num_experts": 257, + # "topk": 8, + # "hidden_size": 7168, + # "shard_intermediate_size": 4096, + # "dtype": torch.bfloat16, + # "block_shape": [128, 128], + # }, + # # DeepSeek-V3-0324, tp = 2 + # { + # "num_experts": 257, + # "topk": 8, + # "hidden_size": 7168, + # "shard_intermediate_size": 2048, + # "dtype": torch.bfloat16, + # "block_shape": [128, 128], + # }, + # # DeepSeek-V3-0324, tp = 4 + # { + # "num_experts": 257, + # "topk": 8, + # "hidden_size": 7168, + # "shard_intermediate_size": 1024, + # "dtype": torch.bfloat16, + # "block_shape": [128, 128], + # }, + # # DeepSeek-V3-0324, tp = 8 + # { + # "num_experts": 257, + # "topk": 8, + # "hidden_size": 7168, + # "shard_intermediate_size": 512, + # "dtype": torch.bfloat16, + # "block_shape": [128, 128], + # }, + # # Mixtral-8x7B-Instruct-v0.1, tp = 1 + # { + # "num_experts": 8, + # "topk": 2, + # "hidden_size": 4096, + # "shard_intermediate_size": 28672, + # "dtype": torch.bfloat16, + # "block_shape": None, + # }, + # # Mixtral-8x7B-Instruct-v0.1, tp = 2 + # { + # "num_experts": 8, + # "topk": 2, + # "hidden_size": 4096, + # "shard_intermediate_size": 14336, + # "dtype": torch.bfloat16, + # "block_shape": None, + # }, + # Mixtral-8x7B-Instruct-v0.1, tp = 4 + { + "num_experts": 8, + "topk": 2, + "hidden_size": 4096, + "shard_intermediate_size": 7168, + "dtype": torch.bfloat16, + "block_shape": None, + }, + # Mixtral-8x7B-Instruct-v0.1, tp = 8 + { + "num_experts": 8, + "topk": 2, + "hidden_size": 4096, + "shard_intermediate_size": 3584, + "dtype": torch.bfloat16, + "block_shape": None, + }, +] + +shape_values = [list(d.values()) for d in shape_configs] +bs = [1, 16, 32] # 128, 256, 512, 1024, 2048, 4096, 8192] +configs = [(k, *v) for k, v in product(bs, shape_values)] + + +def fused_topk_native( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, +): + assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" + M, _ = hidden_states.shape + topk_weights = torch.empty( + M, topk, dtype=torch.float32, device=hidden_states.device + ) + topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device) + topk_weights = F.softmax(gating_output.float(), dim=-1) + topk_weights, topk_ids = torch.topk(topk_weights, topk, dim=-1) + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + return topk_weights, topk_ids + + +@torch.compile(dynamic=False) +def fused_moe_torch( + x, + w1, + w2, + input_gating, + topk, +) -> torch.Tensor: + + topk_weights, topk_ids = fused_topk_native( + hidden_states=x, + gating_output=input_gating, + topk=topk, + renormalize=True, + ) + w13_weights = w1[topk_ids] + w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2) + w2_weights = w2[topk_ids] + x1 = torch.einsum("ti,taoi -> tao", x, w1_weights) + x1 = F.silu(x1) + x3 = torch.einsum("ti, taoi -> tao", x, w3_weights) + expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights) + return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype)) + + +def fused_moe_torch_compile( + x, + w1, + w2, + input_gating, + topk, + use_fp8_w8a8=False, + w1_scale=None, + w2_scale=None, + a1_scale=None, + a2_scale=None, +): + return fused_moe_torch( + x, + w1, + w2, + input_gating, + topk, + ) + + +def fused_moe_sglang_api( + x, + w1, + w2, + input_gating, + topk, +): + num_tokens = x.shape[0] + topk_weights = torch.empty(num_tokens, topk, dtype=torch.float32, device=x.device) + topk_indices = torch.empty(num_tokens, topk, dtype=torch.int32, device=x.device) + + topk_softmax( + topk_weights, + topk_indices, + input_gating, + renormalize=True, + ) + return fused_experts( + x, + w1, + w2, + topk_weights, + topk_indices, + ) + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=[ + "num_tokens", + "num_experts", + "topk", + "hidden_size", + "shard_intermediate_size", + "dtype", + "block_shape", + ], + x_vals=configs, + line_arg="provider", + line_vals=[ + "torch_compile", + "sgl_kernel", + ], + line_names=[ + "torch_compile", + "sgl_kernel", + ], + styles=[ + ("blue", "-"), + ("green", "-"), + ], + ylabel="Time (ms)", + plot_name="fused-moe-performance", + args={}, + ) +) +def benchmark( + num_tokens, + num_experts, + topk, + hidden_size, + shard_intermediate_size, + dtype, + block_shape, + provider, +): + print( + f"benchmark {provider} with batch_size={num_tokens} hidden_size={hidden_size} shard_intermediate_size={shard_intermediate_size}" + ) + torch.set_default_device("xpu") + torch.xpu.manual_seed_all(0) + + x = torch.randn(num_tokens, hidden_size, dtype=dtype) + + w1 = torch.randn(num_experts, shard_intermediate_size, hidden_size, dtype=dtype) + w2 = torch.randn( + num_experts, hidden_size, shard_intermediate_size // 2, dtype=dtype + ) + + input_gating = torch.randn(num_tokens, num_experts, dtype=dtype) + + if provider == "torch_compile": + api_func = fused_moe_torch_compile + else: + api_func = fused_moe_sglang_api + + api_kwargs = { + "x": x, + "w1": w1, + "w2": w2, + "input_gating": input_gating, + "topk": topk, + } + + # Warmup + for _ in range(10): + _ = api_func(**api_kwargs) + torch.xpu.synchronize() + + bench_lambda = lambda: api_func(**api_kwargs) + + quantiles = [0.5, 0.2, 0.8] + ms, min_ms, max_ms = triton.testing.do_bench(bench_lambda, quantiles=quantiles) + torch.xpu.empty_cache() + del x, w1, w2, input_gating + return ms, min_ms, max_ms + + +if __name__ == "__main__": + benchmark.run(print_data=True) + print("Benchmark finished!") diff --git a/benchmark/bench_moe_align_block_size.py b/benchmark/bench_moe_align_block_size.py index 2745022..82a84ee 100644 --- a/benchmark/bench_moe_align_block_size.py +++ b/benchmark/bench_moe_align_block_size.py @@ -5,7 +5,6 @@ import triton import triton.language as tl from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size -from vllm import _custom_ops as ops USE_RANDOM_PERM = False @@ -143,102 +142,63 @@ def moe_align_block_size_triton( def calculate_diff(num_tokens, num_experts=256, block_size=128, topk=8): topk_ids = torch.stack( [ - torch.randperm(num_experts, dtype=torch.int32, device="cuda")[:topk] + torch.randperm(num_experts, dtype=torch.int32, device="xpu")[:topk] for _ in range(num_tokens) ] ) max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) - sorted_ids_cuda = torch.empty( + sorted_ids_xpu = torch.empty( (max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device ) - sorted_ids_cuda.fill_(topk_ids.numel()) + sorted_ids_xpu.fill_(topk_ids.numel()) max_num_m_blocks = max_num_tokens_padded // block_size - expert_ids_cuda = torch.zeros( + expert_ids_xpu = torch.zeros( (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device ) - num_tokens_post_pad_cuda = torch.empty( + num_tokens_post_pad_xpu = torch.empty( (1), dtype=torch.int32, device=topk_ids.device ) - token_cnts_buffer = torch.zeros( - (num_experts + 1) * num_experts, dtype=torch.int32, device=topk_ids.device - ) cumsum_buffer = torch.zeros( - num_experts + 1, dtype=torch.int32, device=topk_ids.device + num_experts + 2, dtype=torch.int32, device=topk_ids.device ) - sorted_ids_triton = torch.empty_like(sorted_ids_cuda) + sorted_ids_triton = torch.empty_like(sorted_ids_xpu) sorted_ids_triton.fill_(topk_ids.numel()) - expert_ids_triton = torch.zeros_like(expert_ids_cuda) - num_tokens_post_pad_triton = torch.empty_like(num_tokens_post_pad_cuda) - - sorted_ids_vllm = torch.empty_like(sorted_ids_cuda) - sorted_ids_vllm.fill_(topk_ids.numel()) - expert_ids_vllm = torch.zeros_like(expert_ids_cuda) - num_tokens_post_pad_vllm = torch.empty_like(num_tokens_post_pad_cuda) + expert_ids_triton = torch.zeros_like(expert_ids_xpu) + num_tokens_post_pad_triton = torch.empty_like(num_tokens_post_pad_xpu) - # compare the performance of cuda, triton and vllm implementation + # compare the performance of xpu and triton implementation sgl_moe_align_block_size( topk_ids, - num_experts, + num_experts + 1, block_size, - sorted_ids_cuda, - expert_ids_cuda, - num_tokens_post_pad_cuda, - token_cnts_buffer, + sorted_ids_xpu, + expert_ids_xpu, + num_tokens_post_pad_xpu, cumsum_buffer, + False, ) moe_align_block_size_triton( topk_ids, - num_experts, + num_experts + 1, block_size, sorted_ids_triton, expert_ids_triton, num_tokens_post_pad_triton, ) - try: - ops.moe_align_block_size( - topk_ids, - num_experts, - block_size, - sorted_ids_vllm, - expert_ids_vllm, - num_tokens_post_pad_vllm, - ) - print(f"✅ VLLM implementation works with {num_experts} experts!") - vllm_works = True - except RuntimeError as e: - print(f"❌ VLLM implementation failed with {num_experts} experts: {e}") - vllm_works = False - - if torch.allclose(expert_ids_cuda, expert_ids_triton) and torch.allclose( - num_tokens_post_pad_cuda, num_tokens_post_pad_triton + if torch.allclose(expert_ids_xpu, expert_ids_triton) and torch.allclose( + num_tokens_post_pad_xpu, num_tokens_post_pad_triton ): print("✅ SGL and Triton implementations match") else: print("❌ SGL and Triton implementations do not match") - print("SGL expert_ids:", expert_ids_cuda) + print("SGL expert_ids:", expert_ids_xpu) print("Triton expert_ids:", expert_ids_triton) - print("SGL num_tokens_post_pad:", num_tokens_post_pad_cuda) + print("SGL num_tokens_post_pad:", num_tokens_post_pad_xpu) print("Triton num_tokens_post_pad:", num_tokens_post_pad_triton) - if ( - vllm_works - and torch.allclose(expert_ids_cuda, expert_ids_vllm) - and torch.allclose(num_tokens_post_pad_cuda, num_tokens_post_pad_vllm) - ): - print("✅ SGL and VLLM implementations match") - else: - if not vllm_works: - print("⚠️ VLLM comparison skipped due to failure") - else: - print("❌ SGL and VLLM implementations do not match") - print("SGL expert_ids:", expert_ids_cuda) - print("VLLM expert_ids:", expert_ids_vllm) - print("SGL num_tokens_post_pad:", num_tokens_post_pad_cuda) - print("VLLM num_tokens_post_pad:", num_tokens_post_pad_vllm) - # Test range num_tokens_range = [1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192] @@ -249,9 +209,9 @@ def calculate_diff(num_tokens, num_experts=256, block_size=128, topk=8): def get_topk_ids(num_tokens: int, num_experts: int, topk: int) -> torch.Tensor: - topk_ids = torch.zeros((num_tokens, topk), dtype=torch.int32, device="cuda") + topk_ids = torch.zeros((num_tokens, topk), dtype=torch.int32, device="xpu") for i in range(num_tokens): - topk_ids[i, :] = torch.randperm(num_experts, dtype=torch.int32, device="cuda")[ + topk_ids[i, :] = torch.randperm(num_experts, dtype=torch.int32, device="xpu")[ :topk ] return topk_ids @@ -262,8 +222,8 @@ def get_topk_ids(num_tokens: int, num_experts: int, topk: int) -> torch.Tensor: x_names=["num_tokens", "num_experts", "topk"], x_vals=configs, line_arg="provider", - line_vals=["sgl", "triton", "vllm"], - line_names=["SGL", "Triton", "VLLM"], + line_vals=["sgl", "triton"], + line_names=["SGL", "Triton"], styles=[("blue", "-"), ("red", "-"), ("green", "-")], ylabel="us", plot_name="moe-align-block-size-performance", @@ -281,7 +241,7 @@ def benchmark(num_tokens, num_experts, topk, provider): num_experts, (num_tokens, topk), dtype=torch.int32, - device="cuda", + device="xpu", ) max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) @@ -306,11 +266,6 @@ def sgl_moe_align_block_size_with_empty( expert_ids, num_tokens_post_pad, ): - token_cnts_buffer = torch.empty( - (num_experts + 1) * num_experts, - dtype=torch.int32, - device=topk_ids.device, - ) cumsum_buffer = torch.empty( num_experts + 1, dtype=torch.int32, device=topk_ids.device ) @@ -322,8 +277,8 @@ def sgl_moe_align_block_size_with_empty( sorted_ids.clone(), expert_ids.clone(), num_tokens_post_pad.clone(), - token_cnts_buffer, cumsum_buffer, + False, ) ms, min_ms, max_ms = triton.testing.do_bench( @@ -349,23 +304,6 @@ def sgl_moe_align_block_size_with_empty( ), quantiles=quantiles, ) - else: # vllm - try: - ms, min_ms, max_ms = triton.testing.do_bench( - lambda: ops.moe_align_block_size( - topk_ids, - num_experts, - block_size, - sorted_ids.clone(), - expert_ids.clone(), - num_tokens_post_pad.clone(), - ), - quantiles=quantiles, - ) - except RuntimeError as e: - print(f"❌ VLLM benchmark failed with {num_experts} experts: {e}") - # Return extreme values to indicate failure in the chart - return float("inf"), float("inf"), float("inf") return 1000 * ms, 1000 * max_ms, 1000 * min_ms diff --git a/include/sgl_kernel_ops.h b/include/sgl_kernel_ops.h index 31e450d..baa55a0 100644 --- a/include/sgl_kernel_ops.h +++ b/include/sgl_kernel_ops.h @@ -215,8 +215,10 @@ void moe_align_block_size( torch::Tensor sorted_token_ids, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad, - torch::Tensor token_cnts_buffer, - torch::Tensor cumsum_buffer); + torch::Tensor cumsum_buffer, + bool pad_sorted_token_ids); + +void moe_sum(torch::Tensor& input, torch::Tensor& output); void topk_softmax( torch::Tensor& topk_weights, @@ -253,6 +255,13 @@ void fp8_blockwise_scaled_grouped_mm( const torch::Tensor& expert_offsets, const torch::Tensor& workspace); +void moe_grouped_mm_nt( + torch::Tensor& output, + const torch::Tensor& activations, + const torch::Tensor& weights, + const torch::Tensor& total_rows_for_experts, + const int64_t n_experts); + void prepare_moe_input( const torch::Tensor& topk_ids, torch::Tensor& expert_offsets, diff --git a/include/utils.h b/include/utils.h index f6d829c..5526239 100644 --- a/include/utils.h +++ b/include/utils.h @@ -26,16 +26,6 @@ limitations under the License. CHECK_IS_XPU(x); \ CHECK_IS_CONTIGUOUS(x) -#define DISPATCH_CASE_INTEGRAL_TYPES(...) \ - AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) - -#define DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \ - AT_DISPATCH_SWITCH(TYPE, NAME, DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__)) - #define CEILDIV(x, y) (((x) + (y) - 1) / (y)) #define WARP_SIZE 32 diff --git a/python/sgl_kernel/__init__.py b/python/sgl_kernel/__init__.py index 4d5065b..a75a2e7 100755 --- a/python/sgl_kernel/__init__.py +++ b/python/sgl_kernel/__init__.py @@ -50,12 +50,12 @@ from sgl_kernel.moe import ( apply_shuffle_mul_sum, cutlass_fp4_group_mm, - ep_moe_post_reorder, - ep_moe_pre_reorder, - ep_moe_silu_and_mul, fp8_blockwise_scaled_grouped_mm, + fused_experts, moe_align_block_size, moe_fused_gate, + moe_sum, + moe_sum_reduce, prepare_moe_input, topk_softmax, ) diff --git a/python/sgl_kernel/moe.py b/python/sgl_kernel/moe.py index 27f3187..5d6c24a 100755 --- a/python/sgl_kernel/moe.py +++ b/python/sgl_kernel/moe.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Optional import torch @@ -10,8 +10,8 @@ def moe_align_block_size( sorted_token_ids, experts_ids, num_tokens_post_pad, - token_cnts_buffer, cumsum_buffer, + pad_sorted_token_ids=False, ): torch.ops.sgl_kernel.moe_align_block_size.default( topk_ids, @@ -20,8 +20,8 @@ def moe_align_block_size( sorted_token_ids, experts_ids, num_tokens_post_pad, - token_cnts_buffer, cumsum_buffer, + pad_sorted_token_ids, ) @@ -36,6 +36,28 @@ def topk_softmax( ) +def moe_sum_reduce( + input_tensor, + output_tensor, + routed_scaling_factor=0, +): + torch.ops.sgl_kernel.moe_sum_reduce.default( + input_tensor, + output_tensor, + routed_scaling_factor, + ) + + +def moe_sum( + input_tensor: torch.Tensor, + output_tensor: torch.Tensor, +): + torch.ops.sgl_kernel.moe_sum.default( + input_tensor, + output_tensor, + ) + + def moe_fused_gate( input_tensor, bias, @@ -44,6 +66,7 @@ def moe_fused_gate( topk, num_fused_shared_experts=0, routed_scaling_factor=0, + apply_routed_scaling_factor_on_output=False, ): # This fused kernel function is used to select topk expert in a hierarchical 2-layer fashion # it split group of expert into num_expert_group, and use top2 expert weight sum in each group @@ -51,8 +74,13 @@ def moe_fused_gate( # the #experts is decided by the input tensor shape and we currently only support power of 2 #experts # and #experts should be divisible by num_expert_group. #expert/num_expert_group <= 32 is limited for now. # for non-supported case, we suggest to use the biased_grouped_topk func in sglang.srt.layers.moe.topk - # num_fused_shared_experts: if > 0, the last several experts will be replaced with shared experts - # routed_scaling_factor: if > 0, the shared experts will be scaled by this factor + # num_fused_shared_experts: if > 0, the last several experts will be + # replaced with shared experts. the shared experts will be divided by the + # routed_scaling_factor - this is intended to cancel out later when routed+shared + # output is scaled so that shared experts are not scaled. + # routed_scaling_factor: if > 0, the experts will be scaled by this factor + # apply_routed_scaling_factor_on_output: if true, output will be + # scaled by the routed_scaling_factor return torch.ops.sgl_kernel.moe_fused_gate.default( input_tensor, bias, @@ -61,70 +89,7 @@ def moe_fused_gate( topk, num_fused_shared_experts, routed_scaling_factor, - ) - - -def ep_moe_pre_reorder( - input_tensor, - gateup_input, - src2dst, - topk_ids, - a1_scales, - start_expert_id, - end_expert_id, - topk, - use_per_token_if_dynamic, -): - return torch.ops.sgl_kernel.ep_moe_pre_reorder.default( - input_tensor, - gateup_input, - src2dst, - topk_ids, - a1_scales, - start_expert_id, - end_expert_id, - topk, - use_per_token_if_dynamic, - ) - - -def ep_moe_silu_and_mul( - gateup_output, - down_input, - reorder_topk_ids, - scales, - start_expert_id, - end_expert_id, -): - return torch.ops.sgl_kernel.ep_moe_silu_and_mul.default( - gateup_output, - down_input, - reorder_topk_ids, - scales, - start_expert_id, - end_expert_id, - ) - - -def ep_moe_post_reorder( - down_output, - output, - src2dst, - topk_ids, - topk_weights, - start_expert_id, - end_expert_id, - topk, -): - return torch.ops.sgl_kernel.ep_moe_post_reorder.default( - down_output, - output, - src2dst, - topk_ids, - topk_weights, - start_expert_id, - end_expert_id, - topk, + apply_routed_scaling_factor_on_output, ) @@ -252,3 +217,154 @@ def cutlass_fp4_group_mm( params["blockscale_offsets"], ) return c.to(dtype=out_dtype) + + +def fused_experts( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + b1: Optional[torch.Tensor] = None, + b2: Optional[torch.Tensor] = None, + inplace: bool = False, + activation: str = "silu", + use_fp8_w8a8: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[list[int]] = None, + no_combine: bool = False, + routed_scaling_factor: Optional[float] = None, + gemm1_alpha: Optional[float] = None, + gemm1_limit: Optional[float] = None, +) -> torch.Tensor: + """ + This function computes a Mixture of Experts (MoE) layer using two sets of + weights, w1 and w2, and top-k gating mechanism. + + Parameters: + - hidden_states [num_tokens, hidden_dim] (torch.Tensor): The input tensor to the MoE layer. + - w1 [num_experts, hidden_dim, output_channel] (torch.Tensor): The first set of expert weights. + - w2 [num_experts, output_channel, hidden_dim] (torch.Tensor): The second set of expert weights. + - topk_weights [num_tokens, topk] (torch.Tensor): The top-k output of the experts. + - topk_ids [num_tokens, topk] (torch.Tensor): The top-k indices of the experts. + - b1 (Optional[torch.Tensor]): Optional bias for w1. + - b2 (Optional[torch.Tensor]): Optional bias for w2. + - inplace (bool): If True, perform operations in-place to save memory. Defaults to False. + - activation (str): The activation function to use ('silu' or 'gelu'). Defaults to 'silu'. + - use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner + products for w1 and w2. Defaults to False. + - w1_scale (Optional[torch.Tensor]): Optional scale to be used for + w1. + - w2_scale (Optional[torch.Tensor]): Optional scale to be used for + w2. + - a1_scale (Optional[torch.Tensor]): Optional scale to be used for + a1. + - a2_scale (Optional[torch.Tensor]): Optional scale to be used for + a2. + - block_shape: (Optional[List[int]]): Optional block size for block-wise + quantization. + - no_combine (bool): If True, skip the combine step. Defaults to False. + - routed_scaling_factor (Optional[float]): Optional scaling factor for routed tokens, used by Llama4 only. + - gemm1_alpha (Optional[float]): Optional gemm1_alpha for the activation + function. + - gemm1_limit (Optional[float]): Optional gemm1_limit for the swiglu activation + function. + + Returns: + - torch.Tensor: The output tensor after applying the MoE layer. + """ + + assert use_fp8_w8a8 is False, "current MoE does not support use_fp8_w8a8" + assert w1_scale is None, "current MoE does not support w1_scale" + assert w2_scale is None, "current MoE does not support w2_scale" + assert a1_scale is None, "current MoE does not support a1_scale" + assert a2_scale is None, "current MoE does not support a2_scale" + assert block_shape is None, "current MoE does not support block_shape" + + # type check + assert hidden_states.dtype == torch.bfloat16, "hidden_states must be bfloat16" + assert w1.dtype == torch.bfloat16, "w1 must be bfloat16" + assert w2.dtype == torch.bfloat16, "w2 must be bfloat16" + + # Shape check + assert hidden_states.ndim == 2, "hidden_states must be 2D" + assert ( + hidden_states.shape[-1] == w1.shape[-1] + ), f"hidden_states shape[-1] {hidden_states.shape} must be equal to w1 shape[-2] {w1.shape}" + assert ( + 2 * w2.shape[2] == w1.shape[1] + ), f"w2 shape[2] {w2.shape[2]} must be half of w1 shape[1] {w1.shape[1]}" + assert (topk_ids.shape == topk_weights.shape) and ( + topk_ids.shape[0] == hidden_states.shape[0] + ), f"topk_ids shape {topk_ids.shape} and topk_weights shape {topk_weights.shape} must be equal and match hidden_states shape[0] {hidden_states.shape[0]}" + + num_tokens, _ = hidden_states.shape + + E, _, K = w1.shape + E, OutK, N = w2.shape + assert N * 2 == w1.shape[1], "w1 shape[1] must be 2x of w2 shape[2]" + + M = num_tokens + TopK = topk_ids.shape[1] + + # import pdb; pdb.set_trace() + cache = torch.empty( + M * TopK * max(2 * N, OutK), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + intermediate_cache1 = cache[: M * TopK * 2 * N].view((M * TopK, 2 * N)) + intermediate_cache2 = torch.empty( + (M * TopK, N), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + intermediate_cache3 = cache[: M * TopK * OutK].view((M * TopK, OutK)) + + if no_combine: + assert not inplace + out_hidden_states = torch.empty( + (num_tokens, OutK), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + elif inplace: + out_hidden_states = hidden_states + else: + out_hidden_states = torch.zeros_like(hidden_states) + + flat_topk = topk_ids.flatten() + idxs = flat_topk.argsort() + sorted_expert_ids = flat_topk[idxs] + + counts = torch.bincount(sorted_expert_ids, minlength=E) # [E] + token_idxs = idxs // TopK # [num_tokens * TopK] + input_A = torch.empty( + (num_tokens * TopK, K), device=hidden_states.device, dtype=hidden_states.dtype + ) + input_A = hidden_states[token_idxs].squeeze(1) + offset = counts.to(torch.int32) + + torch.ops.sgl_kernel.moe_grouped_mm_nt(intermediate_cache1, input_A, w1, offset, E) + + torch.ops.sgl_kernel.silu_and_mul(intermediate_cache2, intermediate_cache1) + + torch.ops.sgl_kernel.moe_grouped_mm_nt( + intermediate_cache3, intermediate_cache2, w2, offset, E + ) + + flat_weights = topk_weights.to(intermediate_cache3.dtype).flatten()[idxs] # [N] + intermediate_cache3 = intermediate_cache3 * flat_weights.unsqueeze(1) + out_hidden_states.scatter_reduce_( + 0, + token_idxs.view(-1, 1).expand(-1, OutK), + intermediate_cache3, + reduce="sum", + ) + + return out_hidden_states diff --git a/src/sycl/GroupGemm.cpp b/src/sycl/GroupGemm.cpp new file mode 100644 index 0000000..048c69b --- /dev/null +++ b/src/sycl/GroupGemm.cpp @@ -0,0 +1,221 @@ +#include +#include +#include + +#include + +#include "Utils.h" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/fusion/xe_callbacks.hpp" +#include "cutlass/gemm/collective/collective_mma.hpp" +#include "cutlass/gemm/device/gemm_universal.h" +#include "cutlass/gemm/group_array_problem_shape.hpp" +#include "cutlass/util/device_memory.h" +#include "kernels/moe/dispatch_policy.hpp" +#include "kernels/moe/xe_array_epilogue.hpp" +#include "kernels/moe/xe_array_mma.hpp" +#include "kernels/moe/xe_moe_gemm.hpp" + +using namespace cute; + +template +struct MoERunner { + using ProblemShape = cutlass::gemm::GroupProblemShape>; // per group + template + typename Gemm::Arguments args_from_options( + const cutlass::KernelHardwareInfo& hw_info, + const typename Gemm::ElementA* A_ptr, + const typename Gemm::ElementB* B_ptr, + typename Gemm::CollectiveEpilogue::ElementOutput* D_ptr, + const int gemm_N, + const int gemm_K, + const int* num_rows_per_expert_device, + const int num_experts) { + typename Gemm::Arguments arguments; + decltype(arguments.fusion_args) fusion_args; + + fusion_args.alpha = 1; + fusion_args.beta = 0; + fusion_args.alpha_ptr = nullptr; + fusion_args.beta_ptr = nullptr; + fusion_args.alpha_ptr_array = nullptr; + fusion_args.beta_ptr_array = nullptr; + // One alpha and beta per each group + fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 1}; + fusion_args.dBeta = {cute::_0{}, cute::_0{}, 1}; + + using RasterOrderOptions = + typename cutlass::gemm::kernel::detail::PersistentTileSchedulerXeGroup::RasterOrderOptions; + + arguments = typename Gemm::Arguments{ + cutlass::gemm::GemmUniversalMode::kGrouped, + static_cast((void*)A_ptr), + static_cast((void*)B_ptr), + nullptr, // static_cast((void*)D_ptr), + static_cast((void*)D_ptr), + fusion_args, + hw_info, + {1, RasterOrderOptions::AlongN}, + num_rows_per_expert_device, + num_experts, + gemm_N, + gemm_K}; + + return arguments; + } + + int init( + int device_id, + const void* activations, + const void* weights, + void* outputs, + const int gemm_n, + const int gemm_k, + const int* num_rows_per_expert_device, + const int num_experts) { // Change device_id to another value if you are running on a machine with + // multiple GPUs and wish to use a GPU other than that with device ID 0. + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + gemm_args = args_from_options( + hw_info, + reinterpret_cast(activations), + reinterpret_cast(weights), + reinterpret_cast(outputs), + gemm_n, + gemm_k, + num_rows_per_expert_device, + num_experts); + TORCH_CHECK(gemm_op.can_implement(gemm_args) == cutlass::Status::kSuccess, "GEMM configuration not supported."); + return Gemm::get_workspace_size(gemm_args); + } + + void run(sycl::queue queue, void* workspace) { + TORCH_CHECK(gemm_op.initialize(gemm_args, workspace) == cutlass::Status::kSuccess, "Failed to initialize GEMM."); + + // Run the GEMM + TORCH_CHECK(gemm_op.run(&queue) == cutlass::Status::kSuccess, "Failed to run GEMM."); + } + + public: + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using GmemTiledCopyA = XE_2D_U16x8x32_LD_N; + using GmemTiledCopyB = XE_2D_U16x16x16_LD_T; + + // Workgroup-level tile + using TileShape = Shape<_256, _256, _32>; + + using TiledMma = // M=8,N=16,K=16, D=f32,A=bf16,B=bf16,C=f32 + typename TiledMMAHelper< + MMA_Atom, + Layout, + Layout, Stride<_4, _1, _0>>>::TiledMMA; + + static constexpr int PipelineStages = 2; + // Dispatch to grouped gemm algorithm + using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16MoE; + using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16Group; + + // ScaledAcc needs to be supported in xe_builder.inl and xe_callbacks.cpp + // This is a workaround + using EpilogueOp = cutlass::epilogue::fusion:: + LinearCombination; + using CopyOpG2R = XE_2D_U32x8x16_LD_N; + using CopyOpR2G = XE_2D_U16x8x16_ST_N; + + using StrideC = cutlass::detail::TagToStrideC_t; + using FusionCallbacks = typename cutlass::epilogue::collective::detail::FusionOpInfo< + EpilogueOp>::template FusionCallbacks; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveEpilogue< + cutlass::epilogue::IntelXeXMX16MoE, + TileShape, + float, + StrideC, + scalar_t, + StrideC, + FusionCallbacks, + CopyOpG2R, + void, + void, + CopyOpR2G, + void, + void>; + + // Mainloop + using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma< + GEMMDispatchPolicy, + TileShape, + scalar_t, + cutlass::gemm::TagToStrideA_t, + scalar_t, + cutlass::gemm::TagToStrideB_t, + TiledMma, + GmemTiledCopyA, + void, + void, + cute::identity, // A + GmemTiledCopyB, + void, + void, + cute::identity // B + >; + + using GemmKernel = cutlass::gemm::kernel:: + GemmMoEUniversal; + + using Gemm = cutlass::gemm::device::GemmMoEUniversalAdapter; + + Gemm gemm_op; + typename Gemm::Arguments gemm_args; + // The KernelHardwareInfo struct holds the number of EUs on the GPU with a + // given device ID. This information is used by the underlying kernel. + cutlass::KernelHardwareInfo hw_info; +}; + +void moe_grouped_mm_nt( + torch::Tensor& output, + const torch::Tensor& activations, + const torch::Tensor& weights, + const torch::Tensor& total_rows_for_experts, + const int64_t n_experts) { + int total_m = activations.sizes()[0]; + int gemm_k = activations.sizes()[1]; + auto weights_shape = weights.sizes().vec(); + int gemm_n = weights.sizes()[1]; + + TORCH_CHECK(weights_shape.size() == 3, "weights must be 3D"); + TORCH_CHECK(weights_shape[0] == n_experts, "weights must have n_experts as the first dimension"); + TORCH_CHECK(weights_shape[1] == gemm_n, "weights must be gemm_n * gemm_k"); + TORCH_CHECK( + weights_shape[0] == total_rows_for_experts.size(0), + "rows_for_experts must have the same size as the first dimension of weights"); + TORCH_CHECK(output.sizes()[0] == total_m, "output must have the same number of rows as activations"); + TORCH_CHECK(output.sizes()[1] == gemm_n, "output must have the same number of columns as activations"); + TORCH_CHECK(n_experts % 8 == 0, "n_experts must be a multiple of 8 for the current implementation"); + TORCH_CHECK( + activations.scalar_type() == weights.scalar_type(), "activations and weights must have the same data type"); + TORCH_CHECK( + activations.scalar_type() == at::ScalarType::BFloat16, + "Only bfloat16 are supported in moe_grouped_mm_nt currently"); + + auto stream = at::xpu::getCurrentXPUStream(); + auto queue = stream.queue(); + + using Kernel = MoERunner; + Kernel kernel; + auto workspace_size = kernel.init( + activations.device().index(), + activations.data_ptr(), + weights.data_ptr(), + output.data_ptr(), + gemm_n, + gemm_k, + total_rows_for_experts.data_ptr(), + n_experts); + auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(activations.device()); + auto workspace = torch::empty(workspace_size, workspace_options); + kernel.run(queue, workspace.data_ptr()); +} diff --git a/src/sycl/MoEAlign.cpp b/src/sycl/MoEAlign.cpp new file mode 100644 index 0000000..a83f8f8 --- /dev/null +++ b/src/sycl/MoEAlign.cpp @@ -0,0 +1,377 @@ +#include +#include +#include + +#include + +#include "SYCLHelpers.h" +#include "Utils.h" + +#define VEC_SIZE 4 +static constexpr int sub_group_size = 32; + +using Vec = sycl::int4; + +// Utility function: atomic add for SYCL +template +T atomic_add_sycl( + T* ptr, T value, sycl::atomic_ref atomic_ref) { + return atomic_ref.fetch_add(value); +} + +// Utility: next power of 2 +inline size_t next_pow2(size_t n) { + n--; + n |= n >> 1; + n |= n >> 2; + n |= n >> 4; + n |= n >> 8; + n |= n >> 16; + n |= n >> 32; + return n + 1; +} + +template +struct CountAndSortExpertTokensFunctor { + CountAndSortExpertTokensFunctor( + const scalar_t* topk_ids, int32_t* sorted_token_ids, int32_t* cumsum_buffer, size_t numel) + : topk_ids(topk_ids), sorted_token_ids(sorted_token_ids), cumsum_buffer(cumsum_buffer), numel(numel) {} + + [[sycl::reqd_sub_group_size(sub_group_size)]] void operator()(sycl::nd_item<1> item) const { + const size_t tid = item.get_global_id(0); + const size_t stride = item.get_global_range(0); + + for (size_t i = tid; i < numel; i += stride) { + int32_t expert_id = topk_ids[i] + 1; + sycl::atomic_ref atomicRef( + cumsum_buffer[expert_id]); + int32_t rank_post_pad = atomicRef.fetch_add(1); + sorted_token_ids[rank_post_pad] = i; + } + } + + const scalar_t* topk_ids; + int32_t* sorted_token_ids; + int32_t* cumsum_buffer; + size_t numel; +}; + +template +struct MOEAlignBlockSizeFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { + MOEAlignBlockSizeFunctor( + const scalar_t* topk_ids, + int32_t* sorted_token_ids, + int32_t* expert_ids, + int32_t* total_tokens_post_pad, + int32_t num_experts, + int32_t block_size, + size_t numel, + int32_t* cumsum, + bool pad_sorted_token_ids, + const int32_t scan_size) + : topk_ids(topk_ids), + sorted_token_ids(sorted_token_ids), + expert_ids(expert_ids), + total_tokens_post_pad(total_tokens_post_pad), + num_experts(num_experts), + block_size(block_size), + numel(numel), + cumsum(cumsum), + pad_sorted_token_ids(pad_sorted_token_ids), + scan_size(scan_size) {} + + [[sycl::reqd_sub_group_size(sub_group_size)]] void operator()(sycl::nd_item<1> item) const { + const size_t tid = item.get_local_id(0); + const size_t stride = item.get_local_range(0); + + int32_t* shared_counts = (int32_t*)(slm_.template get_multi_ptr().get()); + int32_t* prefix = shared_counts + num_experts; + int32_t* scan_buf = prefix + num_experts + 1; + int32_t* s_total_tokens_post_pad = + (int32_t*)(total_token_.template get_multi_ptr().get()); + *s_total_tokens_post_pad = 0; + + if (tid < num_experts) { + shared_counts[tid] = 0; + } + item.barrier(sycl::access::fence_space::local_space); + + for (size_t i = tid; i < numel; i += stride) { + int expert_id = topk_ids[i] + 1; + sycl::atomic_ref atomicRef( + shared_counts[expert_id]); + atomicRef.fetch_add(1); + } + + item.barrier(sycl::access::fence_space::local_space); + + int32_t padded_count = 0; + if (tid < num_experts) { + int32_t count = shared_counts[tid]; + padded_count = (count + block_size - 1) / block_size * block_size; + scan_buf[tid] = padded_count; + } + + // Blelloch scan + if (tid >= num_experts && tid < scan_size) { + scan_buf[tid] = 0; + } + item.barrier(sycl::access::fence_space::local_space); + + int offset = 1; + for (int d = scan_size >> 1; d > 0; d >>= 1) { + if (tid < d) { + int ai = offset * (2 * tid + 1) - 1; + int bi = offset * (2 * tid + 2) - 1; + scan_buf[bi] += scan_buf[ai]; + } + offset <<= 1; + item.barrier(sycl::access::fence_space::local_space); + } + + // down-sweep + if (tid == 0) { + prefix[num_experts] = scan_buf[scan_size - 1]; + scan_buf[scan_size - 1] = 0; + } + item.barrier(sycl::access::fence_space::local_space); + + for (int d = 1; d < scan_size; d <<= 1) { + offset >>= 1; + if (tid < d) { + int ai = offset * (2 * tid + 1) - 1; + int bi = offset * (2 * tid + 2) - 1; + if (bi < scan_size) { + int temp = scan_buf[ai]; + scan_buf[ai] = scan_buf[bi]; + scan_buf[bi] += temp; + } + } + item.barrier(sycl::access::fence_space::local_space); + } + + if (tid < num_experts) { + prefix[tid] = scan_buf[tid]; + } + + if (tid == 0) { + *s_total_tokens_post_pad = prefix[num_experts]; + *total_tokens_post_pad = *s_total_tokens_post_pad; + } + item.barrier(sycl::access::fence_space::local_space); + + // Write cumsum + if (tid <= num_experts) { + cumsum[tid] = prefix[tid]; + } + + // Fill expert_ids + const int32_t num_blocks = *s_total_tokens_post_pad / block_size; + for (int32_t i = tid; i < num_blocks; i += stride) { + int32_t block_start = i * block_size; + int left = 0, right = num_experts; + while (left < right) { + int mid = (left + right) >> 1; + if (prefix[mid] <= block_start) { + left = mid + 1; + } else { + right = mid; + } + } + expert_ids[i] = left - 2; + } + + if (pad_sorted_token_ids) { + Vec fill_vec{(int)numel, (int)numel, (int)numel, (int)numel}; + int32_t total_vecs = (*s_total_tokens_post_pad + VEC_SIZE - 1) / VEC_SIZE; + Vec* out_ptr = reinterpret_cast(sorted_token_ids); + for (int32_t i = tid; i < total_vecs; i += stride) { + out_ptr[i] = fill_vec; + } + } + } + + void sycl_ker_config_convention(sycl::handler& cgh) { + const size_t scan_size = next_pow2(num_experts); + const size_t shared_mem_size = num_experts + (num_experts + 1) + scan_size + sub_group_size; + slm_ = sycl::local_accessor(shared_mem_size, cgh); + total_token_ = sycl::local_accessor(1, cgh); + } + + const scalar_t* topk_ids; + int32_t* sorted_token_ids; + int32_t* expert_ids; + int32_t* total_tokens_post_pad; + int32_t num_experts; + int32_t block_size; + size_t numel; + int32_t* cumsum; + bool pad_sorted_token_ids; + const int32_t scan_size; + sycl::local_accessor slm_; + sycl::local_accessor total_token_; +}; + +template +struct MOEAlignBlockSizeSmallBatchExpertFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { + MOEAlignBlockSizeSmallBatchExpertFunctor( + const scalar_t* topk_ids, + int32_t* sorted_token_ids, + int32_t* expert_ids, + int32_t* total_tokens_post_pad, + int32_t num_experts, + int32_t block_size, + size_t numel, + bool pad_sorted_token_ids) + : topk_ids(topk_ids), + sorted_token_ids(sorted_token_ids), + expert_ids(expert_ids), + total_tokens_post_pad(total_tokens_post_pad), + num_experts(num_experts), + block_size(block_size), + numel(numel), + pad_sorted_token_ids(pad_sorted_token_ids) {} + + [[sycl::reqd_sub_group_size(sub_group_size)]] void operator()(sycl::nd_item<1> item) const { + const size_t tid = item.get_local_id(0); + const size_t stride = item.get_local_range(0); + const size_t block_dim = item.get_local_range(0); + + int32_t* shared_mem = (int32_t*)(slm_.template get_multi_ptr().get()); + int32_t* cumsum = shared_mem; + int32_t* tokens_cnts = shared_mem + num_experts + 1; + + for (int i = 0; i < num_experts; ++i) { + tokens_cnts[num_experts + i] = 0; + } + item.barrier(sycl::access::fence_space::local_space); + + for (size_t i = tid; i < numel; i += stride) { + ++tokens_cnts[(tid + 1) * num_experts + topk_ids[i] + 1]; + } + item.barrier(sycl::access::fence_space::local_space); + + if (tid < num_experts) { + tokens_cnts[tid] = 0; + for (int i = 1; i <= block_dim; ++i) { + tokens_cnts[i * num_experts + tid] += tokens_cnts[(i - 1) * num_experts + tid]; + } + } + item.barrier(sycl::access::fence_space::local_space); + + if (tid == 0) { + cumsum[0] = 0; + for (int i = 1; i <= num_experts; ++i) { + cumsum[i] = cumsum[i - 1] + div_up(tokens_cnts[block_dim * num_experts + i - 1], block_size) * block_size; + } + *total_tokens_post_pad = static_cast(cumsum[num_experts]); + } + item.barrier(sycl::access::fence_space::local_space); + + if (tid < num_experts) { + for (int i = cumsum[tid]; i < cumsum[tid + 1]; i += block_size) { + expert_ids[i / block_size] = tid - 1; + } + } + + if (pad_sorted_token_ids) { + Vec fill_vec{(int)numel, (int)numel, (int)numel, (int)numel}; + int32_t total_vecs = (*total_tokens_post_pad + VEC_SIZE - 1) / VEC_SIZE; + Vec* out_ptr = reinterpret_cast(sorted_token_ids); + for (int32_t i = tid; i < total_vecs; i += stride) { + out_ptr[i] = fill_vec; + } + } + item.barrier(sycl::access::fence_space::local_space); + + for (size_t i = tid; i < numel; i += stride) { + int32_t expert_id = topk_ids[i] + 1; + int32_t rank_post_pad = tokens_cnts[tid * num_experts + expert_id] + cumsum[expert_id]; + sorted_token_ids[rank_post_pad] = i; + ++tokens_cnts[tid * num_experts + expert_id]; + } + } + + void sycl_ker_config_convention(sycl::handler& cgh) { + const int32_t threads_local = std::max((int32_t)num_experts, sub_group_size); + const int32_t shared_mem_size = ((threads_local + 1) * num_experts + (num_experts + 1)); + slm_ = sycl::local_accessor(shared_mem_size, cgh); + } + + const scalar_t* topk_ids; + int32_t* sorted_token_ids; + int32_t* expert_ids; + int32_t* total_tokens_post_pad; + int32_t num_experts; + int32_t block_size; + size_t numel; + bool pad_sorted_token_ids; + sycl::local_accessor slm_; +}; + +void moe_align_block_size( + torch::Tensor topk_ids, + int64_t num_experts, + int64_t block_size, + torch::Tensor sorted_token_ids, + torch::Tensor experts_ids, + torch::Tensor num_tokens_post_pad, + torch::Tensor cumsum_buffer, + bool pad_sorted_token_ids) { + auto q = sycl::queue(); + + int threads = 1024; + threads = ((threads + sub_group_size - 1) / sub_group_size) * sub_group_size; + + DISPATCH_INTEGRAL_TYPES(topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { + auto stream = at::xpu::getCurrentXPUStream(); + auto queue = stream.queue(); + bool small_batch_expert_mode = (topk_ids.numel() < 1024) && (num_experts <= 64); + + if (small_batch_expert_mode) { + const int32_t threads_local = std::max((int32_t)num_experts, sub_group_size); + auto range = sycl::nd_range<1>(sycl::range<1>(threads_local), sycl::range<1>(threads_local)); + using SmallKernel = MOEAlignBlockSizeSmallBatchExpertFunctor; + SmallKernel kernel( + topk_ids.data_ptr(), + sorted_token_ids.data_ptr(), + experts_ids.data_ptr(), + num_tokens_post_pad.data_ptr(), + num_experts, + block_size, + topk_ids.numel(), + pad_sorted_token_ids); + sycl_kernel_submit(range.get_global_range(), range.get_local_range(), queue, kernel); + } else { + const size_t scan_size = next_pow2(num_experts); + const size_t shared_mem_size = (num_experts + (num_experts + 1) + scan_size + sub_group_size) * sizeof(int32_t); + using Kernel = MOEAlignBlockSizeFunctor; + Kernel kernel( + topk_ids.data_ptr(), + sorted_token_ids.data_ptr(), + experts_ids.data_ptr(), + num_tokens_post_pad.data_ptr(), + num_experts, + block_size, + topk_ids.numel(), + cumsum_buffer.data_ptr(), + pad_sorted_token_ids, + scan_size); + auto range = sycl::nd_range<1>(sycl::range<1>(threads), sycl::range<1>(threads)); + auto local_range = range.get_local_range(); + sycl_kernel_submit(range.get_global_range(), range.get_local_range(), queue, kernel); + + const int block_threads = std::min(256, (int)threads); + const int num_blocks = (topk_ids.numel() + block_threads - 1) / block_threads; + + using SortKernel = CountAndSortExpertTokensFunctor; + SortKernel count_and_sort_kernel( + topk_ids.data_ptr(), + sorted_token_ids.data_ptr(), + cumsum_buffer.data_ptr(), + topk_ids.numel()); + auto sort_range = sycl::nd_range<1>(sycl::range<1>(num_blocks * block_threads), sycl::range<1>(block_threads)); + sycl_kernel_submit(sort_range.get_global_range(), sort_range.get_local_range(), queue, count_and_sort_kernel); + } + }); +} diff --git a/src/sycl/MoESum.cpp b/src/sycl/MoESum.cpp new file mode 100644 index 0000000..c05cd13 --- /dev/null +++ b/src/sycl/MoESum.cpp @@ -0,0 +1,76 @@ +#include +#include +#include + +#include + +#include "SYCLHelpers.h" +#include "Utils.h" + +template +struct MoeSumKernel { + MoeSumKernel(scalar_t* out_, const scalar_t* input_, int hidden_size_) + : out(out_), input(input_), hidden_size(hidden_size_) {} + + void operator()(sycl::nd_item<1> item) const { + int64_t global_idx = item.get_global_id(0); + int64_t token_idx = global_idx / hidden_size; + int idx = global_idx % hidden_size; + + scalar_t x = 0.0; +#pragma unroll + for (int k = 0; k < 2; ++k) { + x += input[token_idx * 2 * hidden_size + k * hidden_size + idx]; + } + out[token_idx * hidden_size + idx] = x; + } + + scalar_t* out; + const scalar_t* input; + int hidden_size; +}; + +void moe_sum( + torch::Tensor& input, // [num_tokens, topk, hidden_size] + torch::Tensor& output) // [num_tokens, hidden_size] +{ + const int hidden_size = input.size(-1); + const auto num_tokens = output.numel() / hidden_size; + const int topk = input.size(1); + + auto stream = at::xpu::getCurrentXPUStream(); + auto queue = stream.queue(); + sycl::range<1> global(num_tokens); + sycl::range<1> local(std::min(hidden_size, 1024)); + auto range = sycl::nd_range<1>(global * local, local); + + switch (topk) { + case 2: { + DISPATCH_FLOAT_TYPES(input.scalar_type(), "moe_sum", [&] { + using Kernel = MoeSumKernel; + Kernel kernel(output.data_ptr(), input.data_ptr(), hidden_size); + sycl_kernel_submit(range.get_global_range(), range.get_local_range(), queue, kernel); + }); + break; + } + case 3: { + DISPATCH_FLOAT_TYPES(input.scalar_type(), "moe_sum", [&] { + using Kernel = MoeSumKernel; + Kernel kernel(output.data_ptr(), input.data_ptr(), hidden_size); + sycl_kernel_submit(range.get_global_range(), range.get_local_range(), queue, kernel); + }); + break; + } + case 4: { + DISPATCH_FLOAT_TYPES(input.scalar_type(), "moe_sum", [&] { + using Kernel = MoeSumKernel; + Kernel kernel(output.data_ptr(), input.data_ptr(), hidden_size); + sycl_kernel_submit(range.get_global_range(), range.get_local_range(), queue, kernel); + }); + break; + } + default: + at::sum_out(output, input, 1); + break; + } +} diff --git a/src/sycl/MoEOps.cpp b/src/sycl/TopKSoftMax.cpp similarity index 100% rename from src/sycl/MoEOps.cpp rename to src/sycl/TopKSoftMax.cpp diff --git a/src/sycl/Utils.h b/src/sycl/Utils.h index 1869b79..de88185 100644 --- a/src/sycl/Utils.h +++ b/src/sycl/Utils.h @@ -12,6 +12,23 @@ TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define DISPATCH_CASE_INTEGRAL_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) + +#define DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__)) + +#define DISPATCH_CASE_FLOAT_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) + +#define DISPATCH_FLOAT_TYPES(TYPE, NAME, ...) AT_DISPATCH_SWITCH(TYPE, NAME, DISPATCH_CASE_FLOAT_TYPES(__VA_ARGS__)) + using DeviceId = at::DeviceIndex; static inline DeviceId dpcppGetDeviceIdOfCurrentQueue() { diff --git a/src/sycl/chunked_prefill.cpp b/src/sycl/chunked_prefill.cpp index 05ae343..3e1c1cd 100644 --- a/src/sycl/chunked_prefill.cpp +++ b/src/sycl/chunked_prefill.cpp @@ -153,8 +153,7 @@ struct Flash_fwd_params { int* __restrict__ num_splits_dynamic_ptr; bool skip_scheduler_metadata_computation; - int arch; - int num_sm; + torch::TensorOptions tensor_opts; }; template @@ -338,17 +337,17 @@ struct KernelRunner { // Define device-global scratch memory size_t workspace_size = FMHAChunkPrefillKernel::get_workspace_size(arguments); - cutlass::device_memory::allocation workspace(workspace_size); + auto workspace = torch::empty(workspace_size, params.tensor_opts); if (!FMHAChunkPrefillKernel::can_implement(arguments)) { return cutlass::Status::kErrorInvalidProblem; } // Initialize the workspace - (FMHAChunkPrefillKernel::initialize_workspace(arguments, workspace.get())); + (FMHAChunkPrefillKernel::initialize_workspace(arguments, workspace.data_ptr())); // Convert host-side arguments to device-side arguments to be passed to the kernel - auto params_kernel = FMHAChunkPrefillKernel::to_underlying_arguments(arguments, workspace.get()); + auto params_kernel = FMHAChunkPrefillKernel::to_underlying_arguments(arguments, workspace.data_ptr()); // Run the Flash Attention implementation. run(params_kernel); @@ -680,7 +679,7 @@ std::vector mha_fwd( TORCH_CHECK( q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16, "q_v is only supported for fp16 and bf16 data type"); - TORCH_CHECK(params.arch == 90, "q_v is only supported for Hopper GPUs"); + TORCH_CHECK(false, "q_v is not supported yet"); at::Tensor q_v = q_v_.value(); TORCH_CHECK(q_v.dtype() == q_type, "q_v must have the same dtype as query"); CHECK_DEVICE(q_v); @@ -733,6 +732,8 @@ std::vector mha_fwd( params.kv_batch_idx = reinterpret_cast(kv_batch_idx.data_ptr()); } + params.tensor_opts = torch::TensorOptions().dtype(torch::kUInt8).device(q.device()); + at::Tensor out_accum, softmax_lse_accum; auto outaccum_type = at::ScalarType::Float; diff --git a/src/sycl/kernels/moe/dispatch_policy.hpp b/src/sycl/kernels/moe/dispatch_policy.hpp new file mode 100644 index 0000000..6dcc4ec --- /dev/null +++ b/src/sycl/kernels/moe/dispatch_policy.hpp @@ -0,0 +1,11 @@ +#pragma once + +#include "cutlass/gemm/dispatch_policy.hpp" + +namespace cutlass::gemm { + +struct KernelXeMoEGEMM {}; +// partial specialization for KernelXeMoEGEMM +template +struct MainloopIntelXeXMX16MoE : MainloopIntelXeXMX16 {}; +} // namespace cutlass::gemm diff --git a/src/sycl/kernels/moe/xe_array_epilogue.hpp b/src/sycl/kernels/moe/xe_array_epilogue.hpp new file mode 100644 index 0000000..74309f7 --- /dev/null +++ b/src/sycl/kernels/moe/xe_array_epilogue.hpp @@ -0,0 +1,503 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing elementwise operations used by epilogues. +*/ + +#pragma once + +#include + +#include "cute/tensor.hpp" +#include "cutlass/cutlass.h" +#include "cutlass/detail/layout.hpp" +#include "cutlass/epilogue/collective/collective_epilogue.hpp" +#include "cutlass/epilogue/collective/detail.hpp" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/fusion/callbacks.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp" +#include "cutlass/epilogue/fusion/xe_visitor_softmax.hpp" +#include "cutlass/util/packed_stride.hpp" +#include "dispatch_policy.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +struct IntelXeXMX16MoE { + static constexpr int SubgroupSize = 16; +}; + +namespace collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class CtaTileMNK_, + class ElementC_, + class StrideC_, + class ElementD_, + class StrideD_, + class FusionCallbacks_, + class CopyOpG2R_, + class SmemLayoutAtomC_, + class CopyOpS2R_, + class CopyOpR2G_, + class SmemLayoutAtomD_, + class CopyOpR2S_> +class CollectiveEpilogue< + IntelXeXMX16MoE, + CtaTileMNK_, + ElementC_, + StrideC_, + ElementD_, + StrideD_, + FusionCallbacks_, + CopyOpG2R_, + SmemLayoutAtomC_, + CopyOpS2R_, + CopyOpR2G_, + SmemLayoutAtomD_, + CopyOpR2S_> { + public: + // + // Type Aliases + // + using DispatchPolicy = IntelXeXMX16Group; + using CtaTileMNK = CtaTileMNK_; + using FusionCallbacks = FusionCallbacks_; + using ElementC = ElementC_; + using ElementAccumulator = ElementC_; + using StrideC = StrideC_; + using InternalStrideC = cute::remove_pointer_t; + using ElementD = ElementD_; + using StrideD = StrideD_; + using InternalStrideD = cute::remove_pointer_t; + using CopyOpG2R = CopyOpG2R_; + using SmemLayoutAtomC = SmemLayoutAtomC_; + using CopyOpS2R = CopyOpS2R_; + using CopyOpR2G = CopyOpR2G_; + using SmemLayoutAtomD = SmemLayoutAtomD_; + using CopyOpR2S = CopyOpR2S_; + + using ThreadEpilogueOp = typename fusion::FusionCallbacksTraits::Operation; + using GmemTiledCopyC = CopyOpG2R; + using GmemTiledCopyD = cute:: + conditional_t && not cute::is_void_v, CopyOpR2G, XE_2D_U32x8x16_ST_N>; + using ElementOutput = ElementD; + using ElementCompute = ElementAccumulator; + using ElementSource = typename FusionCallbacks::ElementSource; + using ElementScalar = typename FusionCallbacks::ElementScalar; + static constexpr FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest; + + static_assert( + cute::is_same_v< + typename FusionCallbacks::Operation, + fusion::LinearCombination>, + "Only Linear Combination Epilogue is supported for Grouped GEMM at the moment."); + + static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize; + + static_assert(cute::rank(CtaTileMNK{}) == 3, "CtaTileMNK must be rank-3: [CTA_M, CTA_N, CTA_K]"); + static_assert(cute::rank(InternalStrideC{}) == 3, "StrideC must be rank-3: [M, N, L]"); + static_assert(cute::rank(InternalStrideD{}) == 3, "StrideD must be rank-3: [M, N, L]"); + + static_assert(std::is_same_v, "Copy operation to shared memory is not supported"); + static_assert(std::is_same_v, "Copy operation to shared memory is not supported"); + static_assert(std::is_same_v, "Copy operation to shared memory is not supported"); + static_assert(std::is_same_v, "Copy operation to shared memory is not supported"); + + using CopyThreadShape = Shape<_1, Int>; + using Trait_C = Copy_Traits; + using XE_Copy_C = decltype(make_tiled_copy( + Copy_Atom{}, + Layout{}, + make_layout(shape_div(typename Trait_C::BlockShape{}, CopyThreadShape{})))); + using Trait_D = Copy_Traits; + using XE_Copy_D = decltype(make_tiled_copy( + Copy_Atom{}, + Layout{}, + make_layout(shape_div(typename Trait_D::BlockShape{}, CopyThreadShape{})))); + + private: + constexpr static bool is_source_supported = + not cute::is_void_v && FusionCallbacks::Operation::IsSourceSupported; + constexpr static bool is_destination_supported = not cute::is_void_v && not cute::is_void_v; + + public: + using EmptyType = cute::tuple<>; + using SmemCStorage = EmptyType; + using SmemDStorage = EmptyType; + + struct TensorStorageImpl : cute::tuple { + using FusionStorage = typename FusionCallbacks::SharedStorage; + FusionStorage thread; + }; + + struct SharedStorage { + using TensorStorage = TensorStorageImpl; + + TensorStorage tensors; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + + using TensorC = decltype(make_tensor( + make_gmem_ptr(static_cast(nullptr)), make_shape(0, 0, 0), InternalStrideC{})); //(m, n) + using TensorD = decltype(make_tensor( + make_gmem_ptr(static_cast(nullptr)), make_shape(0, 0, 0), InternalStrideD{})); //(m, n) + using EpilogueTensors = cute::tuple; + + // Host side epilogue arguments + struct Arguments { + typename FusionCallbacks::Arguments thread{}; + ElementC const** ptr_C; + StrideC dC; + ElementD** ptr_D; + StrideD dD; + }; + + // Device side epilogue params + struct Params { + typename FusionCallbacks::Params thread{}; + XE_Copy_C xe_load_c; + XE_Copy_D xe_store_d; + ElementC const** ptr_C; + StrideC dC; + ElementD** ptr_D; + StrideD dD; + }; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, [[maybe_unused]] void* workspace) { + // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) + auto problem_shape_MNL = repeat_like(typename ProblemShape::UnderlyingProblemShape{}, int32_t(1)); + auto [M, N, L] = problem_shape_MNL; + + XE_Copy_C xe_load_c = {}; + if constexpr (is_source_supported) { + ElementC const* ptr_C_first_batch = reinterpret_cast(args.ptr_C); + TensorC mC_mnl = + make_tensor(make_gmem_ptr(ptr_C_first_batch), make_layout(make_shape(M, N, L), InternalStrideC{})); + xe_load_c = {xe_load_c.with(mC_mnl)}; + } + + XE_Copy_D xe_store_d = {}; + if constexpr (is_destination_supported) { + ElementD* ptr_D_first_batch = reinterpret_cast(args.ptr_D); + TensorD mD_mnl = + make_tensor(make_gmem_ptr(ptr_D_first_batch), make_layout(make_shape(M, N, L), InternalStrideD{})); + xe_store_d = {xe_store_d.with(mD_mnl)}; + } + + return { + FusionCallbacks::to_underlying_arguments(problem_shape, args.thread, workspace), + xe_load_c, + xe_store_d, + args.ptr_C, + args.dC, + args.ptr_D, + args.dD}; + } + + template + static size_t get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status initialize_workspace( + ProblemShape const& problem_shape, + Arguments const& args, + void* workspace, + cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return Status::kSuccess; + } + + template + static bool can_implement(ProblemShape problem_shape, Arguments const& args) { + constexpr int copy_alignment_bits = 128; + constexpr int batch_alignment_bits = 512; + + bool implementable = true; + bool fusion_implementable = true; + + for (int i = 0; i < problem_shape.groups(); ++i) { + auto problem_shape_MNKL = append<4>(problem_shape.get_host_problem_shape(i), 1); + auto [M, N, K, L] = problem_shape_MNKL; + + if constexpr (is_destination_supported) { + constexpr int min_aligned_elements_D = copy_alignment_bits / sizeof_bits::value; + implementable &= + cutlass::detail::check_alignment(cute::make_shape(M, N, L), InternalStrideD{}); + if (L > 1) { + constexpr int min_batch_aligned_elements_D = batch_alignment_bits / sizeof_bits::value; + implementable &= get<2>(InternalStrideD{}) % min_batch_aligned_elements_D == 0; + } + } + + if constexpr (is_source_supported) { + constexpr int min_aligned_elements_C = copy_alignment_bits / sizeof_bits::value; + implementable &= + cutlass::detail::check_alignment(cute::make_shape(M, N, L), InternalStrideC{}); + if (L > 1) { + constexpr int min_batch_aligned_elements_C = batch_alignment_bits / sizeof_bits::value; + implementable &= get<2>(InternalStrideC{}) % min_batch_aligned_elements_C == 0; + } + } + + fusion_implementable = fusion_implementable && FusionCallbacks::can_implement(problem_shape_MNKL, args.thread); + } + + if (!implementable) { + CUTLASS_TRACE_HOST( + " CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for XE 2D copy.\n"); + } + + if (!fusion_implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum requirements for FusionCallbacks.\n"); + } + + return implementable && fusion_implementable; + } + + CUTLASS_HOST_DEVICE + CollectiveEpilogue(Params const& params_, TensorStorage const& shared_storage_) + : params(params_), fusion_callbacks(params_.thread, shared_storage_.thread) {} + + CUTLASS_DEVICE + bool is_producer_load_needed() const { + return fusion_callbacks.is_producer_load_needed(); + } + + template < + class ProblemShapeMNKL, + class TileShapeMNK, + class TileCoordMNKL, + class Accumulator, + class TiledMma, + class LoadStoreTensor> + CUTLASS_DEVICE void operator()( + ProblemShapeMNKL problem_shape_mnkl, + TileShapeMNK tile_shape_MNK, + TileCoordMNKL tile_coord_mnkl, + Accumulator accumulators, + TiledMma tiled_mma, + int thread_idx, + LoadStoreTensor const& load_store_tensors) { + (void)tiled_mma; + using namespace cute; + + static_assert(cute::rank(CtaTileMNK{}) == 3, "CtaTileMNK must be rank-3: [CTA_M, CTA_N, CTA_K]"); + static_assert(cute::rank(InternalStrideC{}) == 3, "StrideC must be rank-3: [M, N, L]"); + static_assert(cute::rank(InternalStrideD{}) == 3, "StrideD must be rank-3: [M, N, L]"); + + using MmaAtomShape = typename TiledMma::AtomShape_MNK; + static constexpr auto BLK_M = get<0>(CtaTileMNK{}); + static constexpr auto BLK_N = get<1>(CtaTileMNK{}); + static constexpr auto BLK_K = get<2>(CtaTileMNK{}); + + static constexpr auto ATOM_M = get<1>(typename TiledMma::ThrLayoutVMNK{}.shape()); + static constexpr auto ATOM_N = get<2>(typename TiledMma::ThrLayoutVMNK{}.shape()); + static constexpr auto ATOM_K = get<3>(typename TiledMma::ThrLayoutVMNK{}.shape()); + + static_assert( + BLK_M % ATOM_M == 0 && BLK_N % ATOM_N == 0 && BLK_K % ATOM_K == 0, + "expected CTATileMNK to be evenly divided by TiledMma::ThrLayoutVMNK"); + static constexpr auto SG_M = BLK_M / ATOM_M; + static constexpr auto SG_N = BLK_N / ATOM_N; + static constexpr auto SG_K = BLK_K / ATOM_K; + using SubgroupTileShape = Shape; + + static constexpr int FragsM = get<0>(SubgroupTileShape{}) / get<0>(MmaAtomShape()); // A frags per sub_group + static constexpr int FragsN = get<1>(SubgroupTileShape{}) / get<1>(MmaAtomShape()); // B frags per sub_group + + static constexpr int FragmentSize = (get<0>(MmaAtomShape()) * get<1>(MmaAtomShape())) / SubgroupSize; + + // Indexing variables + auto [M, N, K, L] = problem_shape_mnkl; + auto [m_coord, n_coord, k_coord, l_coord] = tile_coord_mnkl; + auto m_sg = get_sub_group_id() / ATOM_N; + auto n_sg = get_sub_group_id() % ATOM_N; + + auto mn_shape = shape(typename decltype(params.xe_store_d)::Tiler_MN{}); + + auto sg_local_m_coord = get_sub_group_id() / ATOM_N; + auto sg_local_n_coord = get_sub_group_id() % ATOM_N; + + auto sg_m_coord = m_coord * ATOM_M + sg_local_m_coord; + auto sg_n_coord = n_coord * ATOM_N + sg_local_n_coord; + auto sg_coord = make_coord(sg_m_coord, sg_n_coord, k_coord, l_coord); + + bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed(); + + // Represent the full output tensor + Tensor mD_mnl = cute::get_xe_tensor(make_shape(M, N, L)); + + // Tile the output tensor per WG and select the tile for current WG + Tensor g_wg_D = + local_tile(mD_mnl, take<0, 2>(CtaTileMNK{}), make_coord(m_coord, n_coord, l_coord)); // (BLK_M,BLK_N) + + // Tile the output tensor per SG and select tile for the current SG + Tensor gD = local_tile(g_wg_D, take<0, 2>(SubgroupTileShape{}), make_coord(m_sg, n_sg)); // (SG_M,SG_N) + + auto thread_xe_store_d = params.xe_store_d.get_thread_slice(thread_idx); + Tensor tCgD = thread_xe_store_d.partition_D(gD); + + Tensor trC = make_tensor(Shape>{}); + Tensor trD_compute = make_tensor(Shape>{}); + + // Because Sm90 uses shared memory, they are not tied to using the same accumulator values + // for MMA and Epilogue. But because we are operating directly in the accumulators, we need to be + // sure that we are operating on the same values. + ThrCopy thread_g2r = params.xe_load_c.get_slice(thread_idx); + + // OOB predication for tile quantization "residue" + // Absolute coordinate tensors (dynamic) + Tensor mD_crd = make_identity_tensor(make_shape(M, N)); // (M,N) + Tensor cD = local_tile(mD_crd, take<0, 2>(SubgroupTileShape{}), make_coord(sg_m_coord, sg_n_coord)); + Tensor cD_mn = local_tile(mD_crd, take<0, 2>(CtaTileMNK{}), make_coord(m_coord, n_coord)); // (CTA_M,CTA_N) + Tensor tRS_cD_mn = thread_g2r.partition_S(flat_divide(cD_mn, mn_shape)); // (G2R,G2R_M,G2R_N,EPI_M,EPI_N) + + Tensor tRS_cD = make_coord_tensor(tRS_cD_mn.layout()); // (G2R,G2R_M,G2R_N,EPI_M,EPI_N) + + // Get the fusion callbacks + // Arguments passed here relate to sub-group tiles, rather than CTA (work-group) tiles + constexpr bool RefSrc = true; + auto residue_mn = make_coord(M, N); // TODO(Codeplay): this is not correct + auto cst_args = cutlass::epilogue::fusion::detail::ConsumerStoreArgs{ + problem_shape_mnkl, + SubgroupTileShape{}, + sg_coord, + tiled_mma, + mn_shape, + params.xe_store_d, + cD, + residue_mn, + tRS_cD, + residue_mn, + trC, + thread_idx, + }; + auto cst_callbacks = fusion_callbacks.template get_consumer_store_callbacks(cst_args); + + cst_callbacks.begin(); + + auto acc_frag = recast>(accumulators); + auto trD_compute_frag = recast>(trD_compute); + + Tensor trD = make_tensor(Shape>{}); + auto trD_frag = recast>(trD); + + constexpr int ValuesLoaded = FragsM * FragsN * FragmentSize * SubgroupSize * ATOM_M * ATOM_N * ATOM_K; + constexpr int MN = get<0>(CtaTileMNK{}) * get<1>(CtaTileMNK{}); + static_assert(ValuesLoaded == MN, "the total elements loaded by all threads should be the same as MxN"); + + auto synchronize = [&]() {}; + CUTLASS_PRAGMA_UNROLL + for (int epi_n = 0; epi_n < FragsN; epi_n++) { + CUTLASS_PRAGMA_UNROLL + for (int epi_m = 0; epi_m < FragsM; epi_m++) { + if (is_C_load_needed) { + // coordinates for C and D are the same + copy(params.xe_load_c.with(get<0>(load_store_tensors)), tCgD(_, epi_m, epi_n), trC); + } + + cst_callbacks.previsit(epi_m, epi_n, 0, is_C_load_needed); + + auto acc_frag_mn = acc_frag(_, epi_m, epi_n); + + CUTLASS_PRAGMA_UNROLL + for (int epi_v = 0; epi_v < size<0>(trD_compute_frag); ++epi_v) { + trD_compute_frag(epi_v) = cst_callbacks.visit(acc_frag_mn(epi_v), epi_v, epi_m, epi_n); + } + cst_callbacks.reduce( + nullptr, synchronize, epi_m, epi_n, (epi_m == FragsM - 1 && epi_n == FragsN - 1), trD_compute_frag); + + if constexpr (is_destination_supported) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(trD_compute_frag); ++i) { + trD_frag(i) = + cutlass::NumericArrayConverter{}(trD_compute_frag(i)); + } + copy(params.xe_store_d.with(get<1>(load_store_tensors)), trD, tCgD(_, epi_m, epi_n)); + } + } + } + + cst_callbacks.end(); + } + + template + CUTLASS_DEVICE auto update_tensor_shape_stride( + int32_t const& next_group, ProblemShape_MNKL const& problem_shape_mnkl, const int32_t* num_rows_per_expert) { + auto [M, N, K, L] = problem_shape_mnkl; + int32_t cumulative_M = 0; + for (int i = 0; i < next_group; i++) { + cumulative_M += num_rows_per_expert[i]; + } + M = num_rows_per_expert[next_group]; + + TensorC mC_mnl; + TensorD mD_mnl; + if constexpr (is_source_supported) { + ElementC const* ptr_C_curr_batch = reinterpret_cast((void*)(params.ptr_C)) + cumulative_M * N; + mC_mnl = make_tensor( + make_gmem_ptr(ptr_C_curr_batch), + make_layout(make_shape(M, N, L), cutlass::make_cute_packed_stride(InternalStrideC{}, {M, N, 1}))); + } + + if constexpr (is_destination_supported) { + ElementD* ptr_D_curr_batch = reinterpret_cast((void*)(params.ptr_D)) + cumulative_M * N; + mD_mnl = make_tensor( + make_gmem_ptr(ptr_D_curr_batch), + make_layout(make_shape(M, N, L), cutlass::make_cute_packed_stride(InternalStrideD{}, {M, N, 1}))); + } + return cute::make_tuple(mC_mnl, mD_mnl); + } + + private: + Params const& params; + FusionCallbacks fusion_callbacks; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace collective +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/src/sycl/kernels/moe/xe_array_mma.hpp b/src/sycl/kernels/moe/xe_array_mma.hpp new file mode 100644 index 0000000..0f01ea5 --- /dev/null +++ b/src/sycl/kernels/moe/xe_array_mma.hpp @@ -0,0 +1,374 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 Codeplay Software Ltd. All rights reserved. + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cute/algorithm/functional.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/util/packed_stride.hpp" +#include "dispatch_policy.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + int Stages, + class Schedule, + class TileShape_, + class ElementA_, + class StrideA_, + class ElementB_, + class StrideB_, + class TiledMma_, + class GmemTiledCopyA_, + class SmemLayoutAtomA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopIntelXeXMX16MoE, + TileShape_, + ElementA_, + StrideA_, + ElementB_, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_> { + // + // Type Aliases + // + using DispatchPolicy = MainloopIntelXeXMX16Group; + using WorkgroupTileShape = TileShape_; + using ElementA = ElementA_; + using StrideA = StrideA_; + using InternalStrideA = cute::remove_pointer_t; + using ElementB = ElementB_; + using StrideB = StrideB_; + using InternalStrideB = cute::remove_pointer_t; + using TiledMma = TiledMma_; + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + static_assert( + platform::is_same::value, "MainloopIntelXeXMX16Array requires that A and B have same type."); + + static_assert( + std::is_same_v, "Transformation for A is not currently supported on Intel PVC"); + static_assert( + std::is_same_v, "Transformation for B is not currently supported on Intel PVC"); + + static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize; + + using MmaAtomShape = typename TiledMma::AtomShape_MNK; + + static constexpr int BLK_M = get<0>(WorkgroupTileShape{}); + static constexpr int BLK_N = get<1>(WorkgroupTileShape{}); + static constexpr int BLK_K = get<2>(WorkgroupTileShape{}); + + static constexpr int ATOM_M = get<1>(typename TiledMma::ThrLayoutVMNK{}.shape()); + static constexpr int ATOM_N = get<2>(typename TiledMma::ThrLayoutVMNK{}.shape()); + static constexpr int ATOM_K = get<3>(typename TiledMma::ThrLayoutVMNK{}.shape()); + + static constexpr int SG_M = ceil_div(BLK_M, ATOM_M); + static constexpr int SG_N = ceil_div(BLK_N, ATOM_N); + static constexpr int SG_K = ceil_div(BLK_K, ATOM_K); + using SubgroupTileShape = Shape, C, C>; + + static constexpr int Num_SGs = ATOM_N * ATOM_M * ATOM_K; + static constexpr uint32_t MaxThreadsPerBlock = size(TiledMma{}); + + using Copy_A = typename Copy_Traits::template DefaultTiledCopy; + using Copy_B = typename Copy_Traits::template DefaultTiledCopy; + + using TensorMKL = decltype(make_tensor( + make_gmem_ptr(static_cast(nullptr)), make_shape(0, 0, 0), InternalStrideA{})); //(m, k) + using TensorNKL = decltype(make_tensor( + make_gmem_ptr(static_cast(nullptr)), make_shape(0, 0, 0), InternalStrideB{})); //(n, k) + using MainloopTensors = cute::tuple; + // Host side kernel arguments + struct Arguments { + ElementA const** ptr_A; + StrideA dA; + ElementB const** ptr_B; + StrideB dB; + }; + + struct Params { + ElementA const** ptr_A; + StrideA dA; + ElementB const** ptr_B; + StrideB dB; + }; + + // + // Methods + // + + CollectiveMma() = default; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + (void)workspace; + + auto problem_shape_MNK = repeat_like(typename ProblemShape::UnderlyingProblemShape{}, int32_t(1)); + ; + auto init_M = get<0>(problem_shape_MNK); + auto init_N = get<1>(problem_shape_MNK); + auto init_K = get<2>(problem_shape_MNK); + + return Params{args.ptr_A, args.dA, args.ptr_B, args.dB}; + } + + template + static bool can_implement(ProblemShape problem_shapes, Arguments const& args) { + constexpr int copy_alignment_bits = 128; + constexpr int batch_alignment_bits = 512; + auto problem_shape_MNKL = append<4>(problem_shapes, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + bool implementable = true; + + constexpr int min_aligned_elements_A = copy_alignment_bits / sizeof_bits::value; + constexpr int min_aligned_elements_B = copy_alignment_bits / sizeof_bits::value; + constexpr int min_batch_aligned_elements_A = batch_alignment_bits / sizeof_bits::value; + constexpr int min_batch_aligned_elements_B = batch_alignment_bits / sizeof_bits::value; + for (int i = 0; i < problem_shapes.groups(); i++) { + auto problem_shape_MNKL = append<4>(problem_shapes.get_host_problem_shape(i), 1); + auto [M, N, K, L] = problem_shape_MNKL; + + implementable &= + cutlass::detail::check_alignment(cute::make_shape(M, K, L), InternalStrideA{}); + implementable &= + cutlass::detail::check_alignment(cute::make_shape(N, K, L), InternalStrideB{}); + + if (L > 1) { + implementable &= get<2>(InternalStrideA{}) % min_batch_aligned_elements_A == 0; + implementable &= get<2>(InternalStrideB{}) % min_batch_aligned_elements_B == 0; + } + } + + if (!implementable) { + CUTLASS_TRACE_HOST( + " CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for XE 2D copy.\n"); + } + + return implementable; + } + + /// Perform a subgroup-scoped matrix multiply-accumulate + template < + class FrgTensorD, + class TensorA, + class TensorB, + class FrgTensorC, + class KTileIterator, + class BlkCoord, + class LoadTensors> + CUTLASS_DEVICE void operator()( + FrgTensorD& accum, + TensorA gA, + TensorB gB, + FrgTensorC const& src_accum, + KTileIterator k_tile_iter, + int const& k_tile_count, + BlkCoord const& blk_coord, + int const& K_start, + int const& thread_idx, + Params const& mainloop, + LoadTensors const& load_tensors) { + static_assert(is_rmem::value, "D tensor must be rmem resident."); + static_assert(is_rmem::value, "C tensor must be rmem resident."); + + (void)thread_idx; + + Copy_A tiled_copy_a{Copy_A{}.with(get<0>(load_tensors))}; + Copy_B tiled_copy_b{Copy_B{}.with(get<1>(load_tensors))}; + + auto thr_copy_A = tiled_copy_a.get_slice(thread_idx); + auto thr_copy_B = tiled_copy_b.get_slice(thread_idx); + + // Instantiate the MMA object and get thread slice + TiledMma tiled_mma; + // TODO(Codeplay): see if we can make this nicer + // To make all work items in a subgroup have the same global tensors pass in the index of work item 0 in each + // subgroup + auto sg = compat::get_nd_item<1>().get_sub_group(); + auto first_thread_in_sg_idx = sg.get_group_linear_id() * DispatchPolicy::SubgroupSize; + auto thr_mma = tiled_mma.get_slice(first_thread_in_sg_idx); + + // Partition global counting tensors for MMA + Tensor tCgA = thr_mma.partition_A(gA); + Tensor tCgB = thr_mma.partition_B(gB); + + Tensor tCrA = make_tensor(make_fragment_layout(tiled_copy_a, tCgA(_, _, _, 0).shape())); + Tensor tCrB = make_tensor(make_fragment_layout(tiled_copy_b, tCgB(_, _, _, 0).shape())); + + // Retile registers for copies + Tensor tArA = thr_copy_A.retile_D(tCrA); + Tensor tBrB = thr_copy_B.retile_D(tCrB); + + // Retile global counting tensors for copies + Tensor tAgA = thr_copy_A.retile_S(tCgA); + Tensor tBgB = thr_copy_B.retile_S(tCgB); + + auto tiled_prefetch_a = cute::prefetch_selector, Int>, Num_SGs>(tiled_copy_a); + auto tiled_prefetch_b = cute::prefetch_selector, Int>, Num_SGs>(tiled_copy_b); + auto thr_prefetch_A = tiled_prefetch_a.get_slice(thread_idx); + auto thr_prefetch_B = tiled_prefetch_b.get_slice(thread_idx); + + // Partition global tile for prefetch + auto pAgA = thr_prefetch_A.partition_S(gA); + auto pBgB = thr_prefetch_B.partition_S(gB); + +#if CUTLASS_ENABLE_DEBUG_PRINTS + if (cutlass::thread(LOG_THREAD, LOG_GROUP)) { + print("======================= A: \n"); + print(" gA : "); + print(gA); + print("\n"); + print("tCgA : "); + print(tCgA); + print("\n"); + print("tAgA : "); + print(tAgA); + print("\n"); + + print("===================== B :\n"); + print(" gB : "); + print(gB); + print("\n"); + print("tCgB : "); + print(tCgB); + print("\n"); + print("tBgB : "); + print(tBgB); + print("\n"); + + print("===================== Config: \n"); + print(" threads per workgroup : "); + print(MaxThreadsPerBlock); + print("\n"); + print(" SubgroupTileShape : "); + print(SubgroupTileShape{}); + print("\n"); + } +#endif + + // + // Mainloop + // + const auto k_start_idx = crd2idx((*k_tile_iter), make_shape(K_start)); + constexpr int barrier_scope = 2; + int prefetch_k = k_start_idx; + + CUTLASS_PRAGMA_UNROLL + for (; prefetch_k < DispatchPolicy::Stages; prefetch_k++) { + prefetch(tiled_prefetch_a, pAgA(_, _, _, prefetch_k)); + prefetch(tiled_prefetch_b, pBgB(_, _, _, prefetch_k)); + } + + for (int k_tile = k_start_idx; k_tile < k_tile_count + k_start_idx; k_tile++, prefetch_k++) { + barrier_arrive(barrier_scope); + // Copy gmem to rmem for the first k_tile + copy(tiled_copy_a, tAgA(_, _, _, k_tile), tArA); + copy(tiled_copy_b, tBgB(_, _, _, k_tile), tBrB); + + if (prefetch_k < k_tile_count) { + prefetch(tiled_prefetch_a, pAgA(_, _, _, prefetch_k)); + prefetch(tiled_prefetch_b, pBgB(_, _, _, prefetch_k)); + } + + cute::gemm(tiled_mma, tCrA, tCrB, accum); + barrier_wait(barrier_scope); + } + } + + template + CUTLASS_DEVICE auto update_tensor_shape_stride( + Params const& mainloop_params, + int32_t const& next_group, + ProblemShape_MNKL const& problem_shape_mnkl, + const int* num_rows_per_expert) { + int32_t cumulative_M = 0; + for (int i = 0; i < next_group; i++) { + cumulative_M += num_rows_per_expert[i]; + } + + const int32_t M = num_rows_per_expert[next_group]; + const int32_t N = get<1>(problem_shape_mnkl); + const int32_t K = get<2>(problem_shape_mnkl); + + ElementA const* ptr_A_curr_batch = + reinterpret_cast((void*)(mainloop_params.ptr_A)) + cumulative_M * K; + ElementB const* ptr_B_curr_batch = + reinterpret_cast((void*)(mainloop_params.ptr_B)) + next_group * K * N; + + Tensor mA = make_tensor( + make_gmem_ptr(ptr_A_curr_batch), + make_shape(M, K, (int32_t)1), + cutlass::make_cute_packed_stride(InternalStrideA{}, {M, K, 1})); + Tensor mB = make_tensor( + make_gmem_ptr(ptr_B_curr_batch), + make_shape(N, K, (int32_t)1), + cutlass::make_cute_packed_stride(InternalStrideB{}, {N, K, 1})); + + return cute::make_tuple(mA, mB); + } +}; + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/src/sycl/kernels/moe/xe_moe_gemm.hpp b/src/sycl/kernels/moe/xe_moe_gemm.hpp new file mode 100644 index 0000000..6190d1d --- /dev/null +++ b/src/sycl/kernels/moe/xe_moe_gemm.hpp @@ -0,0 +1,1034 @@ +/*************************************************************************************************** + * Copyright 2025 Intel corporation. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cute/tensor.hpp" +#include "cutlass/cuda_host_adapter.hpp" +#include "cutlass/cutlass.h" +#include "cutlass/detail/layout.hpp" +#include "cutlass/detail/mma.hpp" +#include "cutlass/device_kernel.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/tile_scheduler.hpp" +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/kernel_launch.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/sycl_event_manager.hpp" +#include "cutlass/workspace.h" +#include "dispatch_policy.hpp" + +namespace cutlass::gemm::kernel::detail { + +/////////////////////////////////////////////////////////////////////////////// + +// Persistent Thread Block (TB) scheduler for MoE GEMM +template +class PersistentTileSchedulerXeMoE { + // + // Data members + // + + private: + uint64_t current_work_linear_idx_ = 0; + uint64_t total_grid_size_ = 0; + int32_t* num_rows_per_expert_ = nullptr; + int32_t K_ = 0; + int32_t N_ = 0; + int32_t num_experts_ = 0; + + // Tracking current group, its starting linear idx and total tiles + struct GroupInfo { + int group_idx = 0; + uint64_t start_linear_idx = 0; + uint64_t total_tiles = 0; + } current_group_info_; + + public: + struct WorkTileInfo { + int32_t M_idx = 0; + int32_t N_idx = 0; + int32_t L_idx = 0; + bool is_valid_tile = false; + + CUTLASS_HOST_DEVICE + bool is_valid() const { + return is_valid_tile; + } + + CUTLASS_HOST_DEVICE + static WorkTileInfo invalid_work_tile() { + return {-1, -1, -1, false}; + } + + CUTLASS_HOST_DEVICE + bool is_final_split(uint32_t k_tiles_per_output_tile) const { + return true; + } + + CUTLASS_HOST_DEVICE + int32_t reduction_subtile_idx() const { + return -1; + } + }; + + using ProblemShape = typename GroupProblemShape::UnderlyingProblemShape; + using Params = PersistentTileSchedulerSm90GroupParams; + using RasterOrder = typename Params::RasterOrder; + using RasterOrderOptions = typename Params::RasterOrderOptions; + + struct Arguments { + int max_swizzle_size = 1; + // Not applying Heuristics for Grouped problems, since largest dimension can change per group + RasterOrderOptions raster_order = RasterOrderOptions::AlongM; + }; + + // Sink scheduler params as a member + Params scheduler_params; + + // + // Methods + // + + CUTLASS_HOST_DEVICE void configure(int32_t* num_rows_per_expert, int32_t N, int32_t K, int32_t num_experts) { + num_rows_per_expert_ = num_rows_per_expert; + N_ = N; + K_ = K; + num_experts_ = num_experts; + } + + // Given the inputs, computes the total number of output blocks this problem will compute over + // Note that this is only the logical size of our grid, not the physical grid we will actually launch. + template + CUTLASS_HOST_DEVICE static dim3 + get_tiled_cta_shape_mnl(const KernelHardwareInfo& hw_info, ClusterShape cluster_shape) { + uint32_t total_ctas = 0; + uint32_t cta_in_N_dim = 1; // We linearize the blocks across all the problems here + + total_ctas = hw_info.sm_count; + + return Params::get_tiled_cta_shape_mnl(to_gemm_coord(cluster_shape), total_ctas, cta_in_N_dim); + } + + template + static Params to_underlying_arguments( + GroupProblemShape problem_shapes, + TileShape tile_shape, + ClusterShape cluster_shape, + KernelHardwareInfo const& hw_info, + Arguments const& arguments, + [[maybe_unused]] void* workspace = nullptr, + [[maybe_unused]] const uint32_t epilogue_subtile = 1, + [[maybe_unused]] uint32_t ktile_start_alignment_count = 1u) { + // We only need the tile and cluster shape during scheduler setup, so let FTAD do the magic + static_assert(cute::is_static::value); + static_assert(cute::is_static::value); + + dim3 problem_blocks = get_tiled_cta_shape_mnl(hw_info, cluster_shape); + + Params params; + params.initialize( + problem_blocks, + problem_shapes, + to_gemm_coord(tile_shape), + to_gemm_coord(cluster_shape), + hw_info, + arguments.max_swizzle_size, + arguments.raster_order); + + return params; + } + + // Given the inputs, computes the physical grid we should launch. + template + CUTLASS_HOST_DEVICE static dim3 get_grid_shape( + [[maybe_unused]] Params const& params, + GroupProblemShape problem_shapes, + TileShape tile_shape, + ClusterShape cluster_shape, + KernelHardwareInfo hw_info, + Arguments arguments, + bool truncate_by_problem_size = true) { + dim3 problem_blocks = get_tiled_cta_shape_mnl(hw_info, cluster_shape); + + return Params::get_grid_shape( + problem_blocks, + to_gemm_coord(cluster_shape), + hw_info, + arguments.max_swizzle_size, + arguments.raster_order, + /* truncate_by_problem_size = */ true); + } + + static bool can_implement(Arguments const& args) { + return true; + } + + PersistentTileSchedulerXeMoE() = default; + + CUTLASS_DEVICE explicit PersistentTileSchedulerXeMoE(Params const& params_) : scheduler_params(params_) { + // MSVC requires protecting use of CUDA-specific nonstandard syntax, + // like blockIdx and gridDim, with __CUDA_ARCH__. +#if defined(__CUDA_ARCH__) || defined __SYCL_DEVICE_ONLY__ + if (scheduler_params.raster_order_ == RasterOrder::AlongN) { + current_work_linear_idx_ = uint64_t(BlockIdxX()) + uint64_t(BlockIdxY()) * uint64_t(GridDimX()); + } else { + current_work_linear_idx_ = uint64_t(BlockIdxX()) * uint64_t(GridDimY()) + uint64_t(BlockIdxY()); + } + + total_grid_size_ = uint64_t(GridDimX()) * uint64_t(GridDimY()) * uint64_t(GridDimZ()); + +#else + CUTLASS_ASSERT(false && "This line should never be reached"); +#endif + } + + CUTLASS_DEVICE + WorkTileInfo get_current_work() { + return get_current_work_for_linear_idx(current_work_linear_idx_); + } + + CUTLASS_DEVICE + WorkTileInfo get_current_work_for_linear_idx(uint64_t linear_idx) { + if (scheduler_params.pre_processed_problem_shapes && linear_idx >= scheduler_params.blocks_across_problem_) { + return WorkTileInfo::invalid_work_tile(); + } + + return get_work_idx_m_and_n( + linear_idx, + current_group_info_, + scheduler_params.problem_shapes_, + scheduler_params.cta_shape_, + scheduler_params.cluster_shape_, + scheduler_params.divmod_cluster_shape_major_, + scheduler_params.divmod_cluster_shape_minor_, + scheduler_params.divmod_cta_shape_m_, + scheduler_params.divmod_cta_shape_n_, + scheduler_params.log_swizzle_size_, + scheduler_params.raster_order_); + } + + CUTLASS_DEVICE + void advance_to_next_work(uint32_t advance_count = 1) { + current_work_linear_idx_ += total_grid_size_ * uint64_t(advance_count); + } + + // get work_idx_m, work_idx_n from linear_idx while applying swizzle + CUTLASS_DEVICE + WorkTileInfo get_work_idx_m_and_n( + uint64_t linear_idx, + struct GroupInfo& group_info, + GroupProblemShape& problem_shapes, + GemmCoord cta_shape, + GemmCoord cluster_shape, + FastDivmodU64Pow2 const& divmod_cluster_shape_major, + FastDivmodU64Pow2 const& divmod_cluster_shape_minor, + FastDivmodU64 const& divmod_cta_shape_m, + FastDivmodU64 const& divmod_cta_shape_n, + int32_t log_swizzle_size, + RasterOrder raster_order) { + bool valid_tile = true; + uint64_t ctas_along_m, ctas_along_n; + int total_problem_groups = num_experts_; + ctas_along_m = divmod_cta_shape_m.divide( + cute::shape<0>(ProblemShape(num_rows_per_expert_[group_info.group_idx], N_, K_)) + divmod_cta_shape_m.divisor - + 1); + ctas_along_n = divmod_cta_shape_n.divide( + cute::shape<1>(ProblemShape(num_rows_per_expert_[group_info.group_idx], N_, K_)) + divmod_cta_shape_n.divisor - + 1); + + auto problem_blocks_m = round_up(ctas_along_m, (1 << log_swizzle_size) * cluster_shape.m()); + auto problem_blocks_n = round_up(ctas_along_n, (1 << log_swizzle_size) * cluster_shape.n()); + group_info.total_tiles = problem_blocks_m * problem_blocks_n; + + while (group_info.start_linear_idx + group_info.total_tiles <= linear_idx) { + group_info.group_idx++; + + if (group_info.group_idx >= total_problem_groups) return WorkTileInfo::invalid_work_tile(); + + group_info.start_linear_idx += group_info.total_tiles; + ctas_along_m = divmod_cta_shape_m.divide( + cute::shape<0>(ProblemShape(num_rows_per_expert_[group_info.group_idx], N_, K_)) + + divmod_cta_shape_m.divisor - 1); + ctas_along_n = divmod_cta_shape_n.divide( + cute::shape<1>(ProblemShape(num_rows_per_expert_[group_info.group_idx], N_, K_)) + + divmod_cta_shape_n.divisor - 1); + + problem_blocks_m = round_up(ctas_along_m, (1 << log_swizzle_size) * cluster_shape.m()); + problem_blocks_n = round_up(ctas_along_n, (1 << log_swizzle_size) * cluster_shape.n()); + group_info.total_tiles = problem_blocks_m * problem_blocks_n; + } + + uint64_t cluster_id, cluster_major_offset = 0, cluster_minor_offset = 0; + uint64_t blk_per_grid_dim = divmod_cluster_shape_minor.divide(linear_idx - group_info.start_linear_idx); + divmod_cluster_shape_major(cluster_id, cluster_major_offset, blk_per_grid_dim); + + // With static schedulers, we launch grid such that all cluster are linear (1-D) order, i.e., + // there can only be one cluster in the minor dimension. get_grid_shape() in scheduler params + // put cluster_shape.m/n() as the minor dimension based on raster order AlongN/M resp. + // Therefore, the offset of a CTA (inside a cluster) in the minor dimension can be directly be + // inferred by the blockIdx along the minor dimension. + if (raster_order == RasterOrder::AlongN) { + cluster_minor_offset = BlockIdxX(); + } else { + cluster_minor_offset = BlockIdxY(); + } + + uint64_t cluster_idx_minor, cluster_idx_major; + + uint64_t cluster_idx_minor_div_swizzle, extra, offset; + + offset = cluster_id & ((1 << log_swizzle_size) - 1); + extra = cluster_id >> log_swizzle_size; + + uint64_t curr_group_cluster_blk_major; + if (raster_order == RasterOrder::AlongN) { + curr_group_cluster_blk_major = divmod_cluster_shape_major.divide(problem_blocks_n); + } else { + curr_group_cluster_blk_major = divmod_cluster_shape_major.divide(problem_blocks_m); + } + cluster_idx_minor_div_swizzle = extra / curr_group_cluster_blk_major; + cluster_idx_major = extra % curr_group_cluster_blk_major; + + cluster_idx_minor = cluster_idx_minor_div_swizzle * (1 << log_swizzle_size) + offset; + + auto minor_work_idx = + static_cast(cluster_idx_minor * divmod_cluster_shape_minor.divisor + cluster_minor_offset); + auto major_work_idx = + static_cast(cluster_idx_major * divmod_cluster_shape_major.divisor + cluster_major_offset); + + if (raster_order == RasterOrder::AlongN) { + return {minor_work_idx, major_work_idx, group_info.group_idx, valid_tile}; + } else { + return {major_work_idx, minor_work_idx, group_info.group_idx, valid_tile}; + } + } + + // Returns whether the block assigned this work should compute the epilogue for the corresponding + // output tile. For the basic tile scheduler, this is always true. + CUTLASS_HOST_DEVICE + static bool compute_epilogue(WorkTileInfo const&, Params const&) { + return true; + } + + // Performs the reduction across splits for a given output tile. Since this scheduler does + // not split output tiles, no reduction is needed. + template + CUTLASS_DEVICE static void fixup(Params const&, WorkTileInfo const&, FrgTensorC&, uint32_t, uint32_t) {} + + // Returns whether the current WorkTileInfo passed in should continue to be used. Since + // this scheduler only schedules work in units of single, full output tiles, the WorkTileInfo + // passed in should not be used after having been processed. + CUTLASS_DEVICE + static bool continue_current_work(WorkTileInfo&) { + return false; + } + + // The basic tile scheduler does not require any additional workspace + template + static size_t get_workspace_size( + Arguments const&, ProblemShape, KernelHardwareInfo const&, uint32_t, const uint32_t = 1, uint32_t = 1) { + return 0; + } + + template + static cutlass::Status initialize_workspace( + Arguments const&, + void*, + cudaStream_t, + ProblemShape, + KernelHardwareInfo const&, + uint32_t, + const uint32_t = 1, + uint32_t = 1, + CudaHostAdapter* cuda_adapter = nullptr) { + return Status::kSuccess; + } + + template + CUTLASS_HOST_DEVICE static int + get_work_k_tile_count(WorkTileInfo const& work_tile_info, ProblemShape_MNKL problem_shape, TileShape tile_shape) { + // All work units returned by this scheduler cover the entire K iteration + // space of the output tile assigned to the work unit. + return cute::size(cute::ceil_div(cute::get<2>(problem_shape), cute::get<2>(tile_shape))); + } + + CUTLASS_HOST_DEVICE + static uint32_t get_work_k_tile_start(WorkTileInfo const&) { + // All work units returned by this scheduler start from K tile 0 + return 0u; + } + + CUTLASS_DEVICE + static bool need_separate_reduction(Params const& params) { + return false; + } + + CUTLASS_DEVICE + bool is_work_tile_for_reduction(WorkTileInfo const& work_tile_info, Params const& params) { + return false; + } + + CUTLASS_DEVICE + uint32_t epilgoue_subtile_idx(WorkTileInfo const& work_tile_info, Params const& params) const { + return 0; + } + + template + CUTLASS_DEVICE void separate_reduction( + Params const& params, + WorkTileInfo const& work_tile_info, + FrgTensorC& accumulators, + uint32_t num_barriers, + uint32_t barrier_idx) {} + + // Shares the accumulator set with peers in the global workspace + template + CUTLASS_DEVICE static void share( + Params const& params, + WorkTileInfo const& work_tile_info, + FrgTensorC& accumulators, + uint32_t num_barriers, + uint32_t barrier_idx) {} + + CUTLASS_DEVICE + static bool valid_warpgroup_in_work_tile(WorkTileInfo const& work_tile_info) { + return true; + } + + CUTLASS_DEVICE + static bool requires_separate_reduction(Params const& params) { + return false; + } + + // Kernel helper function to get next work tile + CUTLASS_DEVICE + auto fetch_next_work(WorkTileInfo work_tile_info) { + if (continue_current_work(work_tile_info)) { + return cute::make_tuple(work_tile_info, true); + } + + advance_to_next_work(); + return cute::make_tuple(get_current_work(), true); + } + + // Returns the initial work tile info that will be computed over + template + CUTLASS_DEVICE WorkTileInfo initial_work_tile_info(ClusterShape) { + return get_current_work(); + } +}; + +} // namespace cutlass::gemm::kernel::detail + +/////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::kernel { + +/////////////////////////////////////////////////////////////////////////////// + +template +class GemmMoEUniversal { + public: + // + // Type Aliases + // + using ProblemShape = ProblemShape_; + static_assert( + cute::rank(typename ProblemShape::UnderlyingProblemShape{}) == 3 or + cute::rank(typename ProblemShape::UnderlyingProblemShape{}) == 4, + "ProblemShape{} should be or "); + + // Mainloop derived types + using CollectiveMainloop = CollectiveMainloop_; + using TileShape = typename CollectiveMainloop::WorkgroupTileShape; + using WorkgroupTileShape = TileShape; + using TiledMma = typename CollectiveMainloop::TiledMma; + using ArchTag = typename CollectiveMainloop::ArchTag; + using ElementA = typename CollectiveMainloop::ElementA; + using StrideA = typename CollectiveMainloop::StrideA; + using InternalStrideA = typename CollectiveMainloop::InternalStrideA; + using ElementB = typename CollectiveMainloop::ElementB; + using StrideB = typename CollectiveMainloop::StrideB; + using InternalStrideB = typename CollectiveMainloop::InternalStrideB; + using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; + using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; + using ClusterShape = typename DispatchPolicy::ClusterShape; + using MainloopArguments = typename CollectiveMainloop::Arguments; + using MainloopParams = typename CollectiveMainloop::Params; + + // Epilogue derived types + using CollectiveEpilogue = CollectiveEpilogue_; + using ElementC = typename CollectiveEpilogue::ElementC; + using StrideC = typename CollectiveEpilogue::StrideC; + using InternalStrideC = typename CollectiveEpilogue::InternalStrideC; + using ElementD = typename CollectiveEpilogue::ElementD; + using StrideD = typename CollectiveEpilogue::StrideD; + using InternalStrideD = typename CollectiveEpilogue::InternalStrideD; + using EpilogueArguments = typename CollectiveEpilogue::Arguments; + using EpilogueParams = typename CollectiveEpilogue::Params; + + static_assert(cute::is_same_v, "Only Group Scheduler is supported with this code."); + using TileSchedulerTag = TileScheduler_; + using TileScheduler = typename detail::PersistentTileSchedulerXeMoE; + using TileSchedulerArguments = typename TileScheduler::Arguments; + using TileSchedulerParams = typename TileScheduler::Params; + + static constexpr int SubgroupSize = CollectiveMainloop::SubgroupSize; // sub_group size + static constexpr uint32_t MaxThreadsPerBlock = CollectiveMainloop::MaxThreadsPerBlock; + using MmaAtomShape = typename CollectiveMainloop::MmaAtomShape; + using SubgroupTileShape = typename CollectiveMainloop::SubgroupTileShape; + + using MainloopTensors = typename CollectiveMainloop::MainloopTensors; + using EpilogueTensors = typename CollectiveEpilogue::EpilogueTensors; + + // Kernel level shared memory storage + struct SharedStorage { + using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; + EpilogueTensorStorage epilogue; + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + + static_assert(cute::is_same_v>); + + // Device side arguments + struct Arguments { + GemmUniversalMode mode{}; + const ElementA** A_ptr; + const ElementB** B_ptr; + const ElementC** C_ptr; + ElementD** D_ptr; + decltype(EpilogueArguments{}.thread) fusion_args; + KernelHardwareInfo hw_info{}; + TileSchedulerArguments scheduler{}; + const int* M_per_group{nullptr}; + int num_experts; + int N; + int K; + }; + + // Kernel entry point API + struct Params { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopParams mainloop{}; + EpilogueParams epilogue{}; + KernelHardwareInfo hw_info{}; + TileSchedulerParams scheduler{}; + void* workspace{nullptr}; + const int* M_per_group{nullptr}; + int num_experts; + int N; + int K; + }; + + // + // Methods + // + + // Convert to underlying arguments. In this case, a simple copy for the aliased type. + static Params to_underlying_arguments(Arguments const& args, void* workspace) { + CUTLASS_TRACE_HOST("to_underlying_arguments():"); + auto dummy_problem_shape = cute::Shape{256, args.N, args.K}; + auto dummy_group_problem_shape = ProblemShape{1, &dummy_problem_shape, nullptr}; + + // Get SM count if needed, otherwise use user supplied SM count + int sm_count = args.hw_info.sm_count; + if (sm_count <= 0) { + CUTLASS_TRACE_HOST( + " WARNING: Arguments do not include a valid SM count.\n" + " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); + sm_count = KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id); + } + + CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); + + KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count}; + + // Calculate workspace pointers + uint8_t* workspace_ptr = reinterpret_cast(workspace); + + TileSchedulerParams scheduler = TileScheduler::to_underlying_arguments( + dummy_group_problem_shape, TileShape{}, ClusterShape{}, hw_info, args.scheduler, workspace_ptr); + + return { + args.mode, + dummy_group_problem_shape, + CollectiveMainloop::to_underlying_arguments( + dummy_group_problem_shape, MainloopArguments{args.A_ptr, nullptr, args.B_ptr, nullptr}, workspace_ptr), + CollectiveEpilogue::to_underlying_arguments( + dummy_group_problem_shape, + EpilogueArguments{args.fusion_args, args.C_ptr, nullptr, args.D_ptr, nullptr}, + workspace_ptr), + hw_info, + scheduler, + workspace, + args.M_per_group, + args.num_experts, + args.N, + args.K}; + } + + static bool can_implement(Arguments const& args) { + bool implementable = true; + + implementable = + implementable && + (args.mode == GemmUniversalMode::kGrouped || + (args.mode == GemmUniversalMode::kBatched && rank(typename ProblemShape::UnderlyingProblemShape{}) == 3)); + + implementable = implementable && TileScheduler::can_implement(args.scheduler); + auto dummy_problem_shape = cute::Shape{256, args.N, args.K}; + auto dummy_group_problem_shape = ProblemShape{1, &dummy_problem_shape, nullptr}; + implementable &= CollectiveMainloop::can_implement( + dummy_group_problem_shape, MainloopArguments{args.A_ptr, nullptr, args.B_ptr, nullptr}); + implementable &= CollectiveEpilogue::can_implement( + dummy_group_problem_shape, EpilogueArguments{args.fusion_args, args.C_ptr, nullptr, args.D_ptr, nullptr}); + + return implementable; + } + + static size_t get_workspace_size(Arguments const& args) { + size_t workspace_size = 0; + workspace_size += + TileScheduler::template get_workspace_size( + args.scheduler, typename ProblemShape::UnderlyingProblemShape{}, args.hw_info, -1); + return workspace_size; + } + + static cutlass::Status initialize_workspace( + Arguments const& args, + void* workspace = nullptr, + cudaStream_t stream = nullptr, + CudaHostAdapter* cuda_adapter = nullptr) { + Status status = Status::kSuccess; + uint8_t* workspace_ptr = reinterpret_cast(workspace); + + status = + TileScheduler::template initialize_workspace( + args.scheduler, workspace_ptr, stream, typename ProblemShape::UnderlyingProblemShape{}, args.hw_info, -1); + + return status; + } + + // Computes the kernel launch grid shape based on runtime parameters + static dim3 get_grid_shape(Params const& params) { + // Given device SM count, set grid size s.t. we do not launch more thread blocks than we can run concurrently + TileSchedulerArguments args{}; + args.raster_order = params.scheduler.raster_order_ == TileScheduler::RasterOrder::AlongN + ? TileScheduler::RasterOrderOptions::AlongN + : TileScheduler::RasterOrderOptions::AlongM; + return TileScheduler::get_grid_shape( + params.scheduler, params.problem_shape, TileShape{}, ClusterShape{}, params.hw_info, args); + } + + static dim3 get_block_shape() { + return dim3(MaxThreadsPerBlock, 1, 1); + } + + CUTLASS_DEVICE + void operator()(Params const& params, char* smem_buf) { + // Preconditions + CUTE_STATIC_ASSERT(is_static::value); + + static_assert( + cute::rank(InternalStrideA{}) == 3, + "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert( + cute::rank(InternalStrideB{}) == 3, + "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert( + cute::rank(InternalStrideC{}) == 3, + "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert( + cute::rank(InternalStrideD{}) == 3, + "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + + // Kernel level shared memory storage + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); + + TileScheduler scheduler{params.scheduler}; + const int32_t N = params.N; + const int32_t K = params.K; + scheduler.configure(const_cast(params.M_per_group), params.N, params.K, params.num_experts); + auto work_tile_info = scheduler.initial_work_tile_info(ClusterShape{}); + constexpr auto workgroup_shape = WorkgroupTileShape{}; // (BLK_M,BLK_N,BLK_K) + + int thread_idx = int(ThreadIdxX()); + constexpr auto subgroup_shape = SubgroupTileShape{}; // (SUB_M,SUB_N,SUB_K) + bool did_group_change = true; + int32_t curr_group = -1; + using ProblemShapeMNKL = Shape; + ProblemShapeMNKL problem_shape_MNKL; + MainloopTensors AB_tensors; + EpilogueTensors CD_tensors; + + if (work_tile_info.is_valid()) { + curr_group = work_tile_info.L_idx; + problem_shape_MNKL = append<4>(Shape{params.M_per_group[curr_group], N, K}, 1); + } + + while (work_tile_info.is_valid()) { + auto M = get<0>(problem_shape_MNKL); + auto L = get<3>(problem_shape_MNKL); + + Tensor mA_mkl = cute::get_xe_tensor(make_shape(M, K, L)); //(m,k,l) + Tensor mB_nkl = cute::get_xe_tensor(make_shape(N, K, L)); //(n,k,l) + + auto m_coord = work_tile_info.M_idx; + auto n_coord = work_tile_info.N_idx; + + auto gA_mkl = local_tile(mA_mkl, select<0, 2>(workgroup_shape), make_coord(m_coord, _, 0)); + auto gB_nkl = local_tile(mB_nkl, select<1, 2>(workgroup_shape), make_coord(n_coord, _, 0)); + + CollectiveMainloop collective_mma; + if (did_group_change) { + AB_tensors = collective_mma.update_tensor_shape_stride( + params.mainloop, curr_group, problem_shape_MNKL, params.M_per_group); + } + auto tile_coord = make_coord(m_coord, n_coord, _, 0); + + // Get the number of K tiles to compute for this work as well as the starting K tile offset of the work. + int work_k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, workgroup_shape); + int work_k_tile_start = TileScheduler::get_work_k_tile_start(work_tile_info); + auto k_tile_iter = cute::make_coord_iterator(idx2crd(work_k_tile_start, make_shape(K)), make_shape(K)); + + TiledMma tiled_mma; + Tensor accumulators = partition_fragment_C(tiled_mma, take<0, 2>(workgroup_shape)); + + // Perform the collective scoped MMA + collective_mma( + accumulators, + gA_mkl, + gB_nkl, + accumulators, + k_tile_iter, + work_k_tile_count, + tile_coord, + K, + thread_idx, + params.mainloop, + AB_tensors); + + TileScheduler::fixup(params.scheduler, work_tile_info, accumulators, -1, -1); + + if (TileScheduler::compute_epilogue(work_tile_info, params.scheduler)) { + CollectiveEpilogue epilogue{params.epilogue, shared_storage.epilogue}; + + if (did_group_change) { + CD_tensors = epilogue.update_tensor_shape_stride(curr_group, problem_shape_MNKL, params.M_per_group); + did_group_change = false; + } + + epilogue(problem_shape_MNKL, subgroup_shape, tile_coord, accumulators, tiled_mma, thread_idx, CD_tensors); + } + + // Get next work tile + auto [next_work_tile_info, temp] = scheduler.fetch_next_work(work_tile_info); + work_tile_info = next_work_tile_info; + + did_group_change = curr_group != work_tile_info.L_idx; + + if (did_group_change && work_tile_info.is_valid()) { + curr_group = work_tile_info.L_idx; + problem_shape_MNKL = append<4>(Shape{params.M_per_group[curr_group], N, K}, 1); + } + } + } +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::kernel + +namespace cutlass::gemm::device { + +template +class GemmMoEUniversalAdapter { + public: + using GemmKernel = GetUnderlyingKernel_t; + using TileShape = typename GemmKernel::TileShape; + using ElementA = typename GemmKernel::ElementA; + using ElementB = typename GemmKernel::ElementB; + using ElementC = typename GemmKernel::ElementC; + using ElementD = typename GemmKernel::ElementD; + using ElementAccumulator = typename GemmKernel::ElementAccumulator; + using DispatchPolicy = typename GemmKernel::DispatchPolicy; + using CollectiveMainloop = typename GemmKernel::CollectiveMainloop; + using CollectiveEpilogue = typename GemmKernel::CollectiveEpilogue; + + // Map back to 2.x type as best as possible + using LayoutA = gemm::detail::StrideToLayoutTagA_t; + using LayoutB = gemm::detail::StrideToLayoutTagB_t; + using LayoutC = gemm::detail::StrideToLayoutTagC_t; + using LayoutD = gemm::detail::StrideToLayoutTagC_t; + + static ComplexTransform const kTransformA = + cute::is_same_v + ? ComplexTransform::kConjugate + : ComplexTransform::kNone; + static ComplexTransform const kTransformB = + cute::is_same_v + ? ComplexTransform::kConjugate + : ComplexTransform::kNone; + + // Legacy: Assume MultiplyAdd only since we do not use this tag type in 3.0 + using MathOperator = cutlass::arch::OpMultiplyAdd; + + using OperatorClass = cutlass::detail::get_operator_class_t; + + using ArchTag = typename GemmKernel::ArchTag; + + // NOTE: Assume identity swizzle for now + using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; + + // Assume TiledMma's ShapeMNK is the same as 2.x's ThreadblockShape + using ThreadblockShape = + cutlass::gemm::GemmShape(TileShape{}), cute::size<1>(TileShape{}), cute::size<2>(TileShape{})>; + + using ClusterShape = cutlass::gemm::GemmShape< + cute::size<0>(typename GemmKernel::DispatchPolicy::ClusterShape{}), + cute::size<1>(typename GemmKernel::DispatchPolicy::ClusterShape{}), + cute::size<2>(typename GemmKernel::DispatchPolicy::ClusterShape{})>; + + // Instruction shape is easy too, since we get that directly from our TiledMma's atom shape + using InstructionShape = cutlass::gemm::GemmShape< + cute::size<0>(typename CollectiveMainloop::TiledMma::AtomShape_MNK{}), + cute::size<1>(typename CollectiveMainloop::TiledMma::AtomShape_MNK{}), + cute::size<2>(typename CollectiveMainloop::TiledMma::AtomShape_MNK{})>; + + // Legacy: provide a correct warp count, but no reliable warp shape + static int const kThreadCount = GemmKernel::MaxThreadsPerBlock; + + // Warp shape is not a primary API type in 3.x + // But we can best approximate it by inspecting the TiledMma + // For this, we make the assumption that we always have 4 warps along M, and rest along N, none along K + // We also always round up the warp count to 4 if the tiled mma is smaller than 128 threads + static constexpr int WarpsInMma = cute::max(4, CUTE_STATIC_V(cute::size(typename GemmKernel::TiledMma{})) / 32); + static constexpr int WarpsInMmaM = 4; + static constexpr int WarpsInMmaN = cute::ceil_div(WarpsInMma, WarpsInMmaM); + using WarpCount = cutlass::gemm::GemmShape; + using WarpShape = cutlass::gemm::GemmShape< + CUTE_STATIC_V(cute::tile_size<0>(typename CollectiveMainloop::TiledMma{})) / WarpsInMmaM, + CUTE_STATIC_V(cute::tile_size<1>(typename CollectiveMainloop::TiledMma{})) / WarpsInMmaN, + CUTE_STATIC_V(cute::tile_size<2>(typename CollectiveMainloop::TiledMma{}))>; + + static int constexpr kStages = detail::stages_member(typename CollectiveMainloop::DispatchPolicy{}); + + // Inspect TiledCopy for A and B to compute the alignment size + static int constexpr kAlignmentA = cutlass::detail::get_alignment_count_from_gmem_tiled_copy< + typename CollectiveMainloop::GmemTiledCopyA, + ElementA, + typename CollectiveMainloop::TiledMma::ValTypeA>(); + static int constexpr kAlignmentB = cutlass::detail::get_alignment_count_from_gmem_tiled_copy< + typename CollectiveMainloop::GmemTiledCopyB, + ElementB, + typename CollectiveMainloop::TiledMma::ValTypeB>(); + static int constexpr kAlignmentC = cutlass::detail:: + get_alignment_count_from_gmem_tiled_copy(); + static int constexpr kAlignmentD = cutlass::detail:: + get_alignment_count_from_gmem_tiled_copy(); + + using EpilogueOutputOp = typename CollectiveEpilogue::ThreadEpilogueOp; + + // Split-K preserves splits that are 128b aligned + static int constexpr kSplitKAlignment = + cute::max(128 / sizeof_bits::value, 128 / sizeof_bits::value); + + /// Argument structure: User API + using Arguments = typename GemmKernel::Arguments; + /// Argument structure: Kernel API + using Params = typename GemmKernel::Params; + + private: + /// Kernel API parameters object + Params params_; + + public: + /// Access the Params structure + Params const& params() const { + return params_; + } + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const& args) { + if (GemmKernel::can_implement(args)) { + return Status::kSuccess; + } else { + return Status::kInvalid; + } + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const& args) { + size_t workspace_bytes = 0; + if (args.mode == GemmUniversalMode::kGemmSplitKParallel) { + workspace_bytes += sizeof(int) * size_t(cute::size<0>(TileShape{})) * size_t(cute::size<1>(TileShape{})); + } + + workspace_bytes += GemmKernel::get_workspace_size(args); + + CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes); + + return workspace_bytes; + } + + /// Computes the grid shape + static dim3 get_grid_shape(Arguments const& args, void* workspace = nullptr) { + auto tmp_params = GemmKernel::to_underlying_arguments(args, workspace); + return GemmKernel::get_grid_shape(tmp_params); + } + + /// Computes the grid shape + static dim3 get_grid_shape(Params const& params) { + return GemmKernel::get_grid_shape(params); + } + + /// Computes the maximum number of active blocks per multiprocessor + static int maximum_active_blocks(int /* smem_capacity */ = -1) { + CUTLASS_TRACE_HOST("GemmUniversal::maximum_active_blocks()"); + int max_active_blocks = -1; + int smem_size = GemmKernel::SharedStorageSize; + + // first, account for dynamic smem capacity if needed + cudaError_t result; + if (smem_size >= (48 << 10)) { + CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size); + result = cudaFuncSetAttribute(device_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error: " << cudaGetErrorString(result)); + return -1; + } + } + + // query occupancy after setting smem size + result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, device_kernel, GemmKernel::MaxThreadsPerBlock, smem_size); + + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST( + " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error: " << cudaGetErrorString(result)); + return -1; + } + + CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); + return max_active_blocks; + } + + /// Initializes GEMM state from arguments. + Status initialize( + Arguments const& args, + void* workspace = nullptr, + cudaStream_t stream = nullptr, + CudaHostAdapter* cuda_adapter = nullptr) { + CUTLASS_TRACE_HOST( + "GemmUniversal::initialize() - workspace " << workspace << ", stream: " << (stream ? "non-null" : "null")); + + // Initialize the workspace + Status status = GemmKernel::initialize_workspace(args, workspace, stream, cuda_adapter); + if (status != Status::kSuccess) { + return status; + } + // Initialize the Params structure + params_ = GemmKernel::to_underlying_arguments(args, workspace); + + // + // Account for dynamic smem capacity if needed + // + int smem_size = GemmKernel::SharedStorageSize; + + CUTLASS_ASSERT(cuda_adapter == nullptr); + return Status::kSuccess; + } + + /// Primary run() entry point API that is static allowing users to create and manage their own params. + /// Supplied params struct must be construct by calling GemmKernel::to_underlying_arguments() + static Status + run(Params& params, cudaStream_t stream, CudaHostAdapter* cuda_adapter = nullptr, bool launch_with_pdl = false) { + CUTLASS_TRACE_HOST("GemmUniversal::run()"); + dim3 const block = GemmKernel::get_block_shape(); + dim3 const grid = get_grid_shape(params); + + const compat::dim3 sycl_block(block.x, block.y, block.z); + const compat::dim3 sycl_grid(grid.x, grid.y, grid.z); + + // configure smem size and carveout + int smem_size = GemmKernel::SharedStorageSize; + + Status launch_result{Status::kSuccess}; + cutlass::arch::synclog_setup(); + + CUTLASS_ASSERT(cuda_adapter == nullptr); + sycl::queue q = *stream; +#if defined(SYCL_INTEL_TARGET) + constexpr bool allow_subgroup_size_prop = true; +#else + constexpr bool allow_subgroup_size_prop = false; +#endif + auto kernel_props = [] { + constexpr bool is_device_agnostic = cute::is_same_v; + if constexpr (!allow_subgroup_size_prop or is_device_agnostic) { + using EmptyProperties = decltype(sycl::ext::oneapi::experimental::properties()); + return compat::experimental::kernel_properties{}; + } else { + return compat::experimental::kernel_properties{ + sycl::ext::oneapi::experimental::sub_group_size}; + } + }(); + compat::experimental::launch_properties launch_props{ + sycl::ext::oneapi::experimental::work_group_scratch_size(smem_size), + }; + compat::experimental::launch_policy policy{sycl_grid, sycl_block, launch_props, kernel_props}; + auto event = compat::experimental::launch, GemmKernel>(policy, q, params); + EventManager::getInstance().addEvent(event); + cudaError_t result = cudaGetLastError(); + if (cudaSuccess == result && Status::kSuccess == launch_result) { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("GemmUniversal::run: cudaGetLastError reports success"); +#endif + return Status::kSuccess; + } else { + CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result); + return Status::kErrorInternal; + } + } + + /// Overload that allows a user to re-launch the same kernel without updating internal params struct. + Status run(sycl::queue* stream) { + return run(params_, stream); + } +}; + +} // namespace cutlass::gemm::device diff --git a/src/torch_extension_sycl.cc b/src/torch_extension_sycl.cc index 16e8ded..e826b35 100644 --- a/src/torch_extension_sycl.cc +++ b/src/torch_extension_sycl.cc @@ -21,9 +21,6 @@ limitations under the License. #include "sgl_kernel_torch_shim.h" TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { - /* - * From csrc/gemm - */ m.def("awq_dequantize(Tensor qweight, Tensor scales, Tensor qzeros) -> Tensor"); m.impl("awq_dequantize", torch::kXPU, &awq_dequantize); @@ -56,6 +53,20 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { "bool is_neox) -> (Tensor, Tensor)"); m.impl("rotary_embedding", torch::kXPU, &at::native::xpu::rotary_embedding); + m.def( + "moe_align_block_size(Tensor topk_ids, int num_experts, int block_size, Tensor! sorted_token_ids, Tensor! " + "experts_ids, Tensor! num_tokens_post_pad, Tensor! cumsum_buffer, bool " + "pad_sorted_token_ids) -> ()"); + m.impl("moe_align_block_size", torch::kXPU, &moe_align_block_size); + + m.def("moe_sum(Tensor input, Tensor! output) -> ()"); + m.impl("moe_sum", torch::kXPU, &moe_sum); + + m.def( + "moe_grouped_mm_nt(Tensor output, Tensor activations, Tensor weights, Tensor total_rows_for_experts, int " + "n_experts) -> ()"); + m.impl("moe_grouped_mm_nt", torch::kXPU, &moe_grouped_mm_nt); + // m.def( // "fp8_blockwise_scaled_mm(Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, ScalarType out_dtype, // -> Tensor"); diff --git a/tests/run_suite.py b/tests/run_suite.py index 3102c47..1f802f0 100644 --- a/tests/run_suite.py +++ b/tests/run_suite.py @@ -17,6 +17,8 @@ class TestFile: TestFile("test_awq_dequant.py"), TestFile("test_topk_softmax.py"), TestFile("test_flash_attention.py"), + TestFile("test_moe_align.py"), + TestFile("test_moe_gemm.py"), ], } diff --git a/tests/test_moe_align.py b/tests/test_moe_align.py index 3baae0a..77cc515 100644 --- a/tests/test_moe_align.py +++ b/tests/test_moe_align.py @@ -4,7 +4,7 @@ import torch import triton import triton.language as tl -from sgl_kernel import moe_align_block_size +from sgl_kernel import moe_align_block_size, moe_sum def ceil_div(a, b): @@ -138,88 +138,127 @@ def moe_align_block_size_triton( @pytest.mark.parametrize( - "block_size,num_tokens,topk,num_experts", + "block_size,num_tokens,topk,num_experts,pad_sorted_token_ids", list( itertools.product( [32, 64, 128, 256], # block_size [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096], # num_tokens [1, 2, 4, 8, 16, 32, 64], # topk [64, 160, 256, 257, 260, 264], # num_experts + [True, False], # pad_sorted_token_ids ) ), ) def test_moe_align_block_size_compare_implementations( - block_size, num_tokens, topk, num_experts + block_size, num_tokens, topk, num_experts, pad_sorted_token_ids ): - topk_ids = torch.stack( - [ - torch.randperm(num_experts, dtype=torch.int32, device="cuda")[:topk] - for _ in range(num_tokens) - ] - ) + topk_ids = torch.argsort(torch.rand(num_tokens, num_experts, device="xpu"), dim=1)[ + :, :topk + ] - max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) + max_num_tokens_padded = topk_ids.numel() + (num_experts + 1) * (block_size - 1) - sorted_ids_cuda = torch.empty( + sorted_ids_xpu = torch.empty( (max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device ) - sorted_ids_cuda.fill_(topk_ids.numel()) + if not pad_sorted_token_ids: + sorted_ids_xpu.fill_(topk_ids.numel()) max_num_m_blocks = max_num_tokens_padded // block_size - expert_ids_cuda = torch.zeros( + expert_ids_xpu = torch.zeros( (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device ) - num_tokens_post_pad_cuda = torch.empty( + num_tokens_post_pad_xpu = torch.empty( (1), dtype=torch.int32, device=topk_ids.device ) - token_cnts_buffer = torch.empty( - (num_experts + 1) * num_experts, - dtype=torch.int32, - device=topk_ids.device, - ) cumsum_buffer = torch.empty( - num_experts + 1, dtype=torch.int32, device=topk_ids.device + num_experts + 2, dtype=torch.int32, device=topk_ids.device ) - sorted_ids_triton = torch.empty_like(sorted_ids_cuda) + sorted_ids_triton = torch.empty_like(sorted_ids_xpu) sorted_ids_triton.fill_(topk_ids.numel()) - expert_ids_triton = torch.zeros_like(expert_ids_cuda) - num_tokens_post_pad_triton = torch.empty_like(num_tokens_post_pad_cuda) + expert_ids_triton = torch.zeros_like(expert_ids_xpu) + num_tokens_post_pad_triton = torch.empty_like(num_tokens_post_pad_xpu) moe_align_block_size( topk_ids, - num_experts, + num_experts + 1, block_size, - sorted_ids_cuda, - expert_ids_cuda, - num_tokens_post_pad_cuda, - token_cnts_buffer, + sorted_ids_xpu, + expert_ids_xpu, + num_tokens_post_pad_xpu, cumsum_buffer, + pad_sorted_token_ids, ) moe_align_block_size_triton( topk_ids, - num_experts, + num_experts + 1, block_size, sorted_ids_triton, expert_ids_triton, num_tokens_post_pad_triton, ) - assert torch.allclose(expert_ids_cuda, expert_ids_triton), ( + assert torch.allclose(expert_ids_xpu, expert_ids_triton, atol=0, rtol=0), ( f"Expert IDs mismatch for block_size={block_size}, " f"num_tokens={num_tokens}, topk={topk}\n" - f"CUDA expert_ids: {expert_ids_cuda}\n" + f"xpu expert_ids: {expert_ids_xpu}\n" f"Triton expert_ids: {expert_ids_triton}" ) - assert torch.allclose(num_tokens_post_pad_cuda, num_tokens_post_pad_triton), ( + assert torch.allclose( + num_tokens_post_pad_xpu, num_tokens_post_pad_triton, atol=0, rtol=0 + ), ( f"Num tokens post pad mismatch for block_size={block_size}, " f"num_tokens={num_tokens}, topk={topk}\n" - f"CUDA num_tokens_post_pad: {num_tokens_post_pad_cuda}\n" + f"xpu num_tokens_post_pad: {num_tokens_post_pad_xpu}\n" f"Triton num_tokens_post_pad: {num_tokens_post_pad_triton}" ) + # Select an expert to check + expert_idx = expert_ids_xpu.max().item() + + # Get the first and last block id where expert_ids_xpu == expert_idx + matching_indices = torch.where(expert_ids_xpu == expert_idx)[0] + block_sorted_start = matching_indices[0].item() * block_size + block_sorted_end = min( + (matching_indices[-1].item() + 1) * block_size, num_tokens_post_pad_xpu.item() + ) + + selected_sorted_ids_xpu = sorted_ids_xpu[ + block_sorted_start:block_sorted_end + ].sort()[0] + selected_sorted_ids_triton = sorted_ids_triton[ + block_sorted_start:block_sorted_end + ].sort()[0] + + assert torch.allclose( + selected_sorted_ids_xpu, + selected_sorted_ids_triton, + atol=0, + rtol=0, + ), ( + f"Sorted IDs mismatch for block_size={block_size}, " + f"num_tokens={num_tokens}, topk={topk}\n" + f"xpu sorted_ids: {selected_sorted_ids_xpu}\n" + f"Triton sorted_ids: {selected_sorted_ids_triton}" + ) + + +@pytest.mark.parametrize("m", [1, 33, 64, 222]) +@pytest.mark.parametrize("topk", [2, 6]) +@pytest.mark.parametrize("k", [128, 511, 1024]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) +def test_moe_sum(m: int, topk: int, k: int, dtype: torch.dtype): + input = torch.randn((m, topk, k), device="xpu", dtype=dtype) + actual = torch.empty((m, k), device="xpu", dtype=dtype) + + expected = input.sum(dim=1) + moe_sum(input, actual) + + torch.testing.assert_close(actual, expected, atol=2e-2, rtol=0) + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/test_moe_gemm.py b/tests/test_moe_gemm.py new file mode 100644 index 0000000..d02c830 --- /dev/null +++ b/tests/test_moe_gemm.py @@ -0,0 +1,100 @@ +import itertools + +import pytest +import torch +import torch.nn.functional as F +from sgl_kernel import fused_experts + + +def silu_and_mul(x: torch.Tensor) -> torch.Tensor: + d = x.shape[-1] // 2 + return F.silu(x[..., :d]) * x[..., d:] + + +def create_random_xpu_tensor(shape, dtype, mean=0, std=0.01): + """Create a random xpu tensor + + Args: + shape: Tensor shape + dtype: Data type + mean: Mean value + std: Standard deviation + + Returns: + torch.Tensor: Randomly initialized xpu tensor + """ + return torch.empty(shape, dtype=dtype, device="xpu").normal_(mean, std) + + +def torch_naive_moe( + a, + w1, + w2, + topk_ids, + topk_weight, + topk, +): + B, D = a.shape + a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) + out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) + topk_weight = topk_weight.view(-1) + topk_ids = topk_ids.view(-1) + + for i in range(w1.shape[0]): + mask = topk_ids == i + if mask.sum(): + tmp = silu_and_mul(a[mask] @ w1[i].transpose(0, 1)) + # import pdb; pdb.set_trace() + out[mask] = tmp @ w2[i].transpose(0, 1) + + return ( + out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype) + ).sum(dim=1) + + +@pytest.mark.parametrize( + "num_tokens,topk,num_experts,hidden_size,intermediate_size", + list( + itertools.product( + [1, 4, 33, 64, 222], # num_tokens + [1, 2, 6], # topk + [8, 64], # num_experts + [128, 1024], # hidden_size + [128, 512, 1024], # intermediate_size + ) + ), +) +def test_moe_gemm(num_tokens, topk, num_experts, hidden_size, intermediate_size): + rtol, atol = 1e-1, 1e-2 + a = create_random_xpu_tensor((num_tokens, hidden_size), torch.bfloat16) + w1 = create_random_xpu_tensor( + (num_experts, 2 * intermediate_size, hidden_size), torch.bfloat16 + ) + w2 = create_random_xpu_tensor( + (num_experts, hidden_size, intermediate_size), torch.bfloat16 + ) + score = torch.randn([num_tokens, num_experts], dtype=torch.bfloat16).to("xpu") + + score = torch.softmax(score, dim=-1, dtype=torch.float32) + topk_weight, topk_ids = torch.topk(score, topk) + torch_output = torch_naive_moe( + a, + w1, + w2, + topk_ids, + topk_weight, + topk, + ) + sglang_output = fused_experts( + a, + w1, + w2, + topk_weight, + topk_ids, + ) + # import pdb; pdb.set_trace() + torch.testing.assert_close(torch_output, sglang_output, rtol=rtol, atol=atol) + + +if __name__ == "__main__": + pytest.main([__file__])