7
7
import torchvision
8
8
import torch .distributed as dist
9
9
10
+ from huggingface_hub import snapshot_download
10
11
from safetensors import safe_open
11
12
from tqdm import tqdm
12
13
from einops import rearrange
13
14
from animatediff .utils .convert_from_ckpt import convert_ldm_unet_checkpoint , convert_ldm_clip_checkpoint , convert_ldm_vae_checkpoint
14
15
from animatediff .utils .convert_lora_safetensor_to_diffusers import convert_lora , load_diffusers_lora
15
16
16
17
18
+ MOTION_MODULES = [
19
+ "mm_sd_v14.ckpt" ,
20
+ "mm_sd_v15.ckpt" ,
21
+ "mm_sd_v15_v2.ckpt" ,
22
+ "v3_sd15_mm.ckpt" ,
23
+ ]
24
+
25
+ ADAPTERS = [
26
+ # "mm_sd_v14.ckpt",
27
+ # "mm_sd_v15.ckpt",
28
+ # "mm_sd_v15_v2.ckpt",
29
+ # "mm_sdxl_v10_beta.ckpt",
30
+ "v2_lora_PanLeft.ckpt" ,
31
+ "v2_lora_PanRight.ckpt" ,
32
+ "v2_lora_RollingAnticlockwise.ckpt" ,
33
+ "v2_lora_RollingClockwise.ckpt" ,
34
+ "v2_lora_TiltDown.ckpt" ,
35
+ "v2_lora_TiltUp.ckpt" ,
36
+ "v2_lora_ZoomIn.ckpt" ,
37
+ "v2_lora_ZoomOut.ckpt" ,
38
+ "v3_sd15_adapter.ckpt" ,
39
+ # "v3_sd15_mm.ckpt",
40
+ "v3_sd15_sparsectrl_rgb.ckpt" ,
41
+ "v3_sd15_sparsectrl_scribble.ckpt" ,
42
+ ]
43
+
44
+ BACKUP_DREAMBOOTH_MODELS = [
45
+ "realisticVisionV60B1_v51VAE.safetensors" ,
46
+ "majicmixRealistic_v4.safetensors" ,
47
+ "leosamsFilmgirlUltra_velvia20Lora.safetensors" ,
48
+ "toonyou_beta3.safetensors" ,
49
+ "majicmixRealistic_v5Preview.safetensors" ,
50
+ "rcnzCartoon3d_v10.safetensors" ,
51
+ "lyriel_v16.safetensors" ,
52
+ "leosamsHelloworldXL_filmGrain20.safetensors" ,
53
+ "TUSUN.safetensors" ,
54
+ ]
55
+
56
+
17
57
def zero_rank_print (s ):
18
58
if (not dist .is_initialized ()) and (dist .is_initialized () and dist .get_rank () == 0 ): print ("### " + s )
19
59
@@ -33,63 +73,20 @@ def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, f
33
73
imageio .mimsave (path , outputs , fps = fps )
34
74
35
75
36
- # DDIM Inversion
37
- @torch .no_grad ()
38
- def init_prompt (prompt , pipeline ):
39
- uncond_input = pipeline .tokenizer (
40
- ["" ], padding = "max_length" , max_length = pipeline .tokenizer .model_max_length ,
41
- return_tensors = "pt"
42
- )
43
- uncond_embeddings = pipeline .text_encoder (uncond_input .input_ids .to (pipeline .device ))[0 ]
44
- text_input = pipeline .tokenizer (
45
- [prompt ],
46
- padding = "max_length" ,
47
- max_length = pipeline .tokenizer .model_max_length ,
48
- truncation = True ,
49
- return_tensors = "pt" ,
50
- )
51
- text_embeddings = pipeline .text_encoder (text_input .input_ids .to (pipeline .device ))[0 ]
52
- context = torch .cat ([uncond_embeddings , text_embeddings ])
53
-
54
- return context
55
-
56
-
57
- def next_step (model_output : Union [torch .FloatTensor , np .ndarray ], timestep : int ,
58
- sample : Union [torch .FloatTensor , np .ndarray ], ddim_scheduler ):
59
- timestep , next_timestep = min (
60
- timestep - ddim_scheduler .config .num_train_timesteps // ddim_scheduler .num_inference_steps , 999 ), timestep
61
- alpha_prod_t = ddim_scheduler .alphas_cumprod [timestep ] if timestep >= 0 else ddim_scheduler .final_alpha_cumprod
62
- alpha_prod_t_next = ddim_scheduler .alphas_cumprod [next_timestep ]
63
- beta_prod_t = 1 - alpha_prod_t
64
- next_original_sample = (sample - beta_prod_t ** 0.5 * model_output ) / alpha_prod_t ** 0.5
65
- next_sample_direction = (1 - alpha_prod_t_next ) ** 0.5 * model_output
66
- next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction
67
- return next_sample
68
-
69
-
70
- def get_noise_pred_single (latents , t , context , unet ):
71
- noise_pred = unet (latents , t , encoder_hidden_states = context )["sample" ]
72
- return noise_pred
73
-
74
-
75
- @torch .no_grad ()
76
- def ddim_loop (pipeline , ddim_scheduler , latent , num_inv_steps , prompt ):
77
- context = init_prompt (prompt , pipeline )
78
- uncond_embeddings , cond_embeddings = context .chunk (2 )
79
- all_latent = [latent ]
80
- latent = latent .clone ().detach ()
81
- for i in tqdm (range (num_inv_steps )):
82
- t = ddim_scheduler .timesteps [len (ddim_scheduler .timesteps ) - i - 1 ]
83
- noise_pred = get_noise_pred_single (latent , t , cond_embeddings , pipeline .unet )
84
- latent = next_step (noise_pred , t , latent , ddim_scheduler )
85
- all_latent .append (latent )
86
- return all_latent
87
-
88
-
89
- @torch .no_grad ()
90
- def ddim_inversion (pipeline , ddim_scheduler , video_latent , num_inv_steps , prompt = "" ):
91
- ddim_latents = ddim_loop (pipeline , ddim_scheduler , video_latent , num_inv_steps , prompt )
92
- return ddim_latents
76
+ def auto_download (local_path , is_dreambooth_lora = False ):
77
+ hf_repo = "guoyww/animatediff_t2i_backups" if is_dreambooth_lora else "guoyww/animatediff"
78
+ folder , filename = os .path .split (local_path )
79
+
80
+ if not os .path .exists (local_path ):
81
+ print (f"local file { local_path } does not exist. trying to download from { hf_repo } " )
82
+
83
+ if is_dreambooth_lora : assert filename in BACKUP_DREAMBOOTH_MODELS , f"{ filename } dose not exist in { hf_repo } "
84
+ else : assert filename in MOTION_MODULES + ADAPTERS , f"{ filename } dose not exist in { hf_repo } "
85
+
86
+ folder = "." if folder == "" else folder
87
+ os .makedirs (folder , exist_ok = True )
88
+ snapshot_download (repo_id = hf_repo , local_dir = folder , allow_patterns = [filename ])
89
+
93
90
94
91
def load_weights (
95
92
animation_pipeline ,
@@ -107,10 +104,16 @@ def load_weights(
107
104
# motion module
108
105
unet_state_dict = {}
109
106
if motion_module_path != "" :
107
+ auto_download (motion_module_path , is_dreambooth_lora = False )
108
+
110
109
print (f"load motion module from { motion_module_path } " )
111
110
motion_module_state_dict = torch .load (motion_module_path , map_location = "cpu" )
112
111
motion_module_state_dict = motion_module_state_dict ["state_dict" ] if "state_dict" in motion_module_state_dict else motion_module_state_dict
113
- unet_state_dict .update ({name : param for name , param in motion_module_state_dict .items () if "motion_modules." in name })
112
+ # filter parameters
113
+ for name , param in motion_module_state_dict .items ():
114
+ if not "motion_modules." in name : continue
115
+ if "pos_encoder.pe" in name : continue
116
+ unet_state_dict .update ({name : param })
114
117
unet_state_dict .pop ("animatediff_config" , "" )
115
118
116
119
missing , unexpected = animation_pipeline .unet .load_state_dict (unet_state_dict , strict = False )
@@ -119,6 +122,8 @@ def load_weights(
119
122
120
123
# base model
121
124
if dreambooth_model_path != "" :
125
+ auto_download (dreambooth_model_path , is_dreambooth_lora = True )
126
+
122
127
print (f"load dreambooth model from { dreambooth_model_path } " )
123
128
if dreambooth_model_path .endswith (".safetensors" ):
124
129
dreambooth_state_dict = {}
@@ -140,6 +145,8 @@ def load_weights(
140
145
141
146
# lora layers
142
147
if lora_model_path != "" :
148
+ auto_download (lora_model_path , is_dreambooth_lora = True )
149
+
143
150
print (f"load lora model from { lora_model_path } " )
144
151
assert lora_model_path .endswith (".safetensors" )
145
152
lora_state_dict = {}
@@ -152,6 +159,8 @@ def load_weights(
152
159
153
160
# domain adapter lora
154
161
if adapter_lora_path != "" :
162
+ auto_download (adapter_lora_path , is_dreambooth_lora = False )
163
+
155
164
print (f"load domain lora from { adapter_lora_path } " )
156
165
domain_lora_state_dict = torch .load (adapter_lora_path , map_location = "cpu" )
157
166
domain_lora_state_dict = domain_lora_state_dict ["state_dict" ] if "state_dict" in domain_lora_state_dict else domain_lora_state_dict
@@ -162,6 +171,9 @@ def load_weights(
162
171
# motion module lora
163
172
for motion_module_lora_config in motion_module_lora_configs :
164
173
path , alpha = motion_module_lora_config ["path" ], motion_module_lora_config ["alpha" ]
174
+
175
+ auto_download (path , is_dreambooth_lora = False )
176
+
165
177
print (f"load motion LoRA from { path } " )
166
178
motion_lora_state_dict = torch .load (path , map_location = "cpu" )
167
179
motion_lora_state_dict = motion_lora_state_dict ["state_dict" ] if "state_dict" in motion_lora_state_dict else motion_lora_state_dict
0 commit comments