Skip to content

Conversation

@z52527
Copy link
Collaborator

@z52527 z52527 commented Jan 6, 2026

Problem

Apply num_groups different linear transformations to corresponding slices of input:

Input:  x of shape (B * num_groups, input_dim)
Output: y of shape (B * num_groups, output_dim)

For each group n: y[b, n, :] = x[b, n, :] @ W[n, :, :]

Reference Implementation

The straightforward approach uses a loop over groups:

x = x.reshape(B, num_groups, D_in)
x_split = torch.split(x, 1, dim=1)

out_list = []
for i in range(num_groups):
    x_i = x_split[i].squeeze(1)           # (B, D_in)
    out_i = linear_layers[i](x_i)         # (B, D_out)
    out_list.append(out_i)

output = torch.stack(out_list, dim=1).reshape(-1, D_out)

Optimized Implementation

Use torch.bmm with strided output to fuse all GEMMs into one kernel:

x = x.reshape(B, num_groups, D_in)
output = torch.empty(B, num_groups, D_out, ...)   # pre-allocate final layout
torch.bmm(x.permute(1,0,2), weight,
          out=output.permute(1,0,2))              # cuBLAS writes to strided memory
return output.view(-1, D_out)                     # O(1) view, no copy

Key feature: cuBLAS strided batched GEMM supports strided output via ldc/strideC parameters, allowing direct write to the transposed memory layout.

Performance Results

Config: batch_size=2560, num_groups=12, input_dim=1024, output_dim=3072, dtype=bf16
Device: NVIDIA H100

Speedup
Forward 1.46x
Forward + Backward 1.41x

Device: NVIDIA A100

Speedup TFLOPS
Forward 1.67x 246.7
Forward + Backward 1.34x 238.0

@z52527 z52527 self-assigned this Jan 6, 2026
@JacoCheung
Copy link
Collaborator

JacoCheung commented Jan 7, 2026

@z52527 ,
Could you generalize the BmmImpl such that it could handle the activation of either [batch_count, batch_size, input_dim] or [batch_size, batch_count, input_dim]? Even though the input is [batch_count*batch_size, input_dim], your impl assumes that input is [batch_size, batch_count, input_dim].

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants