Skip to content

Commit 5cbb01b

Browse files
Basic Genmo Mochi video model support.
To use: "Load CLIP" node with t5xxl + type mochi "Load Diffusion Model" node with the mochi dit file. "Load VAE" with the mochi vae file. EmptyMochiLatentVideo node for the latent. euler + linear_quadratic in the KSampler node.
1 parent c3ffbae commit 5cbb01b

18 files changed

+1677
-24
lines changed

comfy/latent_formats.py

+27
Original file line numberDiff line numberDiff line change
@@ -175,3 +175,30 @@ def process_in(self, latent):
175175

176176
def process_out(self, latent):
177177
return (latent / self.scale_factor) + self.shift_factor
178+
179+
class Mochi(LatentFormat):
180+
latent_channels = 12
181+
182+
def __init__(self):
183+
self.scale_factor = 1.0
184+
self.latents_mean = torch.tensor([-0.06730895953510081, -0.038011381506090416, -0.07477820912866141,
185+
-0.05565264470995561, 0.012767231469026969, -0.04703542746246419,
186+
0.043896967884726704, -0.09346305707025976, -0.09918314763016893,
187+
-0.008729793427399178, -0.011931556316503654, -0.0321993391887285]).view(1, self.latent_channels, 1, 1, 1)
188+
self.latents_std = torch.tensor([0.9263795028493863, 0.9248894543193766, 0.9393059390890617,
189+
0.959253732819592, 0.8244560132752793, 0.917259975397747,
190+
0.9294154431013696, 1.3720942357788521, 0.881393668867029,
191+
0.9168315692124348, 0.9185249279345552, 0.9274757570805041]).view(1, self.latent_channels, 1, 1, 1)
192+
193+
self.latent_rgb_factors = None #TODO
194+
self.taesd_decoder_name = None #TODO
195+
196+
def process_in(self, latent):
197+
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
198+
latents_std = self.latents_std.to(latent.device, latent.dtype)
199+
return (latent - latents_mean) * self.scale_factor / latents_std
200+
201+
def process_out(self, latent):
202+
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
203+
latents_std = self.latents_std.to(latent.device, latent.dtype)
204+
return latent * latents_std / self.scale_factor + latents_mean

comfy/ldm/common_dit.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,15 @@ def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
1313
except:
1414
rms_norm_torch = None
1515

16-
def rms_norm(x, weight, eps=1e-6):
16+
def rms_norm(x, weight=None, eps=1e-6):
1717
if rms_norm_torch is not None and not (torch.jit.is_tracing() or torch.jit.is_scripting()):
18-
return rms_norm_torch(x, weight.shape, weight=comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
18+
if weight is None:
19+
return rms_norm_torch(x, (x.shape[-1],), eps=eps)
20+
else:
21+
return rms_norm_torch(x, weight.shape, weight=comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
1922
else:
20-
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps)
21-
return (x * rrms) * comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device)
23+
r = x * torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps)
24+
if weight is None:
25+
return r
26+
else:
27+
return r * comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device)

0 commit comments

Comments
 (0)