Skip to content

Commit d5e817d

Browse files
committed
try tp support etp
1 parent b9f4027 commit d5e817d

File tree

1 file changed

+1
-2
lines changed
  • python/sgl_jax/srt/layers

1 file changed

+1
-2
lines changed

python/sgl_jax/srt/layers/moe.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -385,8 +385,7 @@ def _gmm_compute_with_sharded_weights(
385385
)
386386

387387
# Need to reduce over the intermediate dimension which is sharded on "tensor" axis
388-
mesh_shape_tensor = jax.lax.psum(1, "tensor")
389-
if mesh_shape_tensor > 1:
388+
if self.tp_size > 1:
390389
intermediate_output = jax.lax.psum(intermediate_output, "tensor")
391390

392391
return intermediate_output

0 commit comments

Comments
 (0)