Skip to content

Commit bda1482

Browse files
Basic Hunyuan Video model support.
1 parent 19ee5d9 commit bda1482

18 files changed

+413647
-77
lines changed

comfy/diffusers_convert.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -157,16 +157,23 @@ def convert_unet_state_dict(unet_state_dict):
157157
]
158158

159159

160-
def reshape_weight_for_sd(w):
160+
def reshape_weight_for_sd(w, conv3d=False):
161161
# convert HF linear weights to SD conv2d weights
162-
return w.reshape(*w.shape, 1, 1)
162+
if conv3d:
163+
return w.reshape(*w.shape, 1, 1, 1)
164+
else:
165+
return w.reshape(*w.shape, 1, 1)
163166

164167

165168
def convert_vae_state_dict(vae_state_dict):
166169
mapping = {k: k for k in vae_state_dict.keys()}
170+
conv3d = False
167171
for k, v in mapping.items():
168172
for sd_part, hf_part in vae_conversion_map:
169173
v = v.replace(hf_part, sd_part)
174+
if v.endswith(".conv.weight"):
175+
if not conv3d and vae_state_dict[k].ndim == 5:
176+
conv3d = True
170177
mapping[k] = v
171178
for k, v in mapping.items():
172179
if "attentions" in k:
@@ -179,7 +186,7 @@ def convert_vae_state_dict(vae_state_dict):
179186
for weight_name in weights_to_convert:
180187
if f"mid.attn_1.{weight_name}.weight" in k:
181188
logging.debug(f"Reshaping {k} for SD format")
182-
new_state_dict[k] = reshape_weight_for_sd(v)
189+
new_state_dict[k] = reshape_weight_for_sd(v, conv3d=conv3d)
183190
return new_state_dict
184191

185192

comfy/latent_formats.py

+4
Original file line numberDiff line numberDiff line change
@@ -352,3 +352,7 @@ def __init__(self):
352352
]
353353

354354
self.latent_rgb_factors_bias = [-0.0571, -0.1657, -0.2512]
355+
356+
class HunyuanVideo(LatentFormat):
357+
latent_channels = 16
358+
scale_factor = 0.476986

comfy/ldm/flux/layers.py

+18-8
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def forward(self, vec: Tensor) -> tuple:
114114

115115

116116
class DoubleStreamBlock(nn.Module):
117-
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, dtype=None, device=None, operations=None):
117+
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, dtype=None, device=None, operations=None):
118118
super().__init__()
119119

120120
mlp_hidden_dim = int(hidden_size * mlp_ratio)
@@ -141,6 +141,7 @@ def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias:
141141
nn.GELU(approximate="tanh"),
142142
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
143143
)
144+
self.flipped_img_txt = flipped_img_txt
144145

145146
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None):
146147
img_mod1, img_mod2 = self.img_mod(vec)
@@ -160,13 +161,22 @@ def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=N
160161
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
161162
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
162163

163-
# run actual attention
164-
attn = attention(torch.cat((txt_q, img_q), dim=2),
165-
torch.cat((txt_k, img_k), dim=2),
166-
torch.cat((txt_v, img_v), dim=2),
167-
pe=pe, mask=attn_mask)
168-
169-
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
164+
if self.flipped_img_txt:
165+
# run actual attention
166+
attn = attention(torch.cat((img_q, txt_q), dim=2),
167+
torch.cat((img_k, txt_k), dim=2),
168+
torch.cat((img_v, txt_v), dim=2),
169+
pe=pe, mask=attn_mask)
170+
171+
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1]:]
172+
else:
173+
# run actual attention
174+
attn = attention(torch.cat((txt_q, img_q), dim=2),
175+
torch.cat((txt_k, img_k), dim=2),
176+
torch.cat((txt_v, img_v), dim=2),
177+
pe=pe, mask=attn_mask)
178+
179+
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
170180

171181
# calculate the img bloks
172182
img = img + img_mod1.gate * self.img_attn.proj(img_attn)

0 commit comments

Comments
 (0)