1717import os
1818import datetime
1919import functools
20+ from pprint import pprint
2021import numpy as np
2122import threading
2223from concurrent .futures import ThreadPoolExecutor
@@ -209,7 +210,11 @@ def prepare_sample_eval(features):
209210
210211 def start_training (self ):
211212
212- pipeline = self .load_checkpoint ()
213+ pipeline , opt_state , step = self .load_checkpoint ()
214+ restore_args = {}
215+ if opt_state and step :
216+ restore_args = {"opt_state" : opt_state , "step" :step }
217+ del opt_state
213218 if self .config .enable_ssim :
214219 # Generate a sample before training to compare against generated sample after training.
215220 pretrained_video_path = generate_sample (self .config , pipeline , filename_prefix = "pre-training-" )
@@ -228,7 +233,7 @@ def start_training(self):
228233 pipeline .scheduler_state = scheduler_state
229234 optimizer , learning_rate_scheduler = self ._create_optimizer (pipeline .transformer , self .config , 1e-5 )
230235 # Returns pipeline with trained transformer state
231- pipeline = self .training_loop (pipeline , optimizer , learning_rate_scheduler , train_data_iterator )
236+ pipeline = self .training_loop (pipeline , optimizer , learning_rate_scheduler , train_data_iterator , restore_args )
232237
233238 if self .config .enable_ssim :
234239 posttrained_video_path = generate_sample (self .config , pipeline , filename_prefix = "post-training-" )
@@ -280,18 +285,28 @@ def eval(self, mesh, eval_rng_key, step, p_eval_step, state, scheduler_state, wr
280285 if writer :
281286 writer .add_scalar ("learning/eval_loss" , final_eval_loss , step )
282287
283- def training_loop (self , pipeline , optimizer , learning_rate_scheduler , train_data_iterator ):
288+ def training_loop (self , pipeline , optimizer , learning_rate_scheduler , train_data_iterator , restore_args : dict = {} ):
284289 mesh = pipeline .mesh
285290 graphdef , params , rest_of_state = nnx .split (pipeline .transformer , nnx .Param , ...)
286291
287292 with mesh , nn_partitioning .axis_rules (self .config .logical_axis_rules ):
288293 state = TrainState .create (
289- apply_fn = graphdef .apply , params = params , tx = optimizer , graphdef = graphdef , rest_of_state = rest_of_state
290- )
294+ apply_fn = graphdef .apply , params = params , tx = optimizer , graphdef = graphdef , rest_of_state = rest_of_state )
295+ if restore_args :
296+ step = restore_args .get ("step" , 0 )
297+ max_logging .log (f"Restoring optimizer and resuming from step { step } " )
298+ state .replace (opt_state = restore_args .get ("opt_state" ), step = restore_args .get ("step" , 0 ))
299+ del restore_args ["opt_state" ]
300+ del optimizer
291301 state = jax .tree .map (_to_array , state )
292302 state_spec = nnx .get_partition_spec (state )
293303 state = jax .lax .with_sharding_constraint (state , state_spec )
294304 state_shardings = nnx .get_named_sharding (state , mesh )
305+ if jax .process_index () == 0 and restore_args :
306+ max_logging .log ("--- Optimizer State Sharding Spec (opt_state) ---" )
307+ pretty_string = pprint .pformat (state_spec .opt_state , indent = 4 , width = 60 )
308+ max_logging .log (pretty_string )
309+ max_logging .log ("------------------------------------------------" )
295310 data_shardings = self .get_data_shardings (mesh )
296311 eval_data_shardings = self .get_eval_data_shardings (mesh )
297312
@@ -334,8 +349,9 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data
334349 last_profiling_step = np .clip (
335350 first_profiling_step + self .config .profiler_steps - 1 , first_profiling_step , self .config .max_train_steps - 1
336351 )
337- # TODO - 0 needs to be changed to last step if continuing from an orbax checkpoint.
338- start_step = 0
352+ if restore_args .get ("step" ,0 ):
353+ max_logging .log (f"Resuming training from step { step } " )
354+ start_step = restore_args .get ("step" ,0 )
339355 per_device_tflops , _ , _ = WanTrainer .calculate_tflops (pipeline )
340356 scheduler_state = pipeline .scheduler_state
341357 example_batch = load_next_batch (train_data_iterator , None , self .config )
@@ -373,7 +389,10 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data
373389 example_batch = next_batch_future .result ()
374390 if step != 0 and self .config .checkpoint_every != - 1 and step % self .config .checkpoint_every == 0 :
375391 max_logging .log (f"Saving checkpoint for step { step } " )
376- self .save_checkpoint (step , pipeline , state .params )
392+ if self .config .save_optimizer :
393+ self .save_checkpoint (step , pipeline , state )
394+ else :
395+ self .save_checkpoint (step , pipeline , state .params )
377396
378397 _metrics_queue .put (None )
379398 writer_thread .join ()
0 commit comments