@@ -478,9 +478,48 @@ def matmul_batch_invariant(a, b, *, out=None):
478478 elif a .ndim == 3 and b .ndim == 3 :
479479 # Handle batched case like bmm
480480 return bmm_batch_invariant (a , b , out = out )
481+ elif a .ndim == 3 and b .ndim == 2 :
482+ # Handle 3D x 2D: common for linear layers
483+ # (batch, seq, hidden) @ (hidden, out) -> (batch, seq, out)
484+ # Reshape to 2D, do mm, reshape back
485+ batch , seq , hidden = a .shape
486+ a_2d = a .reshape (- 1 , hidden )
487+ result_2d = matmul_persistent (a_2d , b )
488+ result = result_2d .reshape (batch , seq , - 1 )
489+ if out is not None :
490+ out .copy_ (result )
491+ return out
492+ return result
493+ elif a .ndim == 2 and b .ndim == 3 :
494+ # Handle 2D x 3D: (M, K) @ (B, K, N) -> (B, M, N)
495+ # By broadcasting `a` to 3D, we can reuse the batched matrix
496+ # multiplication logic.
497+ a_expanded = a .unsqueeze (0 ).expand (b .shape [0 ], - 1 , - 1 )
498+ return bmm_batch_invariant (a_expanded , b , out = out )
499+ elif a .ndim == 4 and b .ndim == 4 :
500+ # Handle 4D attention tensors: [batch, heads, seq, dim]
501+ # Reshape to 3D, process, reshape back
502+ batch , heads , seq_a , dim_a = a .shape
503+ _ , _ , dim_b , seq_b = b .shape
504+
505+ # Reshape to [batch*heads, seq_a, dim_a]
506+ a_3d = a .reshape (batch * heads , seq_a , dim_a )
507+ b_3d = b .reshape (batch * heads , dim_b , seq_b )
508+
509+ # Do batched matmul
510+ result_3d = bmm_batch_invariant (a_3d , b_3d )
511+
512+ # Reshape back to [batch, heads, seq_a, seq_b]
513+ result = result_3d .reshape (batch , heads , seq_a , seq_b )
514+
515+ if out is not None :
516+ out .copy_ (result )
517+ return out
518+ return result
481519 else :
482520 raise ValueError (
483- f"matmul_batch_invariant currently only supports 2D x 2D and 3D x 3D, "
521+ f"matmul_batch_invariant currently only supports 2D x 2D, 3D x 3D, "
522+ f"3D x 2D, 2D x 3D, and 4D x 4D, "
484523 f"got shapes { a .shape } and { b .shape } "
485524 )
486525
@@ -667,7 +706,8 @@ def rms_norm_batch_invariant(
667706
668707
669708def linear_batch_invariant (input , weight , bias = None ):
670- output = mm_batch_invariant (input , weight .t ())
709+ output = matmul_batch_invariant (input , weight .t ())
710+
671711 if bias is not None :
672712 output = output + bias
673713 return output
0 commit comments