Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions scenic/train_lib/tests/test_train_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Copyright 2024 The Scenic Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Unit tests for training utility functions in train_lib.train_utils.

This file covers tests for the Chrono context manager.
"""

from unittest import mock

from absl.testing import absltest
from scenic.train_lib import train_utils


class ChronoPausedTest(absltest.TestCase):
"""Tests the Chrono.paused context manager for correct behavior."""

@mock.patch("jax.block_until_ready", autospec=True)
@mock.patch("time.monotonic")
def test_paused_context_manager_waits_executes_the_code_block_and_resumes(
self, mock_monotonic, mock_block_until_ready
):
"""Tests the Chrono.paused context manager in a normal flow."""
chrono = train_utils.Chrono()
before_pause, after_pause, after_resume = 100.0, 101.1, 105.5
mock_monotonic.side_effect = [before_pause, after_pause, after_resume]
wait_for_ops = [mock.MagicMock()] # Dummy operations to await.

with chrono.paused(wait_for=wait_for_ops):
mock_block_until_ready.assert_called_once_with(wait_for_ops)
self.assertEqual(chrono.pause_start, before_pause)

self.assertIsNone(chrono.pause_start) # Should be reset by resume
self.assertEqual(chrono.paused_time, after_pause - before_pause)
self.assertEqual(mock_monotonic.call_count, 3) # init, pause, and resume

@mock.patch("jax.block_until_ready", autospec=True)
@mock.patch("time.monotonic")
def test_paused_context_manager_with_exception_calls_resume(
self, mock_monotonic, mock_block_until_ready
):
"""Tests that Chrono.resume is called even if an exception occurs."""
chrono = train_utils.Chrono()
before_pause, after_pause, after_resume = 100.0, 101.1, 105.5
mock_monotonic.side_effect = [before_pause, after_pause, after_resume]
wait_for_ops = ("dummy_op",)
custom_exception = ValueError("Test exception inside context")

# Disable linting since the assertion against the exception must be done
# within the context manager. The assertions below the context blocks are
# not affected by the exception, despite the highlighting (or dimming).
with self.assertRaises(ValueError) as context: # pylint: disable=g-error-prone-assert-raises
with chrono.paused(wait_for=wait_for_ops):
mock_block_until_ready.assert_called_once_with(wait_for_ops)
self.assertEqual(chrono.pause_start, before_pause)
raise custom_exception
self.assertEqual(context.exception, custom_exception)

self.assertIsNone(chrono.pause_start) # Should be reset by resume
self.assertEqual(chrono.paused_time, after_pause - before_pause)
self.assertEqual(mock_monotonic.call_count, 3) # init, pause, and resume


if __name__ == "__main__":
absltest.main()
29 changes: 26 additions & 3 deletions scenic/train_lib/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@
"""Utility functions for Training."""

import collections.abc as collections
import contextlib
import copy
import functools
import os
import re
import time
from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Tuple, Union
from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Sequence, Tuple, Union

from absl import logging
from clu import metric_writers
Expand Down Expand Up @@ -280,7 +281,7 @@ def _initialize_model(rngs):
**dummy_input,
train=False,
debug=False,
**model_kwargs
**model_kwargs,
)
),
'params',
Expand Down Expand Up @@ -486,7 +487,8 @@ def _initialize_model(rngs):
with jax.default_device(jax.local_devices(backend='cpu')[0]):
init_model_state, init_params = flax.core.pop(
flax.core.freeze(nn.init(fn=init_fn, module=model_def)(rngs)),
'params')
'params',
)
# Set bias in the head to low value, such that loss is small initially.
if (
config.get('init_head_bias', None) is not None
Expand Down Expand Up @@ -1261,6 +1263,27 @@ def load(self, ckpt={}): # pylint: disable=dangerous-default-value
self.accum_pause_time = ckpt.get('accum_pause_time', 0.0)
self.accum_examples_seen = ckpt.get('accum_examples_seen', 0)

@contextlib.contextmanager
def paused(self, wait_for: Iterable[Any] = ()):
"""A context manager for temporarily pausing to await arguments.

Example:
with chrono.paused(wait_for=some_jax_operations):
# Operations to perform while chrono is paused
...

Args:
wait_for: An iterable of JAX operations to wait for before pausing.

Yields:
The Chrono object.
"""
self.pause(wait_for=wait_for)
try:
yield self
finally:
self.resume()


def barrier_across_hosts():
"""Ensure all hosts stay up until the end, otherwise the program may hang."""
Expand Down