pip install --verbose git+https://github.com/fanshiqing/grouped_gemm@main
git submodule update --init --recursive
mkdir build
cd build
cmake ..
make -j
cd ..
# GroupedGEMM ops test
python grouped_gemm/ops_test.py
# topK permute & unpermute ops test
python grouped_gemm/permute_test.py
# sinkhorn kernel test
python grouped_gemm/sinkhorn_test.py
GPU Arch | FP32 | FP16 | BF16 | FP8 |
---|---|---|---|---|
SM 70 | Y | Y | . | Y |
SM 75 | Y | Y | . | Y |
SM 80 | Y | Y | Y | Y |
SM 86 | Y | Y | Y | Y |
SM 89 | Y | Y | Y | Y |
SM 90 | Y | Y | Y | Y |
grouped_gemm.ops.permute( input_act: torch.Tensor, indices: torch.Tensor, max_token_num=0: int) -> tuple
The output tuple of (torch.Tensor, torch.Tensor)
that contains two tensors permuted_act
and row_id_map
.
permuted_act
is the permutation of the original tensorinput_act
with its first dimension permuted according toindices
.row_id_map
is the mapping table for the row indices of the input activations before and aftergrouped_gemm.ops.permute
, which is used for the followingunpermute
op.
-
input_act (torch.Tensor)
shape = [tokens_num, hidden_size]
The input activations with each row (token) corresponds to topK experts. -
indices (torch.Tensor)
shape = [tokens_num, topK_num]
The topK expert indices for each row (token) of activations. Theint32
type is recommended. -
max_token_num (int)
The maximum number of tokens (rows) used for workspace pre-allocation.
grouped_gemm.ops.unpermute( input_act: torch.Tensor, row_id_map: torch.Tensor, probs) -> torch.Tensor
The mirror operator of grouped_gemm.ops.permute
.
-
input_act (torch.Tensor)
shape = [tokens_num * topK_num, hidden_size]
The permuted activations produced bygrouped_gemm.ops.permute
. -
row_id_map (torch.Tensor)
shape = [tokens_num * topK_num]
The mapping table for the row indices of the activations before and aftergrouped_gemm.ops.permute
. The second output tensor ofgrouped_gemm.ops.permute
. -
probs (torch.Tensor)
shape = [tokens_num, topK_num]
Sum weights for same-origin tokens from different experts.
import torch
from grouped_gemm import permute, unpermute
indices = torch.tensor([[1, 2], [0, 1], [0, 2], [1, 2]], dtype=torch.int32, device='cuda')
input_act = torch.tensor([[0,0,0,0], [1,1,1,1], [2,2,2,2], [3,3,3,3]], dtype=torch.float32, device='cuda')
probs = torch.ones_like(indices, dtype=torch.float32)
permuted_inputs, row_id_map = permute(input_act, indices)
unpermute_outputs = unpermute(permuted_inputs, row_id_map, probs)
print(row_id_map)
print(input_act)
print(permuted_inputs)
print(unpermute_outputs)
# Output
# tensor([2, 0, 1, 4, 5, 3, 6, 7], device='cuda:0', dtype=torch.int32)
# tensor([[0., 0., 0., 0.],
# [1., 1., 1., 1.],
# [2., 2., 2., 2.],
# [3., 3., 3., 3.]], device='cuda:0')
# tensor([[1., 1., 1., 1.],
# [2., 2., 2., 2.],
# [0., 0., 0., 0.],
# [1., 1., 1., 1.],
# [3., 3., 3., 3.],
# [0., 0., 0., 0.],
# [2., 2., 2., 2.],
# [3., 3., 3., 3.]], device='cuda:0')
# tensor([[0., 0., 0., 0.],
# [2., 2., 2., 2.],
# [4., 4., 4., 4.],
# [6., 6., 6., 6.]], device='cuda:0')