We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent b9f4027 commit d5e817dCopy full SHA for d5e817d
python/sgl_jax/srt/layers/moe.py
@@ -385,8 +385,7 @@ def _gmm_compute_with_sharded_weights(
385
)
386
387
# Need to reduce over the intermediate dimension which is sharded on "tensor" axis
388
- mesh_shape_tensor = jax.lax.psum(1, "tensor")
389
- if mesh_shape_tensor > 1:
+ if self.tp_size > 1:
390
intermediate_output = jax.lax.psum(intermediate_output, "tensor")
391
392
return intermediate_output
0 commit comments