Skip to content

Commit 9b44b5e

Browse files
committed
changes for WAN 2.2
1 parent c9229c3 commit 9b44b5e

File tree

7 files changed

+1406
-51
lines changed

7 files changed

+1406
-51
lines changed
Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
"""
2+
Copyright 2025 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
from abc import ABC
18+
import json
19+
20+
import jax
21+
import numpy as np
22+
from typing import Optional, Tuple
23+
from maxdiffusion.checkpointing.checkpointing_utils import (create_orbax_checkpoint_manager)
24+
from ..pipelines.wan.wan_pipeline2_2 import WanPipeline
25+
from .. import max_logging, max_utils
26+
import orbax.checkpoint as ocp
27+
from etils import epath
28+
29+
WAN_CHECKPOINT = "WAN_CHECKPOINT"
30+
31+
32+
class WanCheckpointer(ABC):
33+
34+
def __init__(self, config, checkpoint_type):
35+
self.config = config
36+
self.checkpoint_type = checkpoint_type
37+
self.opt_state = None
38+
39+
self.checkpoint_manager: ocp.CheckpointManager = create_orbax_checkpoint_manager(
40+
self.config.checkpoint_dir,
41+
enable_checkpointing=True,
42+
save_interval_steps=1,
43+
checkpoint_type=checkpoint_type,
44+
dataset_type=config.dataset_type,
45+
)
46+
47+
def _create_optimizer(self, model, config, learning_rate):
48+
learning_rate_scheduler = max_utils.create_learning_rate_schedule(
49+
learning_rate, config.learning_rate_schedule_steps, config.warmup_steps_fraction, config.max_train_steps
50+
)
51+
tx = max_utils.create_optimizer(config, learning_rate_scheduler)
52+
return tx, learning_rate_scheduler
53+
54+
def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dict], Optional[int]]:
55+
if step is None:
56+
step = self.checkpoint_manager.latest_step()
57+
max_logging.log(f"Latest WAN checkpoint step: {step}")
58+
if step is None:
59+
max_logging.log("No WAN checkpoint found.")
60+
return None, None
61+
max_logging.log(f"Loading WAN checkpoint from step {step}")
62+
metadatas = self.checkpoint_manager.item_metadata(step)
63+
64+
restore_args = {}
65+
66+
low_state_metadata = metadatas.low_noise_transformer_state
67+
abstract_tree_structure_low_state = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, low_state_metadata)
68+
low_state_restore = ocp.args.PyTreeRestore(
69+
restore_args=jax.tree.map(
70+
lambda _: ocp.RestoreArgs(restore_type=np.ndarray),
71+
abstract_tree_structure_low_state,
72+
)
73+
)
74+
restore_args["low_noise_transformer_state"] = low_state_restore
75+
76+
high_state_metadata = metadatas.high_noise_transformer_state
77+
abstract_tree_structure_high_state = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, high_state_metadata)
78+
high_state_restore = ocp.args.PyTreeRestore(
79+
restore_args=jax.tree.map(
80+
lambda _: ocp.RestoreArgs(restore_type=np.ndarray),
81+
abstract_tree_structure_high_state,
82+
)
83+
)
84+
restore_args["high_noise_transformer_state"] = high_state_restore
85+
86+
restore_args["wan_config"] = ocp.args.JsonRestore()
87+
88+
max_logging.log("Restoring WAN 2.2 checkpoint")
89+
restored_checkpoint = self.checkpoint_manager.restore(
90+
step=step,
91+
args=ocp.args.Composite(**restore_args),
92+
)
93+
max_logging.log(f"restored checkpoint {restored_checkpoint.keys()}")
94+
max_logging.log(f"restored checkpoint wan_state {restored_checkpoint.wan_state.keys()}")
95+
max_logging.log(f"optimizer found in checkpoint {'opt_state' in restored_checkpoint.wan_state.keys()}")
96+
max_logging.log(f"optimizer state saved in attribute self.opt_state {self.opt_state}")
97+
return restored_checkpoint, step
98+
99+
def load_diffusers_checkpoint(self):
100+
pipeline = WanPipeline.from_pretrained(self.config)
101+
return pipeline
102+
103+
def load_checkpoint(self, step=None) -> Tuple[WanPipeline, Optional[dict], Optional[int]]:
104+
restored_checkpoint, step = self.load_wan_configs_from_orbax(step)
105+
opt_state = None
106+
if restored_checkpoint:
107+
max_logging.log("Loading WAN pipeline from checkpoint")
108+
pipeline = WanPipeline.from_checkpoint(self.config, restored_checkpoint)
109+
if "opt_state" in restored_checkpoint["wan_state"].keys():
110+
opt_state = restored_checkpoint["wan_state"]["opt_state"]
111+
else:
112+
max_logging.log("No checkpoint found, loading default pipeline.")
113+
pipeline = self.load_diffusers_checkpoint()
114+
115+
return pipeline, opt_state, step
116+
117+
def save_checkpoint(self, train_step, pipeline: WanPipeline, train_states: dict):
118+
"""Saves the training state and model configurations."""
119+
120+
def config_to_json(model_or_config):
121+
return json.loads(model_or_config.to_json_string())
122+
123+
max_logging.log(f"Saving checkpoint for step {train_step}")
124+
items = {
125+
"wan_config": ocp.args.JsonSave(config_to_json(pipeline.low_noise_transformer)),
126+
}
127+
128+
if "low_noise_transformer" in train_states:
129+
low_noise_state = train_states["low_noise_transformer"]
130+
items["low_noise_transformer_state"] = ocp.args.PyTreeSave(low_noise_state)
131+
132+
if "high_noise_transformer" in train_states:
133+
high_noise_state = train_states["high_noise_transformer"]
134+
items["high_noise_transformer_state"] = ocp.args.PyTreeSave(high_noise_state)
135+
136+
# Save the checkpoint
137+
if len(items) > 1:
138+
self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items))
139+
max_logging.log(f"Checkpoint for step {train_step} saved.")
140+
141+
142+
def save_checkpoint_orig(self, train_step, pipeline: WanPipeline, train_states: dict):
143+
"""Saves the training state and model configurations."""
144+
145+
def config_to_json(model_or_config):
146+
"""
147+
only save the config that is needed and can be serialized to JSON.
148+
"""
149+
if not hasattr(model_or_config, "config"):
150+
return None
151+
source_config = dict(model_or_config.config)
152+
153+
# 1. configs that can be serialized to JSON
154+
SAFE_KEYS = [
155+
"_class_name",
156+
"_diffusers_version",
157+
"model_type",
158+
"patch_size",
159+
"num_attention_heads",
160+
"attention_head_dim",
161+
"in_channels",
162+
"out_channels",
163+
"text_dim",
164+
"freq_dim",
165+
"ffn_dim",
166+
"num_layers",
167+
"cross_attn_norm",
168+
"qk_norm",
169+
"eps",
170+
"image_dim",
171+
"added_kv_proj_dim",
172+
"rope_max_seq_len",
173+
"pos_embed_seq_len",
174+
"flash_min_seq_length",
175+
"flash_block_sizes",
176+
"attention",
177+
"_use_default_values",
178+
]
179+
180+
# 2. save the config that are in the SAFE_KEYS list
181+
clean_config = {}
182+
for key in SAFE_KEYS:
183+
if key in source_config:
184+
clean_config[key] = source_config[key]
185+
186+
# 3. deal with special data type and precision
187+
if "dtype" in source_config and hasattr(source_config["dtype"], "name"):
188+
clean_config["dtype"] = source_config["dtype"].name # e.g 'bfloat16'
189+
190+
if "weights_dtype" in source_config and hasattr(source_config["weights_dtype"], "name"):
191+
clean_config["weights_dtype"] = source_config["weights_dtype"].name
192+
193+
if "precision" in source_config and isinstance(source_config["precision"]):
194+
clean_config["precision"] = source_config["precision"].name # e.g. 'HIGHEST'
195+
196+
return clean_config
197+
198+
items_to_save = {
199+
"transformer_config": ocp.args.JsonSave(config_to_json(pipeline.transformer)),
200+
}
201+
202+
items_to_save["transformer_states"] = ocp.args.PyTreeSave(train_states)
203+
204+
# Create CompositeArgs for Orbax
205+
save_args = ocp.args.Composite(**items_to_save)
206+
207+
# Save the checkpoint
208+
self.checkpoint_manager.save(train_step, args=save_args)
209+
max_logging.log(f"Checkpoint for step {train_step} saved.")

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ save_config_to_gcs: False
2828
log_period: 100
2929

3030
pretrained_model_name_or_path: 'Wan-AI/Wan2.1-T2V-14B-Diffusers'
31+
model_name: wan2.1
3132

3233
# Overrides the transformer from pretrained_model_name_or_path
3334
wan_transformer_pretrained_model_name_or_path: ''

0 commit comments

Comments
 (0)