Skip to content

Commit 88b033c

Browse files
afrozenatorcopybara-github
authored andcommitted
[TRAX] v1.3.3 and Store checkpoint with unreplicated weights/state in Loop.
PiperOrigin-RevId: 323172102
1 parent 023dbce commit 88b033c

File tree

2 files changed

+48
-37
lines changed

2 files changed

+48
-37
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
setup(
2323
name='trax',
24-
version='1.3.2',
24+
version='1.3.3',
2525
description='Trax',
2626
long_description=(
2727
'Trax helps you understand deep learning. We start with basic maths and'

trax/supervised/training.py

Lines changed: 47 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -159,36 +159,7 @@ def __init__(self, model, tasks, eval_model=None, eval_tasks=None,
159159
# unnecessary, i.e. random_seed was set.
160160
if random_seed is None and self._n_hosts > 1:
161161
logging.info('Syncing weights/state across %d hosts.', self._n_hosts)
162-
163-
if logging.vlog_is_on(1):
164-
logging.info(
165-
'Input training weights shape: %s',
166-
fastmath.nested_map(lambda x: x.shape,
167-
self._model_in_training.weights))
168-
logging.info('Input training weights: %s',
169-
self._model_in_training.weights)
170-
logging.info('Input training state: %s', self._model_in_training.state)
171-
logging.info('Input eval weights: %s', self._eval_model.weights)
172-
logging.info('Input eval state: %s', self._eval_model.state)
173-
174-
(self._model_in_training.weights, self._model_in_training.state,
175-
self._eval_model.weights, self._eval_model.state) = self._unreplicate(
176-
_make_weights_and_state_same_across_hosts(
177-
self._for_n_devices(
178-
(self._model_in_training.weights,
179-
self._model_in_training.state, self._eval_model.weights,
180-
self._eval_model.state))))
181-
182-
if logging.vlog_is_on(1):
183-
logging.info(
184-
'Output training weights shape: %s',
185-
fastmath.nested_map(lambda x: x.shape,
186-
self._model_in_training.weights))
187-
logging.info('Output training weights: %s',
188-
self._model_in_training.weights)
189-
logging.info('Output training state: %s', self._model_in_training.state)
190-
logging.info('Output eval weights: %s', self._eval_model.weights)
191-
logging.info('Output eval state: %s', self._eval_model.state)
162+
self._sync_weights_and_state_across_hosts()
192163

193164
self._task.optimizer.tree_init(self._model_in_training.weights)
194165

@@ -236,6 +207,39 @@ def __init__(self, model, tasks, eval_model=None, eval_tasks=None,
236207
if self._output_dir is None:
237208
_log('Will not write evaluation metrics, because output_dir is None.')
238209

210+
def _sync_weights_and_state_across_hosts(self):
211+
"""Sync weights and state across all the hosts in the computation."""
212+
213+
if logging.vlog_is_on(1):
214+
logging.debug(
215+
'Input training weights shape: %s',
216+
fastmath.nested_map(lambda x: x.shape,
217+
self._model_in_training.weights))
218+
logging.debug('Input training weights: %s',
219+
self._model_in_training.weights)
220+
logging.debug('Input training state: %s', self._model_in_training.state)
221+
logging.debug('Input eval weights: %s', self._eval_model.weights)
222+
logging.debug('Input eval state: %s', self._eval_model.state)
223+
224+
(self._model_in_training.weights, self._model_in_training.state,
225+
self._eval_model.weights, self._eval_model.state) = self._unreplicate(
226+
_make_weights_and_state_same_across_hosts(
227+
self._for_n_devices(
228+
(self._model_in_training.weights,
229+
self._model_in_training.state, self._eval_model.weights,
230+
self._eval_model.state))))
231+
232+
if logging.vlog_is_on(1):
233+
logging.debug(
234+
'Output training weights shape: %s',
235+
fastmath.nested_map(lambda x: x.shape,
236+
self._model_in_training.weights))
237+
logging.debug('Output training weights: %s',
238+
self._model_in_training.weights)
239+
logging.debug('Output training state: %s', self._model_in_training.state)
240+
logging.debug('Output eval weights: %s', self._eval_model.weights)
241+
logging.debug('Output eval state: %s', self._eval_model.state)
242+
239243
def run(self, n_steps=1):
240244
"""Runs this training loop for n steps.
241245
@@ -280,12 +284,20 @@ def run(self, n_steps=1):
280284
step_acc += 1
281285
for metric_name, value in optimizer_metrics.items():
282286
optimizer_metrics_acc[metric_name] += value
283-
if self._checkpoint_at(self.step):
284-
self.save_checkpoint(weights, state, slots)
285-
if self._eval_at(self.step):
287+
288+
should_checkpoint = self._checkpoint_at(self.step)
289+
should_eval = self._eval_at(self.step)
290+
unr_weights, unr_state, unr_slots = None, None, None
291+
if should_checkpoint or should_eval:
292+
unr_weights, unr_state, unr_slots = self._unreplicate(
293+
(weights, state, slots))
294+
295+
if should_checkpoint:
296+
self.save_checkpoint(unr_weights, unr_state, unr_slots)
297+
if should_eval:
286298
elapsed_time = time.time() - start_time
287-
self._model_in_training.weights = weights
288-
self._model_in_training.state = state
299+
self._model_in_training.weights = unr_weights
300+
self._model_in_training.state = unr_state
289301
self._eval_model.weights = self._model.weights
290302
self._log_training_progress(
291303
total_loss=loss_acc, n_steps=step_acc, elapsed_time=elapsed_time,
@@ -387,7 +399,6 @@ def _run_one_step(self, weights, state, slots, opt_params):
387399
if logging.vlog_is_on(1) and ((step & step - 1) == 0):
388400
# Prints every power of two, if debugging is enabled.
389401
logging.info('step[%d]', step)
390-
# logging.info('batch[%s]', batch)
391402
logging.info('opt_params[%s]', opt_params)
392403
logging.info('weights[%s]', weights)
393404

0 commit comments

Comments
 (0)