|
| 1 | +# Copyright 2024 The JAX SC Authors. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +"""Shakespeare model using embedding layer.""" |
| 15 | + |
| 16 | +from flax import nnx |
| 17 | +import jax |
| 18 | +import jax.numpy as jnp |
| 19 | +from jax_tpu_embedding.sparsecore.lib.flax.nnx import embed |
| 20 | +from jax_tpu_embedding.sparsecore.lib.nn import embedding |
| 21 | +from jax_tpu_embedding.sparsecore.lib.nn import embedding_spec |
| 22 | + |
| 23 | +Nested = embedding.Nested |
| 24 | + |
| 25 | + |
| 26 | +################################################################################ |
| 27 | +# Define the model. |
| 28 | +################################################################################ |
| 29 | +class Model(nnx.Module): |
| 30 | + """Shakespeare model using embedding layer.""" |
| 31 | + |
| 32 | + def __init__( |
| 33 | + self, |
| 34 | + feature_specs: Nested[embedding_spec.FeatureSpec], |
| 35 | + global_batch_size: int, |
| 36 | + vocab_size: int, |
| 37 | + seq_len: int, |
| 38 | + embedding_size: int, |
| 39 | + enable_minibatching: bool = False, |
| 40 | + feature_name: str = 'shakespeare_feature', |
| 41 | + mesh: jax.sharding.Mesh | None = None, |
| 42 | + sharding_axis: str = 'sparsecore_sharding', |
| 43 | + ): |
| 44 | + self.feature_specs = feature_specs |
| 45 | + self.global_batch_size = global_batch_size |
| 46 | + self.vocab_size = vocab_size |
| 47 | + self.seq_len = seq_len |
| 48 | + self.embedding_size = embedding_size |
| 49 | + self.enable_minibatching = enable_minibatching |
| 50 | + self.feature_name = feature_name |
| 51 | + self.mesh = mesh |
| 52 | + self.sharding_axis = sharding_axis |
| 53 | + rngs = nnx.Rngs(params=42) |
| 54 | + self.embedding_layer = embed.SparseCoreEmbed( |
| 55 | + feature_specs=self.feature_specs, |
| 56 | + mesh=self.mesh, |
| 57 | + sharding_axis=self.sharding_axis, |
| 58 | + rngs=rngs, |
| 59 | + enable_minibatching=enable_minibatching, |
| 60 | + ) |
| 61 | + e = self.embedding_size |
| 62 | + v = self.vocab_size |
| 63 | + s = self.seq_len |
| 64 | + self.dense_layer_1 = nnx.Linear( |
| 65 | + in_features=s * e, |
| 66 | + out_features=e, |
| 67 | + rngs=rngs, |
| 68 | + ) |
| 69 | + self.dense_layer_2 = nnx.Linear( |
| 70 | + in_features=e, |
| 71 | + out_features=v, |
| 72 | + rngs=rngs, |
| 73 | + ) |
| 74 | + |
| 75 | + def add_sharding_constraint(self, x: jax.Array, names: tuple[str | None]): |
| 76 | + # Add a sharding constraint to the array. |
| 77 | + # |
| 78 | + # Add a sharding constraint to the array to ensure that the sharding |
| 79 | + # information is not lost during compilation. This may not be necessary but |
| 80 | + # it helps SPMD and ensures that the sharding information is as expected. |
| 81 | + # |
| 82 | + # Args: |
| 83 | + # x: The array to add the sharding constraint to. |
| 84 | + # names: The mesh axes for the partition spec. |
| 85 | + # |
| 86 | + # Returns: |
| 87 | + # The array with the sharding constraint added. |
| 88 | + return jax.lax.with_sharding_constraint( |
| 89 | + x, |
| 90 | + jax.sharding.NamedSharding( |
| 91 | + self.mesh, jax.sharding.PartitionSpec(*names) |
| 92 | + ), |
| 93 | + ) |
| 94 | + |
| 95 | + def __call__(self, embedding_lookup_inputs: embedding.PreprocessedInput): |
| 96 | + # Run the embedding layer. |
| 97 | + x = self.embedding_layer(embedding_lookup_inputs) |
| 98 | + |
| 99 | + # Unpack the activations. |
| 100 | + x = x[self.feature_name] |
| 101 | + x = jnp.reshape(x, (self.global_batch_size, -1)) |
| 102 | + x = self.add_sharding_constraint(x, (self.sharding_axis,)) |
| 103 | + |
| 104 | + # Apply the dense portion of the model. |
| 105 | + x = self.dense_layer_1(x) |
| 106 | + x = self.add_sharding_constraint(x, (self.sharding_axis,)) |
| 107 | + x = self.dense_layer_2(x) |
| 108 | + x = self.add_sharding_constraint(x, (self.sharding_axis,)) |
| 109 | + |
| 110 | + return x |
0 commit comments