Skip to content

Commit ded8ada

Browse files
Add more dims for batch invariant shims (vllm-project#27489)
Signed-off-by: Bram Wasti <[email protected]> Signed-off-by: Bram Wasti <[email protected]> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent 8bff831 commit ded8ada

File tree

1 file changed

+42
-2
lines changed

1 file changed

+42
-2
lines changed

vllm/model_executor/layers/batch_invariant.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -478,9 +478,48 @@ def matmul_batch_invariant(a, b, *, out=None):
478478
elif a.ndim == 3 and b.ndim == 3:
479479
# Handle batched case like bmm
480480
return bmm_batch_invariant(a, b, out=out)
481+
elif a.ndim == 3 and b.ndim == 2:
482+
# Handle 3D x 2D: common for linear layers
483+
# (batch, seq, hidden) @ (hidden, out) -> (batch, seq, out)
484+
# Reshape to 2D, do mm, reshape back
485+
batch, seq, hidden = a.shape
486+
a_2d = a.reshape(-1, hidden)
487+
result_2d = matmul_persistent(a_2d, b)
488+
result = result_2d.reshape(batch, seq, -1)
489+
if out is not None:
490+
out.copy_(result)
491+
return out
492+
return result
493+
elif a.ndim == 2 and b.ndim == 3:
494+
# Handle 2D x 3D: (M, K) @ (B, K, N) -> (B, M, N)
495+
# By broadcasting `a` to 3D, we can reuse the batched matrix
496+
# multiplication logic.
497+
a_expanded = a.unsqueeze(0).expand(b.shape[0], -1, -1)
498+
return bmm_batch_invariant(a_expanded, b, out=out)
499+
elif a.ndim == 4 and b.ndim == 4:
500+
# Handle 4D attention tensors: [batch, heads, seq, dim]
501+
# Reshape to 3D, process, reshape back
502+
batch, heads, seq_a, dim_a = a.shape
503+
_, _, dim_b, seq_b = b.shape
504+
505+
# Reshape to [batch*heads, seq_a, dim_a]
506+
a_3d = a.reshape(batch * heads, seq_a, dim_a)
507+
b_3d = b.reshape(batch * heads, dim_b, seq_b)
508+
509+
# Do batched matmul
510+
result_3d = bmm_batch_invariant(a_3d, b_3d)
511+
512+
# Reshape back to [batch, heads, seq_a, seq_b]
513+
result = result_3d.reshape(batch, heads, seq_a, seq_b)
514+
515+
if out is not None:
516+
out.copy_(result)
517+
return out
518+
return result
481519
else:
482520
raise ValueError(
483-
f"matmul_batch_invariant currently only supports 2D x 2D and 3D x 3D, "
521+
f"matmul_batch_invariant currently only supports 2D x 2D, 3D x 3D, "
522+
f"3D x 2D, 2D x 3D, and 4D x 4D, "
484523
f"got shapes {a.shape} and {b.shape}"
485524
)
486525

@@ -667,7 +706,8 @@ def rms_norm_batch_invariant(
667706

668707

669708
def linear_batch_invariant(input, weight, bias=None):
670-
output = mm_batch_invariant(input, weight.t())
709+
output = matmul_batch_invariant(input, weight.t())
710+
671711
if bias is not None:
672712
output = output + bias
673713
return output

0 commit comments

Comments
 (0)