Skip to content

Commit 2475509

Browse files
author
Scenic Authors
committed
Small refactor and cleanup of trainer.py.
PiperOrigin-RevId: 752362747
1 parent 346dfb5 commit 2475509

File tree

2 files changed

+102
-3
lines changed

2 files changed

+102
-3
lines changed
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Copyright 2024 The Scenic Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Unit tests for training utility functions in train_lib.train_utils.
16+
17+
This file covers tests for the Chrono context manager.
18+
"""
19+
20+
from unittest import mock
21+
22+
from absl.testing import absltest
23+
from scenic.train_lib import train_utils
24+
25+
26+
class ChronoPausedTest(absltest.TestCase):
27+
"""Tests the Chrono.paused context manager for correct behavior."""
28+
29+
@mock.patch("jax.block_until_ready", autospec=True)
30+
@mock.patch("time.monotonic")
31+
def test_paused_context_manager_waits_executes_the_code_block_and_resumes(
32+
self, mock_monotonic, mock_block_until_ready
33+
):
34+
"""Tests the Chrono.paused context manager in a normal flow."""
35+
chrono = train_utils.Chrono()
36+
before_pause, after_pause, after_resume = 100.0, 101.1, 105.5
37+
mock_monotonic.side_effect = [before_pause, after_pause, after_resume]
38+
wait_for_ops = [mock.MagicMock()] # Dummy operations to await.
39+
40+
with chrono.paused(wait_for=wait_for_ops):
41+
mock_block_until_ready.assert_called_once_with(wait_for_ops)
42+
self.assertEqual(chrono.pause_start, before_pause)
43+
44+
self.assertIsNone(chrono.pause_start) # Should be reset by resume
45+
self.assertEqual(chrono.paused_time, after_pause - before_pause)
46+
self.assertEqual(mock_monotonic.call_count, 3) # init, pause, and resume
47+
48+
@mock.patch("jax.block_until_ready", autospec=True)
49+
@mock.patch("time.monotonic")
50+
def test_paused_context_manager_with_exception_calls_resume(
51+
self, mock_monotonic, mock_block_until_ready
52+
):
53+
"""Tests that Chrono.resume is called even if an exception occurs."""
54+
chrono = train_utils.Chrono()
55+
before_pause, after_pause, after_resume = 100.0, 101.1, 105.5
56+
mock_monotonic.side_effect = [before_pause, after_pause, after_resume]
57+
wait_for_ops = ("dummy_op",)
58+
custom_exception = ValueError("Test exception inside context")
59+
60+
# Disable linting since the assertion against the exception must be done
61+
# within the context manager. The assertions below the context blocks are
62+
# not affected by the exception, despite the highlighting (or dimming).
63+
with self.assertRaises(ValueError) as context: # pylint: disable=g-error-prone-assert-raises
64+
with chrono.paused(wait_for=wait_for_ops):
65+
mock_block_until_ready.assert_called_once_with(wait_for_ops)
66+
self.assertEqual(chrono.pause_start, before_pause)
67+
raise custom_exception
68+
self.assertEqual(context.exception, custom_exception)
69+
70+
self.assertIsNone(chrono.pause_start) # Should be reset by resume
71+
self.assertEqual(chrono.paused_time, after_pause - before_pause)
72+
self.assertEqual(mock_monotonic.call_count, 3) # init, pause, and resume
73+
74+
75+
if __name__ == "__main__":
76+
absltest.main()

scenic/train_lib/train_utils.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,13 @@
1515
"""Utility functions for Training."""
1616

1717
import collections.abc as collections
18+
import contextlib
1819
import copy
1920
import functools
2021
import os
2122
import re
2223
import time
23-
from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Tuple, Union
24+
from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Sequence, Tuple, Union
2425

2526
from absl import logging
2627
from clu import metric_writers
@@ -280,7 +281,7 @@ def _initialize_model(rngs):
280281
**dummy_input,
281282
train=False,
282283
debug=False,
283-
**model_kwargs
284+
**model_kwargs,
284285
)
285286
),
286287
'params',
@@ -486,7 +487,8 @@ def _initialize_model(rngs):
486487
with jax.default_device(jax.local_devices(backend='cpu')[0]):
487488
init_model_state, init_params = flax.core.pop(
488489
flax.core.freeze(nn.init(fn=init_fn, module=model_def)(rngs)),
489-
'params')
490+
'params',
491+
)
490492
# Set bias in the head to low value, such that loss is small initially.
491493
if (
492494
config.get('init_head_bias', None) is not None
@@ -1261,6 +1263,27 @@ def load(self, ckpt={}): # pylint: disable=dangerous-default-value
12611263
self.accum_pause_time = ckpt.get('accum_pause_time', 0.0)
12621264
self.accum_examples_seen = ckpt.get('accum_examples_seen', 0)
12631265

1266+
@contextlib.contextmanager
1267+
def paused(self, wait_for: Iterable[Any] = ()):
1268+
"""A context manager for temporarily pausing to await arguments.
1269+
1270+
Example:
1271+
with chrono.paused(wait_for=some_jax_operations):
1272+
# Operations to perform while chrono is paused
1273+
...
1274+
1275+
Args:
1276+
wait_for: An iterable of JAX operations to wait for before pausing.
1277+
1278+
Yields:
1279+
The Chrono object.
1280+
"""
1281+
self.pause(wait_for=wait_for)
1282+
try:
1283+
yield self
1284+
finally:
1285+
self.resume()
1286+
12641287

12651288
def barrier_across_hosts():
12661289
"""Ensure all hosts stay up until the end, otherwise the program may hang."""

0 commit comments

Comments
 (0)