Skip to content

Commit 183006a

Browse files
patnotzGoogle-ML-Automation
authored andcommitted
Add a Flax NNX layer and supporting code
PiperOrigin-RevId: 820445685
1 parent 2121dbf commit 183006a

File tree

9 files changed

+1309
-0
lines changed

9 files changed

+1309
-0
lines changed

jax_tpu_embedding/sparsecore/examples/models/shakespeare/BUILD

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,17 @@ pytype_strict_library(
5757
pypi_requirement("jax"),
5858
],
5959
)
60+
61+
pytype_strict_library(
62+
name = "flax_nnx_model",
63+
srcs = [
64+
"flax_nnx_model.py",
65+
],
66+
deps = [
67+
"//jax_tpu_embedding/sparsecore/lib/flax/nnx:embed",
68+
"//jax_tpu_embedding/sparsecore/lib/nn:embedding",
69+
"//jax_tpu_embedding/sparsecore/lib/nn:embedding_spec",
70+
pypi_requirement("flax/nnx"),
71+
pypi_requirement("jax"),
72+
],
73+
)
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
# Copyright 2024 The JAX SC Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Shakespeare model using embedding layer."""
15+
16+
from flax import nnx
17+
import jax
18+
import jax.numpy as jnp
19+
from jax_tpu_embedding.sparsecore.lib.flax.nnx import embed
20+
from jax_tpu_embedding.sparsecore.lib.nn import embedding
21+
from jax_tpu_embedding.sparsecore.lib.nn import embedding_spec
22+
23+
Nested = embedding.Nested
24+
25+
26+
################################################################################
27+
# Define the model.
28+
################################################################################
29+
class Model(nnx.Module):
30+
"""Shakespeare model using embedding layer."""
31+
32+
def __init__(
33+
self,
34+
feature_specs: Nested[embedding_spec.FeatureSpec],
35+
global_batch_size: int,
36+
vocab_size: int,
37+
seq_len: int,
38+
embedding_size: int,
39+
enable_minibatching: bool = False,
40+
feature_name: str = 'shakespeare_feature',
41+
mesh: jax.sharding.Mesh | None = None,
42+
sharding_axis: str = 'sparsecore_sharding',
43+
):
44+
self.feature_specs = feature_specs
45+
self.global_batch_size = global_batch_size
46+
self.vocab_size = vocab_size
47+
self.seq_len = seq_len
48+
self.embedding_size = embedding_size
49+
self.enable_minibatching = enable_minibatching
50+
self.feature_name = feature_name
51+
self.mesh = mesh
52+
self.sharding_axis = sharding_axis
53+
rngs = nnx.Rngs(params=42)
54+
self.embedding_layer = embed.SparseCoreEmbed(
55+
feature_specs=self.feature_specs,
56+
mesh=self.mesh,
57+
sharding_axis=self.sharding_axis,
58+
rngs=rngs,
59+
enable_minibatching=enable_minibatching,
60+
)
61+
e = self.embedding_size
62+
v = self.vocab_size
63+
s = self.seq_len
64+
self.dense_layer_1 = nnx.Linear(
65+
in_features=s * e,
66+
out_features=e,
67+
rngs=rngs,
68+
)
69+
self.dense_layer_2 = nnx.Linear(
70+
in_features=e,
71+
out_features=v,
72+
rngs=rngs,
73+
)
74+
75+
def add_sharding_constraint(self, x: jax.Array, names: tuple[str | None]):
76+
# Add a sharding constraint to the array.
77+
#
78+
# Add a sharding constraint to the array to ensure that the sharding
79+
# information is not lost during compilation. This may not be necessary but
80+
# it helps SPMD and ensures that the sharding information is as expected.
81+
#
82+
# Args:
83+
# x: The array to add the sharding constraint to.
84+
# names: The mesh axes for the partition spec.
85+
#
86+
# Returns:
87+
# The array with the sharding constraint added.
88+
return jax.lax.with_sharding_constraint(
89+
x,
90+
jax.sharding.NamedSharding(
91+
self.mesh, jax.sharding.PartitionSpec(*names)
92+
),
93+
)
94+
95+
def __call__(self, embedding_lookup_inputs: embedding.PreprocessedInput):
96+
# Run the embedding layer.
97+
x = self.embedding_layer(embedding_lookup_inputs)
98+
99+
# Unpack the activations.
100+
x = x[self.feature_name]
101+
x = jnp.reshape(x, (self.global_batch_size, -1))
102+
x = self.add_sharding_constraint(x, (self.sharding_axis,))
103+
104+
# Apply the dense portion of the model.
105+
x = self.dense_layer_1(x)
106+
x = self.add_sharding_constraint(x, (self.sharding_axis,))
107+
x = self.dense_layer_2(x)
108+
x = self.add_sharding_constraint(x, (self.sharding_axis,))
109+
110+
return x

jax_tpu_embedding/sparsecore/lib/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ pytype_strict_library(
2727
"//jax_tpu_embedding/sparsecore/lib/core", # buildcleaner: keep
2828
"//jax_tpu_embedding/sparsecore/lib/fdo", # buildcleaner: keep
2929
"//jax_tpu_embedding/sparsecore/lib/flax", # buildcleaner: keep
30+
"//jax_tpu_embedding/sparsecore/lib/flax/nnx", # buildcleaner: keep
3031
"//jax_tpu_embedding/sparsecore/lib/nn", # buildcleaner: keep
3132
"//jax_tpu_embedding/sparsecore/lib/proto", # buildcleaner: keep
3233
],
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Copyright 2024 The JAX SC Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
load("//jax_tpu_embedding/sparsecore:jax_tpu_embedding.bzl", "EXTERNAL_USERS")
15+
load("//third_party/bazel/python:pypi.bzl", "pypi_requirement")
16+
load("//third_party/bazel/python:pytype.bzl", "pytype_strict_library")
17+
18+
package(
19+
default_applicable_licenses = ["//:license"],
20+
default_visibility = EXTERNAL_USERS,
21+
)
22+
23+
pytype_strict_library(
24+
name = "embed",
25+
srcs = [
26+
"embed.py",
27+
],
28+
deps = [
29+
"//jax_tpu_embedding/sparsecore/lib/nn:embedding",
30+
"//jax_tpu_embedding/sparsecore/lib/nn:embedding_spec",
31+
"//jax_tpu_embedding/sparsecore/utils",
32+
pypi_requirement("flax/nnx"),
33+
pypi_requirement("jax"),
34+
pypi_requirement("optax"),
35+
],
36+
)
37+
38+
# Library target.
39+
pytype_strict_library(
40+
name = "nnx",
41+
srcs = ["__init__.py"],
42+
visibility = ["//jax_tpu_embedding/sparsecore/lib:__pkg__"],
43+
deps = [
44+
":embed", # buildcleaner: keep
45+
],
46+
)
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright 2024 The JAX SC Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# Empty file needed by setuptools.find_packages to recognize this as a package.

0 commit comments

Comments
 (0)