Skip to content

Commit 786a99c

Browse files
committed
fix model
1 parent cf80dde commit 786a99c

File tree

3 files changed

+145
-76
lines changed

3 files changed

+145
-76
lines changed

animatediff/models/motion_module.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ def __init__(
238238
pe = torch.zeros(1, max_len, d_model)
239239
pe[0, :, 0::2] = torch.sin(position * div_term)
240240
pe[0, :, 1::2] = torch.cos(position * div_term)
241-
self.register_buffer('pe', pe)
241+
self.register_buffer('pe', pe, persistent=False)
242242

243243
def forward(self, x):
244244
x = x + self.pe[:, :x.size(1)]
@@ -251,7 +251,7 @@ def __init__(
251251
attention_mode = None,
252252
cross_frame_attention_mode = None,
253253
temporal_position_encoding = False,
254-
temporal_position_encoding_max_len = 24,
254+
temporal_position_encoding_max_len = 32,
255255
*args, **kwargs
256256
):
257257
super().__init__(*args, **kwargs)

animatediff/models/unet.py

+73-16
Original file line numberDiff line numberDiff line change
@@ -475,16 +475,77 @@ def forward(
475475
return UNet3DConditionOutput(sample=sample)
476476

477477
@classmethod
478-
def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, unet_additional_kwargs=None):
479-
if subfolder is not None:
480-
pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
481-
print(f"loaded 3D unet's pretrained weights from {pretrained_model_path} ...")
482-
483-
config_file = os.path.join(pretrained_model_path, 'config.json')
484-
if not os.path.isfile(config_file):
485-
raise RuntimeError(f"{config_file} does not exist")
486-
with open(config_file, "r") as f:
487-
config = json.load(f)
478+
def from_pretrained_2d(cls, pretrained_model_name_or_path, unet_additional_kwargs={}, **kwargs):
479+
from diffusers import __version__
480+
from diffusers.utils import DIFFUSERS_CACHE, SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME, is_safetensors_available
481+
from diffusers.modeling_utils import load_state_dict
482+
print(f"loaded 3D unet's pretrained weights from {pretrained_model_name_or_path} ...")
483+
484+
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
485+
force_download = kwargs.pop("force_download", False)
486+
resume_download = kwargs.pop("resume_download", False)
487+
proxies = kwargs.pop("proxies", None)
488+
local_files_only = kwargs.pop("local_files_only", False)
489+
use_auth_token = kwargs.pop("use_auth_token", None)
490+
revision = kwargs.pop("revision", None)
491+
subfolder = kwargs.pop("subfolder", None)
492+
device_map = kwargs.pop("device_map", None)
493+
494+
user_agent = {
495+
"diffusers": __version__,
496+
"file_type": "model",
497+
"framework": "pytorch",
498+
}
499+
500+
model_file = None
501+
if is_safetensors_available():
502+
try:
503+
model_file = cls._get_model_file(
504+
pretrained_model_name_or_path,
505+
weights_name=SAFETENSORS_WEIGHTS_NAME,
506+
cache_dir=cache_dir,
507+
force_download=force_download,
508+
resume_download=resume_download,
509+
proxies=proxies,
510+
local_files_only=local_files_only,
511+
use_auth_token=use_auth_token,
512+
revision=revision,
513+
subfolder=subfolder,
514+
user_agent=user_agent,
515+
)
516+
except:
517+
pass
518+
519+
if model_file is None:
520+
model_file = cls._get_model_file(
521+
pretrained_model_name_or_path,
522+
weights_name=WEIGHTS_NAME,
523+
cache_dir=cache_dir,
524+
force_download=force_download,
525+
resume_download=resume_download,
526+
proxies=proxies,
527+
local_files_only=local_files_only,
528+
use_auth_token=use_auth_token,
529+
revision=revision,
530+
subfolder=subfolder,
531+
user_agent=user_agent,
532+
)
533+
534+
config, unused_kwargs = cls.load_config(
535+
pretrained_model_name_or_path,
536+
cache_dir=cache_dir,
537+
return_unused_kwargs=True,
538+
force_download=force_download,
539+
resume_download=resume_download,
540+
proxies=proxies,
541+
local_files_only=local_files_only,
542+
use_auth_token=use_auth_token,
543+
revision=revision,
544+
subfolder=subfolder,
545+
device_map=device_map,
546+
**kwargs,
547+
)
548+
488549
config["_class_name"] = cls.__name__
489550
config["down_block_types"] = [
490551
"CrossAttnDownBlock3D",
@@ -499,12 +560,8 @@ def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, unet_addition
499560
"CrossAttnUpBlock3D"
500561
]
501562

502-
from diffusers.utils import WEIGHTS_NAME
503-
model = cls.from_config(config, **unet_additional_kwargs)
504-
model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
505-
if not os.path.isfile(model_file):
506-
raise RuntimeError(f"{model_file} does not exist")
507-
state_dict = torch.load(model_file, map_location="cpu")
563+
model = cls.from_config(config, **unused_kwargs, **unet_additional_kwargs)
564+
state_dict = load_state_dict(model_file)
508565

509566
m, u = model.load_state_dict(state_dict, strict=False)
510567
print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")

animatediff/utils/util.py

+70-58
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,53 @@
77
import torchvision
88
import torch.distributed as dist
99

10+
from huggingface_hub import snapshot_download
1011
from safetensors import safe_open
1112
from tqdm import tqdm
1213
from einops import rearrange
1314
from animatediff.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint
1415
from animatediff.utils.convert_lora_safetensor_to_diffusers import convert_lora, load_diffusers_lora
1516

1617

18+
MOTION_MODULES = [
19+
"mm_sd_v14.ckpt",
20+
"mm_sd_v15.ckpt",
21+
"mm_sd_v15_v2.ckpt",
22+
"v3_sd15_mm.ckpt",
23+
]
24+
25+
ADAPTERS = [
26+
# "mm_sd_v14.ckpt",
27+
# "mm_sd_v15.ckpt",
28+
# "mm_sd_v15_v2.ckpt",
29+
# "mm_sdxl_v10_beta.ckpt",
30+
"v2_lora_PanLeft.ckpt",
31+
"v2_lora_PanRight.ckpt",
32+
"v2_lora_RollingAnticlockwise.ckpt",
33+
"v2_lora_RollingClockwise.ckpt",
34+
"v2_lora_TiltDown.ckpt",
35+
"v2_lora_TiltUp.ckpt",
36+
"v2_lora_ZoomIn.ckpt",
37+
"v2_lora_ZoomOut.ckpt",
38+
"v3_sd15_adapter.ckpt",
39+
# "v3_sd15_mm.ckpt",
40+
"v3_sd15_sparsectrl_rgb.ckpt",
41+
"v3_sd15_sparsectrl_scribble.ckpt",
42+
]
43+
44+
BACKUP_DREAMBOOTH_MODELS = [
45+
"realisticVisionV60B1_v51VAE.safetensors",
46+
"majicmixRealistic_v4.safetensors",
47+
"leosamsFilmgirlUltra_velvia20Lora.safetensors",
48+
"toonyou_beta3.safetensors",
49+
"majicmixRealistic_v5Preview.safetensors",
50+
"rcnzCartoon3d_v10.safetensors",
51+
"lyriel_v16.safetensors",
52+
"leosamsHelloworldXL_filmGrain20.safetensors",
53+
"TUSUN.safetensors",
54+
]
55+
56+
1757
def zero_rank_print(s):
1858
if (not dist.is_initialized()) and (dist.is_initialized() and dist.get_rank() == 0): print("### " + s)
1959

@@ -33,63 +73,20 @@ def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, f
3373
imageio.mimsave(path, outputs, fps=fps)
3474

3575

36-
# DDIM Inversion
37-
@torch.no_grad()
38-
def init_prompt(prompt, pipeline):
39-
uncond_input = pipeline.tokenizer(
40-
[""], padding="max_length", max_length=pipeline.tokenizer.model_max_length,
41-
return_tensors="pt"
42-
)
43-
uncond_embeddings = pipeline.text_encoder(uncond_input.input_ids.to(pipeline.device))[0]
44-
text_input = pipeline.tokenizer(
45-
[prompt],
46-
padding="max_length",
47-
max_length=pipeline.tokenizer.model_max_length,
48-
truncation=True,
49-
return_tensors="pt",
50-
)
51-
text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0]
52-
context = torch.cat([uncond_embeddings, text_embeddings])
53-
54-
return context
55-
56-
57-
def next_step(model_output: Union[torch.FloatTensor, np.ndarray], timestep: int,
58-
sample: Union[torch.FloatTensor, np.ndarray], ddim_scheduler):
59-
timestep, next_timestep = min(
60-
timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999), timestep
61-
alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod
62-
alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep]
63-
beta_prod_t = 1 - alpha_prod_t
64-
next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
65-
next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output
66-
next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction
67-
return next_sample
68-
69-
70-
def get_noise_pred_single(latents, t, context, unet):
71-
noise_pred = unet(latents, t, encoder_hidden_states=context)["sample"]
72-
return noise_pred
73-
74-
75-
@torch.no_grad()
76-
def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt):
77-
context = init_prompt(prompt, pipeline)
78-
uncond_embeddings, cond_embeddings = context.chunk(2)
79-
all_latent = [latent]
80-
latent = latent.clone().detach()
81-
for i in tqdm(range(num_inv_steps)):
82-
t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1]
83-
noise_pred = get_noise_pred_single(latent, t, cond_embeddings, pipeline.unet)
84-
latent = next_step(noise_pred, t, latent, ddim_scheduler)
85-
all_latent.append(latent)
86-
return all_latent
87-
88-
89-
@torch.no_grad()
90-
def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""):
91-
ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt)
92-
return ddim_latents
76+
def auto_download(local_path, is_dreambooth_lora=False):
77+
hf_repo = "guoyww/animatediff_t2i_backups" if is_dreambooth_lora else "guoyww/animatediff"
78+
folder, filename = os.path.split(local_path)
79+
80+
if not os.path.exists(local_path):
81+
print(f"local file {local_path} does not exist. trying to download from {hf_repo}")
82+
83+
if is_dreambooth_lora: assert filename in BACKUP_DREAMBOOTH_MODELS, f"{filename} dose not exist in {hf_repo}"
84+
else: assert filename in MOTION_MODULES + ADAPTERS, f"{filename} dose not exist in {hf_repo}"
85+
86+
folder = "." if folder == "" else folder
87+
os.makedirs(folder, exist_ok=True)
88+
snapshot_download(repo_id=hf_repo, local_dir=folder, allow_patterns=[filename])
89+
9390

9491
def load_weights(
9592
animation_pipeline,
@@ -107,10 +104,16 @@ def load_weights(
107104
# motion module
108105
unet_state_dict = {}
109106
if motion_module_path != "":
107+
auto_download(motion_module_path, is_dreambooth_lora=False)
108+
110109
print(f"load motion module from {motion_module_path}")
111110
motion_module_state_dict = torch.load(motion_module_path, map_location="cpu")
112111
motion_module_state_dict = motion_module_state_dict["state_dict"] if "state_dict" in motion_module_state_dict else motion_module_state_dict
113-
unet_state_dict.update({name: param for name, param in motion_module_state_dict.items() if "motion_modules." in name})
112+
# filter parameters
113+
for name, param in motion_module_state_dict.items():
114+
if not "motion_modules." in name: continue
115+
if "pos_encoder.pe" in name: continue
116+
unet_state_dict.update({name: param})
114117
unet_state_dict.pop("animatediff_config", "")
115118

116119
missing, unexpected = animation_pipeline.unet.load_state_dict(unet_state_dict, strict=False)
@@ -119,6 +122,8 @@ def load_weights(
119122

120123
# base model
121124
if dreambooth_model_path != "":
125+
auto_download(dreambooth_model_path, is_dreambooth_lora=True)
126+
122127
print(f"load dreambooth model from {dreambooth_model_path}")
123128
if dreambooth_model_path.endswith(".safetensors"):
124129
dreambooth_state_dict = {}
@@ -140,6 +145,8 @@ def load_weights(
140145

141146
# lora layers
142147
if lora_model_path != "":
148+
auto_download(lora_model_path, is_dreambooth_lora=True)
149+
143150
print(f"load lora model from {lora_model_path}")
144151
assert lora_model_path.endswith(".safetensors")
145152
lora_state_dict = {}
@@ -152,6 +159,8 @@ def load_weights(
152159

153160
# domain adapter lora
154161
if adapter_lora_path != "":
162+
auto_download(adapter_lora_path, is_dreambooth_lora=False)
163+
155164
print(f"load domain lora from {adapter_lora_path}")
156165
domain_lora_state_dict = torch.load(adapter_lora_path, map_location="cpu")
157166
domain_lora_state_dict = domain_lora_state_dict["state_dict"] if "state_dict" in domain_lora_state_dict else domain_lora_state_dict
@@ -162,6 +171,9 @@ def load_weights(
162171
# motion module lora
163172
for motion_module_lora_config in motion_module_lora_configs:
164173
path, alpha = motion_module_lora_config["path"], motion_module_lora_config["alpha"]
174+
175+
auto_download(path, is_dreambooth_lora=False)
176+
165177
print(f"load motion LoRA from {path}")
166178
motion_lora_state_dict = torch.load(path, map_location="cpu")
167179
motion_lora_state_dict = motion_lora_state_dict["state_dict"] if "state_dict" in motion_lora_state_dict else motion_lora_state_dict

0 commit comments

Comments
 (0)