Skip to content

Commit

Permalink
Basic Hunyuan Video model support.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Dec 17, 2024
1 parent 19ee5d9 commit bda1482
Show file tree
Hide file tree
Showing 18 changed files with 413,647 additions and 77 deletions.
13 changes: 10 additions & 3 deletions comfy/diffusers_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,16 +157,23 @@ def convert_unet_state_dict(unet_state_dict):
]


def reshape_weight_for_sd(w):
def reshape_weight_for_sd(w, conv3d=False):
# convert HF linear weights to SD conv2d weights
return w.reshape(*w.shape, 1, 1)
if conv3d:
return w.reshape(*w.shape, 1, 1, 1)
else:
return w.reshape(*w.shape, 1, 1)


def convert_vae_state_dict(vae_state_dict):
mapping = {k: k for k in vae_state_dict.keys()}
conv3d = False
for k, v in mapping.items():
for sd_part, hf_part in vae_conversion_map:
v = v.replace(hf_part, sd_part)
if v.endswith(".conv.weight"):
if not conv3d and vae_state_dict[k].ndim == 5:
conv3d = True
mapping[k] = v
for k, v in mapping.items():
if "attentions" in k:
Expand All @@ -179,7 +186,7 @@ def convert_vae_state_dict(vae_state_dict):
for weight_name in weights_to_convert:
if f"mid.attn_1.{weight_name}.weight" in k:
logging.debug(f"Reshaping {k} for SD format")
new_state_dict[k] = reshape_weight_for_sd(v)
new_state_dict[k] = reshape_weight_for_sd(v, conv3d=conv3d)
return new_state_dict


Expand Down
4 changes: 4 additions & 0 deletions comfy/latent_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,3 +352,7 @@ def __init__(self):
]

self.latent_rgb_factors_bias = [-0.0571, -0.1657, -0.2512]

class HunyuanVideo(LatentFormat):
latent_channels = 16
scale_factor = 0.476986
26 changes: 18 additions & 8 deletions comfy/ldm/flux/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def forward(self, vec: Tensor) -> tuple:


class DoubleStreamBlock(nn.Module):
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, dtype=None, device=None, operations=None):
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, dtype=None, device=None, operations=None):
super().__init__()

mlp_hidden_dim = int(hidden_size * mlp_ratio)
Expand All @@ -141,6 +141,7 @@ def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias:
nn.GELU(approximate="tanh"),
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
)
self.flipped_img_txt = flipped_img_txt

def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None):
img_mod1, img_mod2 = self.img_mod(vec)
Expand All @@ -160,13 +161,22 @@ def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=N
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)

# run actual attention
attn = attention(torch.cat((txt_q, img_q), dim=2),
torch.cat((txt_k, img_k), dim=2),
torch.cat((txt_v, img_v), dim=2),
pe=pe, mask=attn_mask)

txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
if self.flipped_img_txt:
# run actual attention
attn = attention(torch.cat((img_q, txt_q), dim=2),
torch.cat((img_k, txt_k), dim=2),
torch.cat((img_v, txt_v), dim=2),
pe=pe, mask=attn_mask)

img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1]:]
else:
# run actual attention
attn = attention(torch.cat((txt_q, img_q), dim=2),
torch.cat((txt_k, img_k), dim=2),
torch.cat((txt_v, img_v), dim=2),
pe=pe, mask=attn_mask)

txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]

# calculate the img bloks
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
Expand Down
Loading

0 comments on commit bda1482

Please sign in to comment.