Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 24 additions & 11 deletions python/sgl_jax/srt/layers/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@
import numpy as np
from flax import nnx
from flax.nnx.nn import dtypes
from flax.nnx.nn.linear import default_embed_init
from flax.typing import PromoteDtypeFn
from jax.sharding import NamedSharding
from jax.sharding import PartitionSpec as P


class Embed(nnx.Module):
Expand All @@ -35,7 +36,7 @@ class Embed(nnx.Module):
dtype: the dtype of the embedding vectors (default: float32).
param_dtype: the dtype of the embedding parameters.
promote_dtype: the dtype promotion function.
embedding_init: embedding initializer.
kernel_axes: the axes of kernel weights.
rngs: rng keys.
"""

Expand All @@ -46,8 +47,9 @@ def __init__(
dtype: jnp.dtype | None = None,
param_dtype: jnp.dtype = jnp.bfloat16,
promote_dtype: PromoteDtypeFn = dtypes.promote_dtype,
embedding_init: nnx.Initializer = default_embed_init,
kernel_axes: tuple[str, ...] = (None, "tensor"),
rngs: nnx.Rngs = None,
mesh: jax.sharding.Mesh = None,
):
"""
Sets up the embedding parameters for the model.
Expand All @@ -67,13 +69,19 @@ def __init__(
rngs: Random number generator state for parameter initialization.
"""
self.embedding = nnx.Param(
embedding_init(jax.random.PRNGKey(0), (num_embeddings, features), param_dtype)
jax.random.normal(
jax.random.PRNGKey(0),
(num_embeddings, features),
dtype=param_dtype,
out_sharding=P(*kernel_axes),
),
)

self.kernel_axes = kernel_axes
self.num_embeddings = num_embeddings
self.features = features
self.dtype = dtype or self.embedding.value.dtype
self.promote_dtype = promote_dtype
self.mesh = mesh

def __call__(self, inputs: jax.Array) -> jax.Array:
"""Embeds the inputs along the last dimension.
Expand All @@ -92,7 +100,11 @@ def __call__(self, inputs: jax.Array) -> jax.Array:
(embedding,) = self.promote_dtype((self.embedding.value,), dtype=self.dtype, inexact=False)
if self.num_embeddings == 1:
return jnp.broadcast_to(embedding, inputs.shape + (self.features,))
return jnp.take(embedding, inputs, axis=0)

output_pspec = P(*([None] * inputs.ndim), self.kernel_axes[-1])
output_sharding = NamedSharding(self.mesh, output_pspec)
output = embedding.at[inputs].get(out_sharding=output_sharding)
return output

def attend(self, query: jax.Array) -> jax.Array:
"""Attend over the embedding using a query array.
Expand Down Expand Up @@ -126,7 +138,7 @@ def __init__(
dtype: jnp.dtype | None = None,
param_dtype: jnp.dtype = jnp.bfloat16,
promote_dtype: PromoteDtypeFn = dtypes.promote_dtype,
embedding_init: nnx.Initializer = default_embed_init,
kernel_axes: tuple[str, ...] = ("tensor", None),
rngs: nnx.Rngs = None,
use_bias: bool = False,
):
Expand All @@ -151,16 +163,17 @@ def __init__(
dtype=dtype,
param_dtype=param_dtype,
promote_dtype=promote_dtype,
embedding_init=embedding_init,
kernel_axes=kernel_axes,
rngs=rngs,
)
if use_bias:
self.bias = nnx.Param(
nnx.with_partitioning(nnx.initializers.constant(0.0), (None, "tensor"))(
jax.random.normal(
jax.random.PRNGKey(0),
(self.num_embeddings, self.features),
param_dtype,
)
dtype=param_dtype,
out_sharding=P(None, "tensor"),
),
)
else:
self.bias = None
Expand Down
18 changes: 13 additions & 5 deletions python/sgl_jax/srt/layers/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
import jax.numpy as jnp
from flax import nnx
from flax.nnx import rnglib
from flax.nnx.nn import dtypes, initializers
from flax.typing import Array, Axes, Dtype, Initializer
from flax.nnx.nn import dtypes
from flax.typing import Array, Axes, Dtype
from jax import lax
from jax.sharding import PartitionSpec as P


def _canonicalize_axes(rank: int, axes: Axes) -> tuple[int, ...]:
Expand All @@ -34,7 +35,6 @@ def __init__(
dtype: Dtype | None = None,
param_dtype: Dtype = jnp.float32,
use_scale: bool = True,
scale_init: Initializer = initializers.ones,
reduction_axes: Axes = -1,
feature_axes: Axes = -1,
axis_name: str | None = None,
Expand All @@ -46,7 +46,16 @@ def __init__(

self.scale: nnx.Param[jax.Array] | None
if use_scale:
self.scale = nnx.Param(scale_init(jax.random.PRNGKey(0), feature_shape, param_dtype))
self.scale = nnx.Param(
jax.random.normal(
jax.random.PRNGKey(0),
feature_shape,
dtype=param_dtype,
out_sharding=P(
None,
),
),
)
else:
self.scale = None

Expand All @@ -55,7 +64,6 @@ def __init__(
self.dtype = dtype
self.param_dtype = param_dtype
self.use_scale = use_scale
self.scale_init = scale_init
self.reduction_axes = reduction_axes
self.feature_axes = feature_axes
self.axis_name = axis_name
Expand Down
29 changes: 22 additions & 7 deletions python/sgl_jax/srt/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from flax import nnx
from jax import lax
from jax import numpy as jnp
from jax.sharding import NamedSharding
from jax.sharding import PartitionSpec as P


def _canonicalize_tuple(x):
Expand Down Expand Up @@ -39,33 +41,46 @@ def __init__(
params_dtype: jnp.dtype | None = jnp.bfloat16,
kernel_axes: Sequence[str] | None = None,
rngs: nnx.Rngs = None,
mesh: jax.sharding.Mesh = None,
):
"""Initialize parameters and quantization method."""
self.skip_bias_add = skip_bias_add
self.params_dtype = params_dtype
self.kernel_axes = kernel_axes
self.mesh = mesh
self.weight = nnx.Param(
nnx.with_partitioning(nnx.initializers.normal(), kernel_axes)(
jax.random.PRNGKey(0), (input_size, output_size), params_dtype
)
jax.random.normal(
jax.random.PRNGKey(0),
(input_size, output_size),
dtype=params_dtype,
out_sharding=P(*kernel_axes),
),
)
if use_bias:
self.bias = nnx.Param(
nnx.with_partitioning(nnx.initializers.zeros_init(), (kernel_axes[-1],))(
jax.random.PRNGKey(0), (output_size,), params_dtype
)
jax.random.normal(
jax.random.PRNGKey(0),
(output_size,),
dtype=params_dtype,
out_sharding=P(
kernel_axes[-1],
),
),
)
else:
self.bias = None

def __call__(self, x: jax.Array) -> tuple[jax.Array, jax.Array | None]:
"""Forward pass of the linear layer."""
bias = self.bias if not self.skip_bias_add else None
# Access the underlying JAX array using .value property
output_pspec = P(*([None] * (x.ndim - 1)), self.kernel_axes[-1])
output_sharding = NamedSharding(self.mesh, output_pspec)
output = lax.dot_general(
x,
self.weight.value,
(((x.ndim - 1,), (0,)), ((), ())),
preferred_element_type=self.params_dtype,
out_sharding=output_sharding,
)
if bias is not None:
output = output + bias.value
Expand Down
5 changes: 2 additions & 3 deletions python/sgl_jax/srt/layers/logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,9 +403,8 @@ def get_top_logprobs(all_logprobs: jax.Array, logits_metadata: LogitsMetadata):

return input_top_logprobs_val, input_top_logprobs_idx

@staticmethod
def compute_temp_top_p_normalized_logprobs(
last_logits: jax.Array, logits_metadata: LogitsMetadata
self, last_logits: jax.Array, logits_metadata: LogitsMetadata
) -> jax.Array:
"""
compute logprobs for the output token from the given logits.
Expand All @@ -424,7 +423,7 @@ def compute_temp_top_p_normalized_logprobs(

probs = jnp.softmax(last_logits, axis=-1)
del last_logits
probs = top_p_normalize_probs_jax(probs, logits_metadata.top_p)
probs = top_p_normalize_probs_jax(probs, logits_metadata.top_p, self.mesh)
return jnp.log(probs)
else:
return nn.log_softmax(last_logits, axis=-1)
Expand Down
Loading
Loading