You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
In my tests, using 8 H100 GPUs did not show any acceleration. However, when I ran the script kernels/triton/inference/col_major_moe_gemm/perf_test_moe.py , I did see a 2-3 times speedup. Later, I tested Column-Major MoE and the MoE from vllm on 2 H100 GPUs separately and found about a 25% speedup, but this acceleration was not observed on 8 H100 GPUs. Is this result reasonable?
To reproduce:
# Docker
docker run -it --gpus all --rm --ipc=host nvcr.io/nvidia/pytorch:23.10-py3
# Download vllm library
pip install vllm==v0.4.0.post1
# Download vllm repo
git clone https://github.com/vllm-project/vllm.git
cd vllm/benchmarks/; git checkout tags/v0.4.0.post1
# Download Mixtral
from huggingface_hub import snapshot_download, login
hf_token = "You should use your own token!"
login(token=hf_token)
snapshot_download(repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1")
# Run benchmark with vllm MoE
# If you see error about quantization_param_path, just comment it and run again.
python benchmark_throughput.py \
--model=mistralai/Mixtral-8x7B-Instruct-v0.1 --input-len=256 --output-len=64 \
--tensor-parallel-size=8 --num-prompts 400 --worker-use-ray
# Run benchmark with column-major MoE
# Find out the MoE script,
vim /usr/local/lib/python3.10/dist-packages/vllm/model_executor/layers/fused_moe/fused_moe.py
# Copy and paste the script from Column-Major_fused_moe.py
python benchmark_throughput.py \
--model=mistralai/Mixtral-8x7B-Instruct-v0.1 --input-len=256 --output-len=64 \
--tensor-parallel-size=8 --num-prompts 400 --worker-use-ray
I got roughly the same throughput from the above benchmark.
The code for Column-Major_fused_moe.py
"""Column-major Fused MoE kernel."""
import functools
import json
import os
from typing import Any, Dict, Optional, Tuple
import torch
import triton
import triton.language as tl
from vllm._C import ops
from vllm.logger import init_logger
from vllm.utils import is_hip
import time
MEASURE_TIME = False
logger = init_logger(__name__)
@triton.jit()
def col_major(pid,
m, n,
block_m: tl.constexpr, block_n: tl.constexpr):
grid_m = tl.cdiv(m, block_m)
grid_n = tl.cdiv(n, block_n)
pid_m = (pid % grid_n)
pid_n = pid // grid_m
return pid_m, pid_n
@triton.jit
def fused_moe_kernel(
# Pointers to matrices
a_ptr,
b_ptr,
c_ptr,
topk_weights_ptr,
sorted_token_ids_ptr,
expert_ids_ptr,
num_tokens_post_padded_ptr,
# Matrix dimensions
N,
K,
EM,
num_valid_tokens,
# The stride variables represent how much to increase the ptr by when
# moving by 1 element in a particular dimension. E.g. `stride_am` is
# how much to increase `a_ptr` by to get the element one row down
# (A has M rows).
stride_am,
stride_ak,
stride_be,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
stride_weight,
stride_token_id,
# Meta-parameters
block_m: tl.constexpr,
block_n: tl.constexpr,
block_k: tl.constexpr,
MUL_ROUTED_WEIGHT: tl.constexpr,
top_k: tl.constexpr,
compute_type: tl.constexpr,
):
"""
Implements the fused computation for a Mixture of Experts (MOE) using
token and expert matrices.
Key Parameters:
- A: The input tensor representing tokens with shape (*, K), where '*' can
be any shape representing batches and K is the feature dimension of
each token.
- B: The stacked MOE weight tensor with shape (E, N, K), where E is
the number of experts, K is the input feature dimension, and N is
the output feature dimension.
- C: The output cache tensor with shape (M, topk, N), where M is the
total number of tokens post padding, topk is the number of times
each token is repeated, and N is the output feature dimension.
- sorted_token_ids: A tensor containing the sorted indices of tokens,
repeated topk times and arranged by the expert index they are
assigned to.
- expert_ids: A tensor containing the indices of the expert for each
block. It determines which expert matrix from B should be used for
each block in A.
This kernel performs the multiplication of a token by its corresponding
expert matrix as determined by `expert_ids`. The sorting of
`sorted_token_ids` by expert index and padding ensures divisibility by
block_m, which is necessary to maintain consistency in block matrix
multiplication across different blocks processed by the same expert.
"""
pid = tl.program_id(axis=0)
pid_m, pid_n = col_major(pid,
EM, N,
block_m, block_n,)
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
if pid_m * block_m >= num_tokens_post_padded:
return
offs_token_id = pid_m * block_m + tl.arange(0, block_m)
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
token_mask = offs_token < num_valid_tokens
offs_bn = (pid_n * block_n + tl.arange(0, block_n)) % N
offs_k = tl.arange(0, block_k)
a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am +
offs_k[None, :] * stride_ak)
off_experts = tl.load(expert_ids_ptr + pid_m)
b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk +
offs_bn[None, :] * stride_bn)
# -----------------------------------------------------------
# Iterate to compute a block of the C matrix.
# We accumulate into a `[block_m, block_n]` block
# of fp32 values for higher accuracy.
# `accumulator` will be converted back to fp16 after the loop.
accumulator = tl.zeros((block_m, block_n), dtype=tl.float32)
for k in range(0, tl.cdiv(K, block_k)):
# Load the next block of A and B, generate a mask by checking the K dimension.
a = tl.load(a_ptrs,
mask=token_mask[:, None] &
(offs_k[None, :] < K - k * block_k),
other=0.0)
b = tl.load(b_ptrs,
mask=offs_k[:, None] < K - k * block_k,
other=0.0)
# We accumulate along the K dimension.
accumulator += tl.dot(a, b)
# Advance the ptrs to the next K block.
a_ptrs += block_k * stride_ak
b_ptrs += block_k * stride_bk
if MUL_ROUTED_WEIGHT:
moe_weight = tl.load(topk_weights_ptr + offs_token * stride_weight,
mask=token_mask,
other=0)
accumulator = accumulator * moe_weight[:, None]
accumulator = accumulator.to(compute_type)
# -----------------------------------------------------------
# Write back the block of the output
offs_cn = pid_n * block_n + tl.arange(0, block_n)
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[
None, :]
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
tl.store(c_ptrs, accumulator, mask=c_mask)
def moe_align_block_size(
topk_ids: torch.Tensor, block_size: int,
num_experts: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Aligns the token distribution across experts to be compatible with block
size for matrix multiplication.
Parameters:
- topk_ids: A tensor of shape [total_tokens, top_k] representing the
top-k expert indices for each token.
- block_size: The block size used in block matrix multiplication.
- num_experts: The total number of experts.
Returns:
- sorted_token_ids: A tensor containing the sorted token indices according
to their allocated expert.
- expert_ids: A tensor indicating the assigned expert index for each block.
- num_tokens_post_padded: The total number of tokens after padding,
ensuring divisibility by block_size.
This function pads the number of tokens that each expert needs to process
so that it is divisible by block_size.
Padding ensures that during block matrix multiplication, the dimensions
align correctly.
Example:
Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]],
block_size = 4, and num_experts = 4:
- We initially have 12 tokens (after repeating 'top_k' times) and 4 experts,
with each expert needing to process 3 tokens.
- As block_size is 4, we pad 1 token for each expert.
- First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3].
- Then append padding tokens [12, 12, 12, 12] for each block.
- After sorting by expert index, we obtain token_ids
[3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12].
Tokens 12 are non-existent (padding) and are ignored in
the subsequent matrix multiplication.
- The padding ensures that the total number of tokens is now divisible
by block_size for proper block matrix operations.
"""
sorted_ids = torch.empty(
(topk_ids.numel() + num_experts * (block_size - 1), ),
dtype=torch.int32,
device=topk_ids.device)
expert_ids = torch.empty((topk_ids.numel() + num_experts, ),
dtype=torch.int32,
device=topk_ids.device)
sorted_ids.fill_(topk_ids.numel())
num_tokens_post_pad = torch.empty((1),
dtype=torch.int32,
device=topk_ids.device)
ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids,
expert_ids, num_tokens_post_pad)
return sorted_ids, expert_ids, num_tokens_post_pad
def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor,
mul_routed_weight: bool, top_k: int,
config: Dict[str, Any]) -> None:
assert topk_weights.stride(1) == 1
assert sorted_token_ids.stride(0) == 1
EM = sorted_token_ids.shape[0]
N = B.shape[1]
grid = lambda META: (triton.cdiv(EM, META['block_m']) *
triton.cdiv(N, META['block_n']), )
fused_moe_kernel[grid](
A,
B,
C,
topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
B.shape[1],
B.shape[2],
sorted_token_ids.shape[0],
topk_ids.numel(),
A.stride(0),
A.stride(1),
B.stride(0),
B.stride(2),
B.stride(1),
C.stride(1),
C.stride(2),
topk_weights.stride(1), # New argument
sorted_token_ids.stride(0), # New argument
MUL_ROUTED_WEIGHT=mul_routed_weight,
top_k=top_k,
compute_type=tl.bfloat16 if A.dtype == torch.bfloat16 else tl.float16,
**config,
)
def get_config_file_name(E: int, N: int) -> str:
device_name = torch.cuda.get_device_name().replace(" ", "_")
return f"E={E},N={N},device_name={device_name}.json"
@functools.lru_cache
def get_moe_configs(E: int, N: int) -> Optional[Dict[int, Any]]:
"""
Return optimized configurations for the fused MoE kernel.
The return value will be a dictionary that maps an irregular grid of
batch sizes to configurations of the fused_moe kernel. To evaluate the
kernel on a given batch size bs, the closest batch size in the grid should
be picked and the associated configuration chosen to invoke the kernel.
"""
# First look up if an optimized configuration is available in the configs
# directory
json_file_name = get_config_file_name(E, N)
config_file_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name)
if os.path.exists(config_file_path):
with open(config_file_path) as f:
logger.info(
f"Using configuration from {config_file_path} for MoE layer.")
# If a configuration has been found, return it
return {int(key): val for key, val in json.load(f).items()}
# If no optimized configuration is available, we will use the default
# configuration
return None
def fused_moe(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
gating_output: torch.Tensor, # Not in Pytorch-Labs col_major
topk: int, # Pytorch-Labs pass topk_weights and topk_ids
renormalize: bool, # Not in Pytorch-Labs col_major
inplace: bool = False,
override_config: Optional[Dict[str, Any]] = None, # Not in Pytorch-Labs col_major
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
weights, w1 and w2, and top-k gating mechanism.
Parameters:
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- gating_output (torch.Tensor): The output of the gating operation
(before softmax).
- topk (int): The number of top-k experts to select.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- inplace (bool): If True, perform the operation in-place.
Defaults to False.
- override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
# Check constraints.
assert hidden_states.shape[0] == gating_output.shape[0], (
"Number of tokens mismatch")
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
assert hidden_states.dtype in [
torch.float32, torch.float16, torch.bfloat16
]
M, _ = hidden_states.shape
E, N, _ = w1.shape
# print("hidden_states shape or BatchSize =",hidden_states.shape)
# print("w1 shape =",w1.shape)
# print("w2 shape =",w2.shape)
if MEASURE_TIME:
start_time = time.perf_counter()
if is_hip():
# The MoE kernels are not yet supported on ROCm.
routing_weights = torch.softmax(gating_output,
dim=-1,
dtype=torch.float32)
topk_weights, topk_ids = torch.topk(routing_weights, topk, dim=-1)
else:
import vllm._moe_C as moe_kernels
topk_weights = torch.empty(M,
topk,
dtype=torch.float32,
device=hidden_states.device)
topk_ids = torch.empty(M,
topk,
dtype=torch.int32,
device=hidden_states.device)
token_expert_indicies = torch.empty(M,
topk,
dtype=torch.int32,
device=hidden_states.device)
moe_kernels.topk_softmax(
topk_weights,
topk_ids,
token_expert_indicies,
gating_output.float(), # TODO(woosuk): Optimize this.
)
del token_expert_indicies # Not used. Will be used in the future.
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
config = {
'block_m': 128,
'block_n': 128,
'block_k': 64,
}
intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N),
device=hidden_states.device,
dtype=hidden_states.dtype)
intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N // 2),
device=hidden_states.device,
dtype=hidden_states.dtype)
intermediate_cache3 = torch.empty((M, topk_ids.shape[1], w2.shape[1]),
device=hidden_states.device,
dtype=hidden_states.dtype)
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
topk_ids, config['block_m'], E)
# print(f"First MoE: a.shape = {hidden_states.shape}, b.shape = {w1.shape}, c.shape = {intermediate_cache1.shape},\n"
# f"Second MoE:a.shape = {intermediate_cache2.shape}, b.shape = {w2.shape}, c.shape = {intermediate_cache3.shape}")
invoke_fused_moe_kernel(hidden_states, w1, intermediate_cache1,
topk_weights, topk_ids, sorted_token_ids,
expert_ids, num_tokens_post_padded, False,
topk_ids.shape[1], config)
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
invoke_fused_moe_kernel(intermediate_cache2, w2, intermediate_cache3,
topk_weights, topk_ids, sorted_token_ids,
expert_ids, num_tokens_post_padded, True, 1,
config)
if inplace:
result = torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
dim=1,
out=hidden_states)
else:
result torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
dim=1)
if MEASURE_TIME:
end_time = time.perf_counter()
elapsed_time = (end_time - start_time) * 1_000 # Convert to milliseconds
print(f"Batch size = {M}, elapsed_time = {elapsed_time} ms")
return result
The text was updated successfully, but these errors were encountered:
In my tests, using 8 H100 GPUs did not show any acceleration. However, when I ran the script
kernels/triton/inference/col_major_moe_gemm/perf_test_moe.py
, I did see a 2-3 times speedup. Later, I tested Column-Major MoE and the MoE from vllm on 2 H100 GPUs separately and found about a 25% speedup, but this acceleration was not observed on 8 H100 GPUs. Is this result reasonable?To reproduce:
I got roughly the same throughput from the above benchmark.
The code for Column-Major_fused_moe.py
The text was updated successfully, but these errors were encountered: