1+ """
2+ Copyright 2025 Google LLC
3+
4+ Licensed under the Apache License, Version 2.0 (the "License");
5+ you may not use this file except in compliance with the License.
6+ You may obtain a copy of the License at
7+
8+ https://www.apache.org/licenses/LICENSE-2.0
9+
10+ Unless required by applicable law or agreed to in writing, software
11+ distributed under the License is distributed on an "AS IS" BASIS,
12+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+ See the License for the specific language governing permissions and
14+ limitations under the License.
15+ """
16+
17+ from abc import ABC
18+ import json
19+
20+ import jax
21+ import numpy as np
22+ from typing import Optional , Tuple
23+ from maxdiffusion .checkpointing .checkpointing_utils import (create_orbax_checkpoint_manager )
24+ from ..pipelines .wan .wan_pipeline2_2 import WanPipeline
25+ from .. import max_logging , max_utils
26+ import orbax .checkpoint as ocp
27+ from etils import epath
28+
29+ WAN_CHECKPOINT = "WAN_CHECKPOINT"
30+
31+
32+ class WanCheckpointer (ABC ):
33+
34+ def __init__ (self , config , checkpoint_type ):
35+ self .config = config
36+ self .checkpoint_type = checkpoint_type
37+ self .opt_state = None
38+
39+ self .checkpoint_manager : ocp .CheckpointManager = create_orbax_checkpoint_manager (
40+ self .config .checkpoint_dir ,
41+ enable_checkpointing = True ,
42+ save_interval_steps = 1 ,
43+ checkpoint_type = checkpoint_type ,
44+ dataset_type = config .dataset_type ,
45+ )
46+
47+ def _create_optimizer (self , model , config , learning_rate ):
48+ learning_rate_scheduler = max_utils .create_learning_rate_schedule (
49+ learning_rate , config .learning_rate_schedule_steps , config .warmup_steps_fraction , config .max_train_steps
50+ )
51+ tx = max_utils .create_optimizer (config , learning_rate_scheduler )
52+ return tx , learning_rate_scheduler
53+
54+ def load_wan_configs_from_orbax (self , step : Optional [int ]) -> Tuple [Optional [dict ], Optional [int ]]:
55+ if step is None :
56+ step = self .checkpoint_manager .latest_step ()
57+ max_logging .log (f"Latest WAN checkpoint step: { step } " )
58+ if step is None :
59+ max_logging .log ("No WAN checkpoint found." )
60+ return None , None
61+ max_logging .log (f"Loading WAN checkpoint from step { step } " )
62+ metadatas = self .checkpoint_manager .item_metadata (step )
63+
64+ restore_args = {}
65+
66+ low_state_metadata = metadatas .low_noise_transformer_state
67+ abstract_tree_structure_low_state = jax .tree_util .tree_map (ocp .utils .to_shape_dtype_struct , low_state_metadata )
68+ low_state_restore = ocp .args .PyTreeRestore (
69+ restore_args = jax .tree .map (
70+ lambda _ : ocp .RestoreArgs (restore_type = np .ndarray ),
71+ abstract_tree_structure_low_state ,
72+ )
73+ )
74+ restore_args ["low_noise_transformer_state" ] = low_state_restore
75+
76+ high_state_metadata = metadatas .high_noise_transformer_state
77+ abstract_tree_structure_high_state = jax .tree_util .tree_map (ocp .utils .to_shape_dtype_struct , high_state_metadata )
78+ high_state_restore = ocp .args .PyTreeRestore (
79+ restore_args = jax .tree .map (
80+ lambda _ : ocp .RestoreArgs (restore_type = np .ndarray ),
81+ abstract_tree_structure_high_state ,
82+ )
83+ )
84+ restore_args ["high_noise_transformer_state" ] = high_state_restore
85+
86+ restore_args ["wan_config" ] = ocp .args .JsonRestore ()
87+
88+ max_logging .log ("Restoring WAN 2.2 checkpoint" )
89+ restored_checkpoint = self .checkpoint_manager .restore (
90+ step = step ,
91+ args = ocp .args .Composite (** restore_args ),
92+ )
93+ max_logging .log (f"restored checkpoint { restored_checkpoint .keys ()} " )
94+ max_logging .log (f"restored checkpoint wan_state { restored_checkpoint .wan_state .keys ()} " )
95+ max_logging .log (f"optimizer found in checkpoint { 'opt_state' in restored_checkpoint .wan_state .keys ()} " )
96+ max_logging .log (f"optimizer state saved in attribute self.opt_state { self .opt_state } " )
97+ return restored_checkpoint , step
98+
99+ def load_diffusers_checkpoint (self ):
100+ pipeline = WanPipeline .from_pretrained (self .config )
101+ return pipeline
102+
103+ def load_checkpoint (self , step = None ) -> Tuple [WanPipeline , Optional [dict ], Optional [int ]]:
104+ restored_checkpoint , step = self .load_wan_configs_from_orbax (step )
105+ opt_state = None
106+ if restored_checkpoint :
107+ max_logging .log ("Loading WAN pipeline from checkpoint" )
108+ pipeline = WanPipeline .from_checkpoint (self .config , restored_checkpoint )
109+ if "opt_state" in restored_checkpoint ["wan_state" ].keys ():
110+ opt_state = restored_checkpoint ["wan_state" ]["opt_state" ]
111+ else :
112+ max_logging .log ("No checkpoint found, loading default pipeline." )
113+ pipeline = self .load_diffusers_checkpoint ()
114+
115+ return pipeline , opt_state , step
116+
117+ def save_checkpoint (self , train_step , pipeline : WanPipeline , train_states : dict ):
118+ """Saves the training state and model configurations."""
119+
120+ def config_to_json (model_or_config ):
121+ return json .loads (model_or_config .to_json_string ())
122+
123+ max_logging .log (f"Saving checkpoint for step { train_step } " )
124+ items = {
125+ "wan_config" : ocp .args .JsonSave (config_to_json (pipeline .low_noise_transformer )),
126+ }
127+
128+ if "low_noise_transformer" in train_states :
129+ low_noise_state = train_states ["low_noise_transformer" ]
130+ items ["low_noise_transformer_state" ] = ocp .args .PyTreeSave (low_noise_state )
131+
132+ if "high_noise_transformer" in train_states :
133+ high_noise_state = train_states ["high_noise_transformer" ]
134+ items ["high_noise_transformer_state" ] = ocp .args .PyTreeSave (high_noise_state )
135+
136+ # Save the checkpoint
137+ if len (items ) > 1 :
138+ self .checkpoint_manager .save (train_step , args = ocp .args .Composite (** items ))
139+ max_logging .log (f"Checkpoint for step { train_step } saved." )
140+
141+
142+ def save_checkpoint_orig (self , train_step , pipeline : WanPipeline , train_states : dict ):
143+ """Saves the training state and model configurations."""
144+
145+ def config_to_json (model_or_config ):
146+ """
147+ only save the config that is needed and can be serialized to JSON.
148+ """
149+ if not hasattr (model_or_config , "config" ):
150+ return None
151+ source_config = dict (model_or_config .config )
152+
153+ # 1. configs that can be serialized to JSON
154+ SAFE_KEYS = [
155+ "_class_name" ,
156+ "_diffusers_version" ,
157+ "model_type" ,
158+ "patch_size" ,
159+ "num_attention_heads" ,
160+ "attention_head_dim" ,
161+ "in_channels" ,
162+ "out_channels" ,
163+ "text_dim" ,
164+ "freq_dim" ,
165+ "ffn_dim" ,
166+ "num_layers" ,
167+ "cross_attn_norm" ,
168+ "qk_norm" ,
169+ "eps" ,
170+ "image_dim" ,
171+ "added_kv_proj_dim" ,
172+ "rope_max_seq_len" ,
173+ "pos_embed_seq_len" ,
174+ "flash_min_seq_length" ,
175+ "flash_block_sizes" ,
176+ "attention" ,
177+ "_use_default_values" ,
178+ ]
179+
180+ # 2. save the config that are in the SAFE_KEYS list
181+ clean_config = {}
182+ for key in SAFE_KEYS :
183+ if key in source_config :
184+ clean_config [key ] = source_config [key ]
185+
186+ # 3. deal with special data type and precision
187+ if "dtype" in source_config and hasattr (source_config ["dtype" ], "name" ):
188+ clean_config ["dtype" ] = source_config ["dtype" ].name # e.g 'bfloat16'
189+
190+ if "weights_dtype" in source_config and hasattr (source_config ["weights_dtype" ], "name" ):
191+ clean_config ["weights_dtype" ] = source_config ["weights_dtype" ].name
192+
193+ if "precision" in source_config and isinstance (source_config ["precision" ]):
194+ clean_config ["precision" ] = source_config ["precision" ].name # e.g. 'HIGHEST'
195+
196+ return clean_config
197+
198+ items_to_save = {
199+ "transformer_config" : ocp .args .JsonSave (config_to_json (pipeline .transformer )),
200+ }
201+
202+ items_to_save ["transformer_states" ] = ocp .args .PyTreeSave (train_states )
203+
204+ # Create CompositeArgs for Orbax
205+ save_args = ocp .args .Composite (** items_to_save )
206+
207+ # Save the checkpoint
208+ self .checkpoint_manager .save (train_step , args = save_args )
209+ max_logging .log (f"Checkpoint for step { train_step } saved." )
0 commit comments