@@ -159,36 +159,7 @@ def __init__(self, model, tasks, eval_model=None, eval_tasks=None,
159
159
# unnecessary, i.e. random_seed was set.
160
160
if random_seed is None and self ._n_hosts > 1 :
161
161
logging .info ('Syncing weights/state across %d hosts.' , self ._n_hosts )
162
-
163
- if logging .vlog_is_on (1 ):
164
- logging .info (
165
- 'Input training weights shape: %s' ,
166
- fastmath .nested_map (lambda x : x .shape ,
167
- self ._model_in_training .weights ))
168
- logging .info ('Input training weights: %s' ,
169
- self ._model_in_training .weights )
170
- logging .info ('Input training state: %s' , self ._model_in_training .state )
171
- logging .info ('Input eval weights: %s' , self ._eval_model .weights )
172
- logging .info ('Input eval state: %s' , self ._eval_model .state )
173
-
174
- (self ._model_in_training .weights , self ._model_in_training .state ,
175
- self ._eval_model .weights , self ._eval_model .state ) = self ._unreplicate (
176
- _make_weights_and_state_same_across_hosts (
177
- self ._for_n_devices (
178
- (self ._model_in_training .weights ,
179
- self ._model_in_training .state , self ._eval_model .weights ,
180
- self ._eval_model .state ))))
181
-
182
- if logging .vlog_is_on (1 ):
183
- logging .info (
184
- 'Output training weights shape: %s' ,
185
- fastmath .nested_map (lambda x : x .shape ,
186
- self ._model_in_training .weights ))
187
- logging .info ('Output training weights: %s' ,
188
- self ._model_in_training .weights )
189
- logging .info ('Output training state: %s' , self ._model_in_training .state )
190
- logging .info ('Output eval weights: %s' , self ._eval_model .weights )
191
- logging .info ('Output eval state: %s' , self ._eval_model .state )
162
+ self ._sync_weights_and_state_across_hosts ()
192
163
193
164
self ._task .optimizer .tree_init (self ._model_in_training .weights )
194
165
@@ -236,6 +207,39 @@ def __init__(self, model, tasks, eval_model=None, eval_tasks=None,
236
207
if self ._output_dir is None :
237
208
_log ('Will not write evaluation metrics, because output_dir is None.' )
238
209
210
+ def _sync_weights_and_state_across_hosts (self ):
211
+ """Sync weights and state across all the hosts in the computation."""
212
+
213
+ if logging .vlog_is_on (1 ):
214
+ logging .debug (
215
+ 'Input training weights shape: %s' ,
216
+ fastmath .nested_map (lambda x : x .shape ,
217
+ self ._model_in_training .weights ))
218
+ logging .debug ('Input training weights: %s' ,
219
+ self ._model_in_training .weights )
220
+ logging .debug ('Input training state: %s' , self ._model_in_training .state )
221
+ logging .debug ('Input eval weights: %s' , self ._eval_model .weights )
222
+ logging .debug ('Input eval state: %s' , self ._eval_model .state )
223
+
224
+ (self ._model_in_training .weights , self ._model_in_training .state ,
225
+ self ._eval_model .weights , self ._eval_model .state ) = self ._unreplicate (
226
+ _make_weights_and_state_same_across_hosts (
227
+ self ._for_n_devices (
228
+ (self ._model_in_training .weights ,
229
+ self ._model_in_training .state , self ._eval_model .weights ,
230
+ self ._eval_model .state ))))
231
+
232
+ if logging .vlog_is_on (1 ):
233
+ logging .debug (
234
+ 'Output training weights shape: %s' ,
235
+ fastmath .nested_map (lambda x : x .shape ,
236
+ self ._model_in_training .weights ))
237
+ logging .debug ('Output training weights: %s' ,
238
+ self ._model_in_training .weights )
239
+ logging .debug ('Output training state: %s' , self ._model_in_training .state )
240
+ logging .debug ('Output eval weights: %s' , self ._eval_model .weights )
241
+ logging .debug ('Output eval state: %s' , self ._eval_model .state )
242
+
239
243
def run (self , n_steps = 1 ):
240
244
"""Runs this training loop for n steps.
241
245
@@ -280,12 +284,20 @@ def run(self, n_steps=1):
280
284
step_acc += 1
281
285
for metric_name , value in optimizer_metrics .items ():
282
286
optimizer_metrics_acc [metric_name ] += value
283
- if self ._checkpoint_at (self .step ):
284
- self .save_checkpoint (weights , state , slots )
285
- if self ._eval_at (self .step ):
287
+
288
+ should_checkpoint = self ._checkpoint_at (self .step )
289
+ should_eval = self ._eval_at (self .step )
290
+ unr_weights , unr_state , unr_slots = None , None , None
291
+ if should_checkpoint or should_eval :
292
+ unr_weights , unr_state , unr_slots = self ._unreplicate (
293
+ (weights , state , slots ))
294
+
295
+ if should_checkpoint :
296
+ self .save_checkpoint (unr_weights , unr_state , unr_slots )
297
+ if should_eval :
286
298
elapsed_time = time .time () - start_time
287
- self ._model_in_training .weights = weights
288
- self ._model_in_training .state = state
299
+ self ._model_in_training .weights = unr_weights
300
+ self ._model_in_training .state = unr_state
289
301
self ._eval_model .weights = self ._model .weights
290
302
self ._log_training_progress (
291
303
total_loss = loss_acc , n_steps = step_acc , elapsed_time = elapsed_time ,
@@ -387,7 +399,6 @@ def _run_one_step(self, weights, state, slots, opt_params):
387
399
if logging .vlog_is_on (1 ) and ((step & step - 1 ) == 0 ):
388
400
# Prints every power of two, if debugging is enabled.
389
401
logging .info ('step[%d]' , step )
390
- # logging.info('batch[%s]', batch)
391
402
logging .info ('opt_params[%s]' , opt_params )
392
403
logging .info ('weights[%s]' , weights )
393
404
0 commit comments