Skip to content

Commit

Permalink
Report training flops and optimal device time computed by XLA while t…
Browse files Browse the repository at this point in the history
…raining.

This is useful to plot loss/accuracy w.r.t. total training flops or total device time.

Be aware that if your model uses XLA while_loop, these costs might be inaccurate, since only the cost of a single step is used. But this is not the case for most of our models.

PiperOrigin-RevId: 618232607
  • Loading branch information
jpuigcerver authored and copybara-github committed Mar 22, 2024
1 parent 647a51c commit 7d2f7e0
Show file tree
Hide file tree
Showing 6 changed files with 95 additions and 28 deletions.
23 changes: 16 additions & 7 deletions vmoe/evaluate/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,11 +179,19 @@ def compile_for_dataset(name, params, train_step):
datasets_element_shape_dtype[name]['labels'],
datasets_element_shape_dtype[name][VALID_KEY]).compile()
t1 = time.time()
metric_writer.write_scalars(train_step, {f'{name}/compile_secs': t1 - t0})
metrics = {f'{name}/compile_secs': t1 - t0}
step_flops_per_device, step_seconds_per_device = (
utils.get_flops_and_seconds_per_device(eval_step_pjit_ds))
if step_flops_per_device is not None:
metrics[f'{name}/step_flops_per_device'] = step_flops_per_device
if step_seconds_per_device is not None:
metrics[f'{name}/step_seconds_per_device'] = step_seconds_per_device
metric_writer.write_scalars(train_step, metrics)
return eval_step_pjit_ds

def callback_fn(step: int, t: Optional[float], params: PyTree):
def callback_fn(step: int, t: Optional[float], params: PyTree, **kwargs):
del t # Unused.
metrics = {}
for name, dataset in datasets.items():
eval_step_pjit_ds = compile_for_dataset(name, params, step)
# NOTE: Fold-in the dataset name and/or the train_step to the seed
Expand All @@ -200,14 +208,15 @@ def callback_fn(step: int, t: Optional[float], params: PyTree):
params=params)
t1 = time.time()
with jax.spmd_mode('allow_all'):
metric_writer.write_scalars(step, {
f'{name}/prec@1': eval_state.sum_correct / eval_state.num,
f'{name}/loss': eval_state.sum_loss / eval_state.num,
f'{name}/duration_secs': t1 - t0,
})
metrics[f'{name}/prec@1'] = eval_state.sum_correct / eval_state.num
metrics[f'{name}/loss'] = eval_state.sum_loss / eval_state.num
metrics[f'{name}/duration_secs'] = t1 - t0
# Reset iterator for the next evaluation.
dataset.reset()

metrics = metrics | {k: v for k, v in kwargs.items() if v is not None}
metric_writer.write_scalars(step, metrics)

if report_progress is None:
return callback_fn
else:
Expand Down
24 changes: 9 additions & 15 deletions vmoe/evaluate/evaluator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,29 +125,23 @@ def test_evaluate_multiple_datasets(self):
for step in range(1, 10):
action(step=step, params={})
call_args_list = metric_writer.write_scalars.call_args_list
self.assertLen(call_args_list, 6)
self.assertEqual(call_args_list[0],
mock.call(4, {'dataset1/compile_secs': mock.ANY}))
self.assertEqual(call_args_list[1],
self.assertLen(call_args_list, 4)
# First two calls are during compile.
# The arguments depends on the device architecture.
self.assertEqual(call_args_list[2],
mock.call(4,
{'dataset1/duration_secs': mock.ANY,
'dataset1/loss': mock.ANY,
'dataset1/prec@1': mock.ANY}))
self.assertEqual(call_args_list[2],
mock.call(4, {'dataset2/compile_secs': mock.ANY}))
self.assertEqual(call_args_list[3],
mock.call(4,
{'dataset2/duration_secs': mock.ANY,
'dataset1/prec@1': mock.ANY,
'dataset2/duration_secs': mock.ANY,
'dataset2/loss': mock.ANY,
'dataset2/prec@1': mock.ANY}))
self.assertEqual(call_args_list[4],
self.assertEqual(call_args_list[3],
mock.call(8,
{'dataset1/duration_secs': mock.ANY,
'dataset1/loss': mock.ANY,
'dataset1/prec@1': mock.ANY}))
self.assertEqual(call_args_list[5],
mock.call(8,
{'dataset2/duration_secs': mock.ANY,
'dataset1/prec@1': mock.ANY,
'dataset2/duration_secs': mock.ANY,
'dataset2/loss': mock.ANY,
'dataset2/prec@1': mock.ANY}))

Expand Down
3 changes: 2 additions & 1 deletion vmoe/evaluate/fewshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def make_fewshot_state_pjit(seed, sub_seed, step):
rngs = {}
return FewShotState(rngs=rngs)

def callback_fn(step: int, t: Optional[float], variables: PyTree):
def callback_fn(step: int, t: Optional[float], variables: PyTree, **kwargs):
del t # Unused.
# Two-level dict: first is dataset name, second level is (shot, l2_reg).
all_results = {}
Expand Down Expand Up @@ -213,6 +213,7 @@ def callback_fn(step: int, t: Optional[float], variables: PyTree):
for sub_seed in range(seeds_per_step)
) / seeds_per_step
metrics[f'{main_task_prefix}/{name}/{shot}shot'] = accuracy
metrics = metrics | {k: v for k, v in kwargs.items() if v is not None}
metric_writer.write_scalars(step, metrics)
self.last_metrics = metrics

Expand Down
41 changes: 36 additions & 5 deletions vmoe/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,23 @@ def initialize_train_state_from_checkpoint(
raise ValueError(f'Unknown initialization method: {name!r}')


def make_train_cost_fn(compiled_fn) -> Callable[[int], Dict[str, float]]:
"""Returns a function that computes the total training cost at a given step."""
flops_per_device, seconds_per_device = utils.get_flops_and_seconds_per_device(
compiled_fn
)

def fn(step):
output = {}
if flops_per_device is not None:
output['flops'] = flops_per_device * step * jax.device_count()
if seconds_per_device is not None:
output['device_seconds'] = seconds_per_device * step * jax.device_count()
return output

return fn


def mixup(
rng: PRNGKey,
tree: PyTree,
Expand Down Expand Up @@ -785,24 +802,38 @@ def _save_checkpoint(step, ts, it, force=False):
if init_step == 0 and not tf.io.gfile.exists(os.path.join(workdir, 'ckpt/0')):
multihost_utils.sync_devices('training:ckpt-first')
_save_checkpoint(init_step, train_state, tr_iter, force=True)
# Explicitly compile train_step here and report the compilation time.
# Explicitly compile train_step here.
t0 = time.time()
train_step_pjit = train_step_pjit.lower(
train_state,
datataset_element_shape_dtype['image'],
datataset_element_shape_dtype['labels']).compile()
t1 = time.time()
# Report compilation time, and flops and optimal seconds per step and device.
writer.write_scalars(init_step + 1, {'train/compile_secs': t1 - t0})
train_step_flops_per_device, train_step_seconds_per_device = (
utils.get_flops_and_seconds_per_device(train_step_pjit))
if train_step_flops_per_device:
writer.write_scalars(
init_step + 1,
{'train/step_flops_per_device': train_step_flops_per_device})
if train_step_seconds_per_device:
writer.write_scalars(
init_step + 1,
{'train/step_seconds_per_device': train_step_seconds_per_device})
train_cost_fn = make_train_cost_fn(train_step_pjit)
for step, batch in zip(range(init_step + 1, train_steps + 1), tr_iter):
profile_hook(step)
with jax.profiler.StepTraceAnnotation('train', step_num=step):
train_state, metrics = train_step_pjit(train_state, batch['image'],
batch['labels'])
progress_hook(
step, scalar_metrics={f'train/{k}': v for k, v in metrics.items()})
progress_hook(step, scalar_metrics=(
train_cost_fn(step) | {f'train/{k}': v for k, v in metrics.items()}
))
_save_checkpoint(step, train_state, tr_iter)
evaluation_hook(step, params=train_state.params)
fewshot_hook(step, variables={'params': train_state.params})
evaluation_hook(step, params=train_state.params, **train_cost_fn(step))
fewshot_hook(step, variables={'params': train_state.params},
**train_cost_fn(step))
ckpt_manager.wait_until_finished()
if not tf.io.gfile.exists(os.path.join(workdir, f'ckpt/{train_steps}')):
multihost_utils.sync_devices('training:ckpt-last')
Expand Down
16 changes: 16 additions & 0 deletions vmoe/train/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,22 @@ def test(self, mock_create_optimizer):
self.assertSetEqual(set(train_state.rngs.keys()), {'foo'})


class MakeTrainCostFnTest(parameterized.TestCase):

@parameterized.parameters(
((None, None), {}),
((3., None), {'flops': 3. * 2 * 42}),
((None, 3.), {'device_seconds': 3. * 2 * 42}),
((5., 3.), {'flops': 5. * 2 * 42, 'device_seconds': 3. * 2 * 42}),
)
@mock.patch.object(jax, 'device_count', return_value=2)
def test(self, flops_and_seconds_per_device, expected, _):
with mock.patch.object(trainer.utils, 'get_flops_and_seconds_per_device',
return_value=flops_and_seconds_per_device):
fn = trainer.make_train_cost_fn(compiled_fn=mock.Mock())
self.assertEqual(fn(42), expected)


class MixupTest(parameterized.TestCase):

@parameterized.named_parameters(
Expand Down
16 changes: 16 additions & 0 deletions vmoe/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,22 @@
PRNGKey = jax.Array


def get_flops_and_seconds_per_device(
compiled_fn
) -> Tuple[float | None, float | None]:
"""Returns the FLOPs and optimal seconds per device of a compiled function."""
cost_analysis = compiled_fn.cost_analysis()[0]
flops_per_device = cost_analysis.get('flops')
seconds_per_device = cost_analysis.get('optimal_seconds')
# Note: XLA returns negative FLOPs and optimal_seconds for some platforms
# (e.g. GPUs).
if flops_per_device is not None and flops_per_device <= 0:
flops_per_device = None
if seconds_per_device is not None and seconds_per_device <= 0:
seconds_per_device = None
return flops_per_device, seconds_per_device


def make_rngs(rng_keys: Tuple[str, ...], seed: int) -> Dict[str, PRNGKey]:
if not rng_keys:
return dict()
Expand Down

0 comments on commit 7d2f7e0

Please sign in to comment.