Skip to content

Commit ea02f58

Browse files
authored
remove with_sharding_constraint (sgl-project#308)
1 parent 6881ae9 commit ea02f58

File tree

1 file changed

+1
-4
lines changed

1 file changed

+1
-4
lines changed

python/sgl_jax/srt/layers/sampler.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44
from jax import lax
55
from jax import numpy as jnp
66
from jax import random
7-
from jax.sharding import Mesh, NamedSharding
8-
from jax.sharding import PartitionSpec as P
7+
from jax.sharding import Mesh
98

109
from sgl_jax.srt.layers.binary_search import topk_mask, topp_mask
1110
from sgl_jax.srt.layers.logits_processor import LogitsProcessorOutput
@@ -30,8 +29,6 @@ def _regular_sampling(self, operands):
3029
"""Regular sampling branch"""
3130
logits, sampling_metadata, positions, rng, mesh, use_sort_for_toppk_minp = operands
3231

33-
logits = lax.with_sharding_constraint(logits, NamedSharding(mesh, P(None, None)))
34-
3532
# Validate broadcast compatibility for temperature division
3633
logits_batch_size = logits.shape[0]
3734
temperatures_shape = sampling_metadata.temperatures.shape

0 commit comments

Comments
 (0)