Skip to content

Commit

Permalink
1. decorate SaveBestCheckpointer class with @gin.configurable.
Browse files Browse the repository at this point in the history
2. modify create_orbax_checkpoint_manager function to take SaveBestCheckpointer instance instead of directing calling __init__.
3. Move make_train_state from checkpoints_test to test_utils.

PiperOrigin-RevId: 587030944
  • Loading branch information
liangyaning33 authored and t5-copybara committed Dec 1, 2023
1 parent 69ee4d5 commit fc7e250
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 4 deletions.
2 changes: 2 additions & 0 deletions t5x/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import flax
from flax import serialization
from flax import traverse_util
import gin
import jax
from jax import monitoring
import jax.config
Expand Down Expand Up @@ -1406,6 +1407,7 @@ def _try_fill_metric_run_and_tag_names(metric_name: str,


# TODO(b/216649487): Replace with BestCheckpointManager.
@gin.configurable
class SaveBestCheckpointer(Checkpointer):
"""A Checkpointer class that keeps checkpoints based on 'best' metrics.
Expand Down
65 changes: 64 additions & 1 deletion t5x/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@

import contextlib
import dataclasses
import functools
import itertools
import math
import operator
from typing import Generator, List, Sequence, Tuple
from typing import Any, Generator, List, Mapping, Sequence, Tuple
import unittest

import jax
Expand All @@ -29,6 +31,7 @@
import t5.data
from t5x import adafactor
from t5x import models
from t5x import optimizers
from t5x import partitioning
from t5x import train_state as train_state_lib
from t5x.checkpoint_importer import LazyArray
Expand Down Expand Up @@ -108,6 +111,66 @@ def make_devices(nx: int,
return devices


def make_train_state_base(
*,
step: int,
params: Mapping[str, Any],
param_states: Mapping[str, Any],
flax_optimizer_def: optimizers.OptimizerDefType = optimizers.sgd(0.1),
) -> train_state_lib.TrainState:
"""Helper to construct a train state for testing."""
optimizer = optimizers.Optimizer(
flax_optimizer_def,
state=optimizers.OptimizerState( # pytype: disable=wrong-arg-types # jax-ndarray
step=step, param_states=param_states
),
target=params,
)

return train_state_lib.FlaxOptimTrainState(optimizer)


def make_train_state_replicated(
global_input_shape,
step=42,
dtype=np.float32,
):
"""Helper to construct a train state for testing."""
bias = np.ones(global_input_shape, dtype=dtype)
kernel = np.arange(math.prod(global_input_shape), dtype=dtype).reshape(
global_input_shape
)
train_state = make_train_state_base(
step=np.int32(step),
params={'bias': bias * 2, 'kernel': kernel * 2},
param_states={ # only cast targets (above)
'bias': bias.astype(np.float32),
'kernel': kernel.astype(np.float32),
},
)
return train_state


def make_train_state(
global_mesh, global_input_shape, mesh_axes, step=42, dtype=np.float32
):
"""Construct a train state for testing."""
train_state = make_train_state_replicated(
global_input_shape, step=step, dtype=dtype
)

return jax.tree_map(
functools.partial(
create_sharded_array,
global_shape=global_input_shape,
global_mesh=global_mesh,
mesh_axes=mesh_axes,
),
train_state,
is_leaf=lambda x: isinstance(x, np.ndarray),
)


def get_t5_test_model(**config_overrides) -> models.EncoderDecoderModel:
"""Returns a tiny T5 1.1 model to use for testing."""
tiny_config = network.T5Config(
Expand Down
16 changes: 15 additions & 1 deletion t5x/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1298,7 +1298,21 @@ def _get_extra_kwargs(cfg):
extra_kwargs = _get_default_args(cfg.checkpointer_cls)
else:
if issubclass(cfg.checkpointer_cls, checkpoints.SaveBestCheckpointer):
extra_kwargs = _get_default_args(cfg.checkpointer_cls.__init__)
save_best_checkpointer = checkpoints.SaveBestCheckpointer(
train_state=train_state,
checkpoints_dir=model_dir,
partitioner=partitioner,
)
extra_kwargs = {
'metric_name_to_monitor': (
save_best_checkpointer._metric_name_to_monitor # pylint: disable=protected-access
),
'metric_mode': save_best_checkpointer._metric_mode, # pylint: disable=protected-access
'keep_checkpoints_without_metrics': (
save_best_checkpointer._keep_checkpoints_without_metrics # pylint: disable=protected-access
),
'force_keep_period': save_best_checkpointer._force_keep_period, # pylint: disable=protected-access
}
return extra_kwargs

save_dtype = None
Expand Down
20 changes: 18 additions & 2 deletions t5x/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,10 +706,18 @@ def test_create_orbax_checkpoint_manager(
restore_dataset=False,
)

global_mesh = test_utils.create_global_mesh((4, 2), ("x", "y"))
mesh_axes = partitioning.PartitionSpec("x", "y")
global_input_shape = (8, 2)

train_state = test_utils.make_train_state(
global_mesh, global_input_shape, mesh_axes
)

manager = utils.create_orbax_checkpoint_manager(
save_cfg=save_cfg,
restore_cfg=restore_cfg,
train_state=mock.Mock(),
train_state=train_state,
partitioner=mock_partitioner,
ds_iter=mock.Mock(),
model_dir=directory,
Expand Down Expand Up @@ -773,10 +781,18 @@ def test_create_orbax_checkpoint_manager_from_checkpointer(
restore_dataset=False,
)

global_mesh = test_utils.create_global_mesh((4, 2), ("x", "y"))
mesh_axes = partitioning.PartitionSpec("x", "y")
global_input_shape = (8, 2)

train_state = test_utils.make_train_state(
global_mesh, global_input_shape, mesh_axes
)

manager = utils.create_orbax_checkpoint_manager(
save_cfg=save_cfg,
restore_cfg=restore_cfg,
train_state=mock.Mock(),
train_state=train_state,
partitioner=mock_partitioner,
ds_iter=mock.Mock(),
model_dir=directory,
Expand Down

0 comments on commit fc7e250

Please sign in to comment.