Skip to content

Conversation

@airMeng
Copy link
Collaborator

@airMeng airMeng commented Oct 17, 2025

  • moe scatter and gather update to the latest main
  • cutlass based MoE GEMM

@airMeng airMeng marked this pull request as draft October 22, 2025 09:01
@airMeng airMeng force-pushed the moe branch 4 times, most recently from 6589915 to 0004dab Compare October 28, 2025 08:47
@airMeng airMeng marked this pull request as ready for review November 14, 2025 08:13
@airMeng airMeng requested a review from mingfeima November 14, 2025 08:14
Copy link
Collaborator

@mingfeima mingfeima left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's firstly separate the PR into:

  • implement MoE with grouped gemm by cutlass
  • enable operator unit benchmarks for CI

Comment on lines +341 to +352
flat_topk = topk_ids.flatten()
idxs = flat_topk.argsort()
sorted_expert_ids = flat_topk[idxs]

counts = torch.bincount(sorted_expert_ids, minlength=E) # [E]
token_idxs = idxs // TopK # [num_tokens * TopK]
input_A = torch.empty(
(num_tokens * TopK, K), device=hidden_states.device, dtype=hidden_states.dtype
)
input_A = hidden_states[token_idxs].squeeze(1)
offset = counts.to(torch.int32)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this the best that we can do for current cutlass APIs?

if so, do we have JIRA tracking the real need that we have?

this implementation will definitely hurt the perf.

Copy link

@sanchitintel sanchitintel Nov 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@airMeng, do you know how is FlashInfer using a cutlass-based FusedMoE for Nvidia GPUs? I think folks in the vLLM team know.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Flashinder using a cutlass-based FusedMoE without activation reshuffle

Copy link
Collaborator

@mingfeima mingfeima left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need some minor changes. Generally LGTM!

Comment on lines +46 to +80
# # DeepSeek-V3-0324, tp = 1
# {
# "num_experts": 257,
# "topk": 8,
# "hidden_size": 7168,
# "shard_intermediate_size": 4096,
# "dtype": torch.bfloat16,
# "block_shape": [128, 128],
# },
# # DeepSeek-V3-0324, tp = 2
# {
# "num_experts": 257,
# "topk": 8,
# "hidden_size": 7168,
# "shard_intermediate_size": 2048,
# "dtype": torch.bfloat16,
# "block_shape": [128, 128],
# },
# # DeepSeek-V3-0324, tp = 4
# {
# "num_experts": 257,
# "topk": 8,
# "hidden_size": 7168,
# "shard_intermediate_size": 1024,
# "dtype": torch.bfloat16,
# "block_shape": [128, 128],
# },
# # DeepSeek-V3-0324, tp = 8
# {
# "num_experts": 257,
# "topk": 8,
# "hidden_size": 7168,
# "shard_intermediate_size": 512,
# "dtype": torch.bfloat16,
# "block_shape": [128, 128],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we use TP as an arguement and then for the configuration, shard_intermediate_size = shard_intermediate_size // tp

return n + 1;
}

#define CEILDIV(x, y) ((x + y - 1) / y)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

defined in include/utils.h

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed, using div_up instead

Comment on lines +337 to +365
if (small_batch_expert_mode) {
const int32_t threads_local = std::max((int32_t)num_experts, sub_group_size);
auto range = sycl::nd_range<1>(sycl::range<1>(threads_local), sycl::range<1>(threads_local));
using SmallKernel = MOEAlignBlockSizeSmallBatchExpertFunctor<scalar_t>;
SmallKernel kernel(
topk_ids.data_ptr<scalar_t>(),
sorted_token_ids.data_ptr<int32_t>(),
experts_ids.data_ptr<int32_t>(),
num_tokens_post_pad.data_ptr<int32_t>(),
num_experts,
block_size,
topk_ids.numel(),
pad_sorted_token_ids);
sycl_kernel_submit(range.get_global_range(), range.get_local_range(), queue, kernel);
} else {
const size_t scan_size = next_pow2(num_experts);
const size_t shared_mem_size = (num_experts + (num_experts + 1) + scan_size + sub_group_size) * sizeof(int32_t);
using Kernel = MOEAlignBlockSizeFunctor<scalar_t>;
Kernel kernel(
topk_ids.data_ptr<scalar_t>(),
sorted_token_ids.data_ptr<int32_t>(),
experts_ids.data_ptr<int32_t>(),
num_tokens_post_pad.data_ptr<int32_t>(),
num_experts,
block_size,
topk_ids.numel(),
cumsum_buffer.data_ptr<int32_t>(),
pad_sorted_token_ids,
scan_size);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thumbs up!

Comment on lines +47 to +76
switch (topk) {
case 2: {
DISPATCH_FLOAT_TYPES(input.scalar_type(), "moe_sum", [&] {
using Kernel = MoeSumKernel<scalar_t, 2>;
Kernel kernel(output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), hidden_size);
sycl_kernel_submit(range.get_global_range(), range.get_local_range(), queue, kernel);
});
break;
}
case 3: {
DISPATCH_FLOAT_TYPES(input.scalar_type(), "moe_sum", [&] {
using Kernel = MoeSumKernel<scalar_t, 3>;
Kernel kernel(output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), hidden_size);
sycl_kernel_submit(range.get_global_range(), range.get_local_range(), queue, kernel);
});
break;
}
case 4: {
DISPATCH_FLOAT_TYPES(input.scalar_type(), "moe_sum", [&] {
using Kernel = MoeSumKernel<scalar_t, 4>;
Kernel kernel(output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), hidden_size);
sycl_kernel_submit(range.get_global_range(), range.get_local_range(), queue, kernel);
});
break;
}
default:
at::sum_out(output, input, 1);
break;
}
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need to do special treatment for 2, 3, and 4?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mingfeima
Copy link
Collaborator

Do we have performance numbers? I don't know if we can compare cutlass kernels with xetla kernels (with so called persisitent weight). If the compare makes sense, please collect the data. Otherwise, you can just calculate the memory bandwidth.

Comment on lines 14 to 73
if: github.event_name != 'pull_request' || contains(github.event.pull_request.labels.*.name, 'run-ci')
runs-on: sglang-pvc
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0

- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v2

- name: Build Docker image
run: |
docker build \
--build-arg SG_LANG_KERNEL_BRANCH=${{ github.head_ref }} \
--build-arg SG_LANG_KERNEL_REPO=${{ github.event.pull_request.head.repo.clone_url }} \
--no-cache --progress=plain -f Dockerfile.xpu_kernel -t xpu_sglang:kernel .
- name: Run container
run: |
docker run -dt \
--device /dev/dri/ \
--name ci_sglang_xpu \
-e HF_TOKEN=$(cat ~/huggingface_token.txt) \
xpu_sglang:kernel
- name: Install Dependency
timeout-minutes: 20
run: |
docker exec ci_sglang_xpu /miniforge3/envs/py3.10/bin/python3 -m pip install --upgrade pip
docker exec ci_sglang_xpu /miniforge3/envs/py3.10/bin/python3 -m pip install pytest expecttest ray huggingface_hub
docker exec ci_sglang_xpu /bin/bash -c '/miniforge3/envs/py3.10/bin/huggingface-cli login --token ${HF_TOKEN} '
docker exec ci_sglang_xpu /bin/bash -c "ln -sf /miniforge3/envs/py3.10/bin/python3 /usr/bin/python3"
- name: Run Sglang Kernel Cases
timeout-minutes: 20
run: |
docker exec -w /root/sglang ci_sglang_xpu \
/bin/bash -c "cd /root/sglang/sgl-kernel-xpu/tests && python3 run_suite.py --suite per-commit "
- name: Run Sglang Kernel Benchmarks
timeout-minutes: 20
run: |
docker exec -w /root/sglang ci_sglang_xpu \
/bin/bash -c "cd /root/sglang/sgl-kernel-xpu/benchmark && python3 bench_flash_attn.py && python3 bench_moe_topk_softmax.py && python3 benchmark_fused_moe.py "
- name: Run E2E Bfloat16 tests
timeout-minutes: 20
run: |
echo "[PlaceHolder for E2E Test...]"
- name: Run E2E Qunatization tests
timeout-minutes: 20
run: |
echo "[PlaceHolder for E2E Test...]"
- name: Cleanup container
if: always()
run: |
docker rm -f ci_sglang_xpu || true

Check warning

Code scanning / CodeQL

Workflow does not contain permissions Medium

Actions job or workflow does not limit the permissions of the GITHUB_TOKEN. Consider setting an explicit permissions block, using the following as a minimal starting point: {contents: read}

Copilot Autofix

AI 2 days ago

The best way to fix this problem is to add the explicit permissions block with the minimum set of permissions required for the workflow to function. Since this workflow does not push changes, create releases, or otherwise write to the repository contents, it only needs contents: read permission. For maximum clarity and correctness, insert the following block after the top-level name: key (before or after on: is fine, but prefer after) to apply to all jobs in the workflow:

permissions:
  contents: read

This ensures that GITHUB_TOKEN will only have read access to repository contents, thus following the principle of least privilege and fixing the CodeQL error. No new imports, methods, or definitions are required—just a YAML syntax addition.

Suggested changeset 1
.github/workflows/pr-test-xpu.yml

Autofix patch

Autofix patch
Run the following command in your local git repository to apply this patch
cat << 'EOF' | git apply
diff --git a/.github/workflows/pr-test-xpu.yml b/.github/workflows/pr-test-xpu.yml
--- a/.github/workflows/pr-test-xpu.yml
+++ b/.github/workflows/pr-test-xpu.yml
@@ -1,5 +1,8 @@
 name: PR Test (XPU)
 
+permissions:
+  contents: read
+
 on:
   pull_request:
     branches: [main]
EOF
@@ -1,5 +1,8 @@
name: PR Test (XPU)

permissions:
contents: read

on:
pull_request:
branches: [main]
Copilot is powered by AI and may make mistakes. Always verify output.
Copy link

@kareemshaik80 kareemshaik80 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@airMeng
Copy link
Collaborator Author

airMeng commented Nov 20, 2025

@mingfeima @sanchitintel If no more comments I will merge the PR

@mingfeima mingfeima merged commit f6d2976 into main Nov 20, 2025
6 checks passed
@airMeng airMeng deleted the moe branch November 20, 2025 02:12
@mingfeima
Copy link
Collaborator

@mingfeima @sanchitintel If no more comments I will merge the PR

NP, we don't have to do everything in just one PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants