Skip to content

Commit

Permalink
Update make_create_train_state_fn to support multiple model inputs.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 583801783
  • Loading branch information
jpuigcerver authored and copybara-github committed Nov 19, 2023
1 parent efcf732 commit f8b56f6
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 11 deletions.
17 changes: 8 additions & 9 deletions vmoe/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,8 +334,7 @@ def make_create_train_state_fn(
*,
model: nn.Module,
optimizer_config: Dict[str, Any],
input_shape: Tuple[int, ...],
input_axis_resources: PartitionSpec,
input_shape_dtypes: Tuple[jax.ShapeDtypeStruct, ...],
train_steps: int,
seed: int = 0,
extra_rng_keys: Tuple[str, ...]) -> Callable[[], TrainState]:
Expand All @@ -344,8 +343,7 @@ def make_create_train_state_fn(
Args:
model: Linen module representing the model.
optimizer_config: A ConfigDict with the optimizer configuration.
input_shape: Shape of the inputs to the model.
input_axis_resources: PartitionSpec for the inputs of the model.
input_shape_dtypes: ShapeDtypeStructs representing the inputs to the model.
train_steps: Total number of training steps.
seed: PRNG seed.
extra_rng_keys: Sequence of RNG keys used by the model, in addition to
Expand All @@ -360,9 +358,11 @@ def make_create_train_state_fn(

def initialize():
rngs = utils.make_rngs(rng_keys, seed)
inputs = jnp.zeros(input_shape, dtype=jnp.float32)
inputs = partitioning.with_sharding_constraint(inputs, input_axis_resources)
variables = model.init(rngs, inputs)
inputs = tuple(jnp.zeros(x.shape, dtype=x.dtype)
for x in input_shape_dtypes)
inputs = tuple(partitioning.with_sharding_constraint(x, s.sharding)
for x, s in zip(inputs, input_shape_dtypes))
variables = model.init(rngs, *inputs)
rngs.pop('params') # This PRNGKey is not used anymore.
return TrainState.create(
apply_fn=model.apply, tx=tx, rngs=rngs, **variables)
Expand Down Expand Up @@ -699,8 +699,7 @@ def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str,
train_state_initialize_fn = make_create_train_state_fn(
model=create_flax_model(config=config.model, deterministic=False),
optimizer_config=config.optimizer,
input_shape=datataset_element_shape_dtype['image'].shape,
input_axis_resources=datataset_element_shape_dtype['image'].sharding,
input_shape_dtypes=(datataset_element_shape_dtype['image'],),
train_steps=train_steps,
extra_rng_keys=tuple(config.get('extra_rng_keys', [])),
seed=config.get('seed', 0))
Expand Down
7 changes: 5 additions & 2 deletions vmoe/train/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,11 +241,14 @@ class MakeCreateTrainStateFnTest(absltest.TestCase):
def test(self, mock_create_optimizer):
mock_create_optimizer.return_value = trainer.optimizer.optax.adam(
learning_rate=0.1)
mesh = jax.sharding.Mesh(np.asarray(jax.devices()), ('d',))
sharding = jax.sharding.NamedSharding(mesh, PartitionSpec(('d',)))
shape_dtype = jax.ShapeDtypeStruct(
shape=(8, 32, 32, 3), dtype=jnp.float32, sharding=sharding)
train_state_init_fn = trainer.make_create_train_state_fn(
model=trainer.nn.Conv(features=16, kernel_size=(3, 3)),
optimizer_config={},
input_shape=(8, 32, 32, 3),
input_axis_resources=PartitionSpec(('d',)),
input_shape_dtypes=(shape_dtype,),
train_steps=10,
seed=0,
extra_rng_keys=('foo',))
Expand Down

0 comments on commit f8b56f6

Please sign in to comment.