Replies: 1 comment
-
I got it working by using lax.pmean. It actually works even if you don't specify Can someone comment on whether # Copyright 2024 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import os
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'
import jax
jax.config.update("jax_platform_name", "cpu")
from jax import lax, numpy as jnp
import numpy as np
import optax
from flax import nnx
from jax.experimental import mesh_utils
import matplotlib.pyplot as plt
# create a mesh + shardings
num_devices = jax.local_device_count()
mesh = jax.sharding.Mesh(
mesh_utils.create_device_mesh((num_devices,)), ('data',)
)
rep_spec = jax.sharding.PartitionSpec()
dp_spec = jax.sharding.PartitionSpec('data')
model_sharding = jax.NamedSharding(mesh, rep_spec)
data_sharding = jax.NamedSharding(mesh, dp_spec)
# create model
class MLP(nnx.Module):
def __init__(self, din, dmid, dout, *, rngs: nnx.Rngs):
self.linear1 = nnx.Linear(din, dmid, rngs=rngs)
self.dropout = nnx.Dropout(.05, rngs=rngs)
self.linear2 = nnx.Linear(dmid, dout, rngs=rngs)
def __call__(self, x, rngs=None):
return self.linear2(self.dropout(nnx.relu(self.linear1(x)), rngs=rngs))
model = MLP(1, 64, 1, rngs=nnx.Rngs(0))
model.train()
optimizer = nnx.Optimizer(model, optax.adamw(1e-2))
# replicate state
state = nnx.state((model, optimizer))
state = jax.device_put(state, model_sharding)
nnx.update((model, optimizer), state)
# visualize model sharding
print('model sharding')
jax.debug.visualize_array_sharding(model.linear1.kernel.value)
batch_size = 16
@nnx.jit
@nnx.split_rngs(splits=batch_size, only='dropout')
@nnx.shard_map(
mesh=mesh,
in_specs=(rep_spec, rep_spec, dp_spec, dp_spec),
out_specs=rep_spec
)
def train_step(model: MLP, optimizer: nnx.Optimizer, x, y):
def loss_fn(model: MLP):
y_pred = model(x)
return jnp.mean((y - y_pred) ** 2)
loss, grads = nnx.value_and_grad(loss_fn)(model)
loss, grads = lax.pmean((loss, grads), 'data')
optimizer.update(grads)
return loss
def dataset(steps, batch_size):
for _ in range(steps):
x = np.random.uniform(-2, 2, size=(batch_size, 1))
y = 0.8 * x**2 + 0.1 + np.random.normal(0, 0.1, size=x.shape)
yield x, y
for step, (x, y) in enumerate(dataset(1000, batch_size)):
# shard data
x, y = jax.device_put((x, y), data_sharding)
# train
loss = train_step(model, optimizer, x, y)
if step == 0:
print('data sharding')
jax.debug.visualize_array_sharding(x)
if step % 100 == 0:
print(f'step={step}, loss={loss}')
# dereplicate state
state = nnx.state((model, optimizer))
state = jax.device_get(state)
nnx.update((model, optimizer), state)
X, Y = next(dataset(1, 1000))
x_range = np.linspace(X.min(), X.max(), 100)[:, None]
model.eval()
y_pred = model(x_range)
# plot
plt.scatter(X, Y, label='data')
plt.plot(x_range, y_pred, color='black', label='model')
plt.legend()
plt.show() |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I added dropout to https://github.com/google/flax/blob/main/examples/nnx_toy_examples/04_data_parallel_with_jit.py and got this:
Would it make sense to use nnx.jit and nnx.shard_map?
I haven't figured out how to do it yet.
Beta Was this translation helpful? Give feedback.
All reactions