Skip to content

Commit 662d501

Browse files
authored
Add ability to save optimizer and resume while training (#275)
1 parent ae2a542 commit 662d501

File tree

5 files changed

+43
-17
lines changed

5 files changed

+43
-17
lines changed

src/maxdiffusion/checkpointing/wan_checkpointer.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ class WanCheckpointer(ABC):
3333
def __init__(self, config, checkpoint_type):
3434
self.config = config
3535
self.checkpoint_type = checkpoint_type
36+
self.opt_state = None
3637

3738
self.checkpoint_manager: ocp.CheckpointManager = create_orbax_checkpoint_manager(
3839
self.config.checkpoint_dir,
@@ -57,7 +58,6 @@ def load_wan_configs_from_orbax(self, step):
5758
return None
5859
max_logging.log(f"Loading WAN checkpoint from step {step}")
5960
metadatas = self.checkpoint_manager.item_metadata(step)
60-
6161
transformer_metadata = metadatas.wan_state
6262
abstract_tree_structure_params = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, transformer_metadata)
6363
params_restore = ocp.args.PyTreeRestore(
@@ -73,27 +73,32 @@ def load_wan_configs_from_orbax(self, step):
7373
step=step,
7474
args=ocp.args.Composite(
7575
wan_state=params_restore,
76-
# wan_state=params_restore_util_way,
7776
wan_config=ocp.args.JsonRestore(),
7877
),
7978
)
80-
return restored_checkpoint
79+
max_logging.log(f"restored checkpoint {restored_checkpoint.keys()}")
80+
max_logging.log(f"restored checkpoint wan_state {restored_checkpoint.wan_state.keys()}")
81+
max_logging.log(f"optimizer found in checkpoint {'opt_state' in restored_checkpoint.wan_state.keys()}")
82+
max_logging.log(f"optimizer state saved in attribute self.opt_state {self.opt_state}")
83+
return restored_checkpoint, step
8184

8285
def load_diffusers_checkpoint(self):
8386
pipeline = WanPipeline.from_pretrained(self.config)
8487
return pipeline
8588

8689
def load_checkpoint(self, step=None):
87-
restored_checkpoint = self.load_wan_configs_from_orbax(step)
88-
90+
restored_checkpoint, step = self.load_wan_configs_from_orbax(step)
91+
opt_state = None
8992
if restored_checkpoint:
9093
max_logging.log("Loading WAN pipeline from checkpoint")
9194
pipeline = WanPipeline.from_checkpoint(self.config, restored_checkpoint)
95+
if "opt_state" in restored_checkpoint["wan_state"].keys():
96+
opt_state = restored_checkpoint["wan_state"]["opt_state"]
9297
else:
9398
max_logging.log("No checkpoint found, loading default pipeline.")
9499
pipeline = self.load_diffusers_checkpoint()
95100

96-
return pipeline
101+
return pipeline, opt_state, step
97102

98103
def save_checkpoint(self, train_step, pipeline: WanPipeline, train_states: dict):
99104
"""Saves the training state and model configurations."""

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,7 @@ num_eval_samples: 420
243243

244244
warmup_steps_fraction: 0.1
245245
learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps.
246+
save_optimizer: False
246247

247248
# However you may choose a longer schedule (learning_rate_schedule_steps > steps), in which case the training will end before
248249
# dropping fully down. Or you may choose a shorter schedule, where the unspecified steps will have a learning rate of 0.

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,10 @@ def create_model(rngs: nnx.Rngs, wan_config: dict):
131131
# This helps with loading sharded weights directly into the accelerators without fist copying them
132132
# all to one device and then distributing them, thus using low HBM memory.
133133
if restored_checkpoint:
134-
params = restored_checkpoint["wan_state"]
134+
if "params" in restored_checkpoint["wan_state"]: # if checkpointed with optimizer
135+
params = restored_checkpoint["wan_state"]["params"]
136+
else: # if not checkpointed with optimizer
137+
params = restored_checkpoint["wan_state"]
135138
else:
136139
params = load_wan_transformer(
137140
config.wan_transformer_pretrained_model_name_or_path,

src/maxdiffusion/pyconfig.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,8 +196,6 @@ def user_init(raw_keys):
196196

197197
# Orbax doesn't save the tokenizer params, instead it loads them from the pretrained_model_name_or_path
198198
raw_keys["tokenizer_model_name_or_path"] = raw_keys["pretrained_model_name_or_path"]
199-
if "gs://" in raw_keys["tokenizer_model_name_or_path"]:
200-
raw_keys["pretrained_model_name_or_path"] = max_utils.download_blobs(raw_keys["pretrained_model_name_or_path"], "/tmp")
201199
if "gs://" in raw_keys["pretrained_model_name_or_path"]:
202200
raw_keys["pretrained_model_name_or_path"] = max_utils.download_blobs(raw_keys["pretrained_model_name_or_path"], "/tmp")
203201
if "gs://" in raw_keys["unet_checkpoint"]:

src/maxdiffusion/trainers/wan_trainer.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import os
1818
import datetime
1919
import functools
20+
from pprint import pprint
2021
import numpy as np
2122
import threading
2223
from concurrent.futures import ThreadPoolExecutor
@@ -209,7 +210,11 @@ def prepare_sample_eval(features):
209210

210211
def start_training(self):
211212

212-
pipeline = self.load_checkpoint()
213+
pipeline, opt_state, step = self.load_checkpoint()
214+
restore_args = {}
215+
if opt_state and step:
216+
restore_args = {"opt_state": opt_state, "step":step}
217+
del opt_state
213218
if self.config.enable_ssim:
214219
# Generate a sample before training to compare against generated sample after training.
215220
pretrained_video_path = generate_sample(self.config, pipeline, filename_prefix="pre-training-")
@@ -228,7 +233,7 @@ def start_training(self):
228233
pipeline.scheduler_state = scheduler_state
229234
optimizer, learning_rate_scheduler = self._create_optimizer(pipeline.transformer, self.config, 1e-5)
230235
# Returns pipeline with trained transformer state
231-
pipeline = self.training_loop(pipeline, optimizer, learning_rate_scheduler, train_data_iterator)
236+
pipeline = self.training_loop(pipeline, optimizer, learning_rate_scheduler, train_data_iterator, restore_args)
232237

233238
if self.config.enable_ssim:
234239
posttrained_video_path = generate_sample(self.config, pipeline, filename_prefix="post-training-")
@@ -280,18 +285,28 @@ def eval(self, mesh, eval_rng_key, step, p_eval_step, state, scheduler_state, wr
280285
if writer:
281286
writer.add_scalar("learning/eval_loss", final_eval_loss, step)
282287

283-
def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data_iterator):
288+
def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data_iterator, restore_args:dict={}):
284289
mesh = pipeline.mesh
285290
graphdef, params, rest_of_state = nnx.split(pipeline.transformer, nnx.Param, ...)
286291

287292
with mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
288293
state = TrainState.create(
289-
apply_fn=graphdef.apply, params=params, tx=optimizer, graphdef=graphdef, rest_of_state=rest_of_state
290-
)
294+
apply_fn=graphdef.apply, params=params, tx=optimizer, graphdef=graphdef, rest_of_state=rest_of_state)
295+
if restore_args:
296+
step = restore_args.get("step", 0)
297+
max_logging.log(f"Restoring optimizer and resuming from step {step}")
298+
state.replace(opt_state=restore_args.get("opt_state"), step = restore_args.get("step", 0))
299+
del restore_args["opt_state"]
300+
del optimizer
291301
state = jax.tree.map(_to_array, state)
292302
state_spec = nnx.get_partition_spec(state)
293303
state = jax.lax.with_sharding_constraint(state, state_spec)
294304
state_shardings = nnx.get_named_sharding(state, mesh)
305+
if jax.process_index() == 0 and restore_args:
306+
max_logging.log("--- Optimizer State Sharding Spec (opt_state) ---")
307+
pretty_string = pprint.pformat(state_spec.opt_state, indent=4, width=60)
308+
max_logging.log(pretty_string)
309+
max_logging.log("------------------------------------------------")
295310
data_shardings = self.get_data_shardings(mesh)
296311
eval_data_shardings = self.get_eval_data_shardings(mesh)
297312

@@ -334,8 +349,9 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data
334349
last_profiling_step = np.clip(
335350
first_profiling_step + self.config.profiler_steps - 1, first_profiling_step, self.config.max_train_steps - 1
336351
)
337-
# TODO - 0 needs to be changed to last step if continuing from an orbax checkpoint.
338-
start_step = 0
352+
if restore_args.get("step",0):
353+
max_logging.log(f"Resuming training from step {step}")
354+
start_step = restore_args.get("step",0)
339355
per_device_tflops, _, _ = WanTrainer.calculate_tflops(pipeline)
340356
scheduler_state = pipeline.scheduler_state
341357
example_batch = load_next_batch(train_data_iterator, None, self.config)
@@ -373,7 +389,10 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data
373389
example_batch = next_batch_future.result()
374390
if step != 0 and self.config.checkpoint_every != -1 and step % self.config.checkpoint_every == 0:
375391
max_logging.log(f"Saving checkpoint for step {step}")
376-
self.save_checkpoint(step, pipeline, state.params)
392+
if self.config.save_optimizer:
393+
self.save_checkpoint(step, pipeline, state)
394+
else:
395+
self.save_checkpoint(step, pipeline, state.params)
377396

378397
_metrics_queue.put(None)
379398
writer_thread.join()

0 commit comments

Comments
 (0)