From 418eb7062dcbd5c68f869527a7bc34cff55ca87e Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 20 Dec 2024 04:38:29 -0500 Subject: [PATCH] Support new LTXV VAE. --- .../vae/causal_video_autoencoder.py | 297 +++++++++++++++--- comfy/sd.py | 8 +- 2 files changed, 260 insertions(+), 45 deletions(-) diff --git a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py index 3bd59a76..4d43feb2 100644 --- a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py +++ b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py @@ -6,7 +6,9 @@ from typing import Optional, Tuple, Union from .conv_nd_factory import make_conv_nd, make_linear_nd from .pixel_norm import PixelNorm - +from ..model import PixArtAlphaCombinedTimestepSizeEmbeddings +import comfy.ops +ops = comfy.ops.disable_weight_init class Encoder(nn.Module): r""" @@ -236,6 +238,7 @@ def __init__( patch_size: int = 1, norm_layer: str = "group_norm", causal: bool = True, + timestep_conditioning: bool = False, ): super().__init__() self.patch_size = patch_size @@ -250,6 +253,8 @@ def __init__( block_params = block_params if isinstance(block_params, dict) else {} if block_name == "res_x_y": output_channel = output_channel * block_params.get("multiplier", 2) + if block_name == "compress_all": + output_channel = output_channel * block_params.get("multiplier", 1) self.conv_in = make_conv_nd( dims, @@ -276,6 +281,19 @@ def __init__( resnet_eps=1e-6, resnet_groups=norm_num_groups, norm_layer=norm_layer, + inject_noise=block_params.get("inject_noise", False), + timestep_conditioning=timestep_conditioning, + ) + elif block_name == "attn_res_x": + block = UNetMidBlock3D( + dims=dims, + in_channels=input_channel, + num_layers=block_params["num_layers"], + resnet_groups=norm_num_groups, + norm_layer=norm_layer, + inject_noise=block_params.get("inject_noise", False), + timestep_conditioning=timestep_conditioning, + attention_head_dim=block_params["attention_head_dim"], ) elif block_name == "res_x_y": output_channel = output_channel // block_params.get("multiplier", 2) @@ -286,6 +304,8 @@ def __init__( eps=1e-6, groups=norm_num_groups, norm_layer=norm_layer, + inject_noise=block_params.get("inject_noise", False), + timestep_conditioning=False, ) elif block_name == "compress_time": block = DepthToSpaceUpsample( @@ -296,11 +316,13 @@ def __init__( dims=dims, in_channels=input_channel, stride=(1, 2, 2) ) elif block_name == "compress_all": + output_channel = output_channel // block_params.get("multiplier", 1) block = DepthToSpaceUpsample( dims=dims, in_channels=input_channel, stride=(2, 2, 2), residual=block_params.get("residual", False), + out_channels_reduction_factor=block_params.get("multiplier", 1), ) else: raise ValueError(f"unknown layer: {block_name}") @@ -323,27 +345,75 @@ def __init__( self.gradient_checkpointing = False + self.timestep_conditioning = timestep_conditioning + + if timestep_conditioning: + self.timestep_scale_multiplier = nn.Parameter( + torch.tensor(1000.0, dtype=torch.float32) + ) + self.last_time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings( + output_channel * 2, 0, operations=ops, + ) + self.last_scale_shift_table = nn.Parameter(torch.empty(2, output_channel)) + # def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor: - def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor: + def forward( + self, + sample: torch.FloatTensor, + timestep: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: r"""The forward method of the `Decoder` class.""" - # assert target_shape is not None, "target_shape must be provided" + batch_size = sample.shape[0] sample = self.conv_in(sample, causal=self.causal) - upscale_dtype = next(iter(self.up_blocks.parameters())).dtype - checkpoint_fn = ( partial(torch.utils.checkpoint.checkpoint, use_reentrant=False) if self.gradient_checkpointing and self.training else lambda x: x ) - sample = sample.to(upscale_dtype) + scaled_timestep = None + if self.timestep_conditioning: + assert ( + timestep is not None + ), "should pass timestep with timestep_conditioning=True" + scaled_timestep = timestep * self.timestep_scale_multiplier for up_block in self.up_blocks: - sample = checkpoint_fn(up_block)(sample, causal=self.causal) + if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D): + sample = checkpoint_fn(up_block)( + sample, causal=self.causal, timestep=scaled_timestep + ) + else: + sample = checkpoint_fn(up_block)(sample, causal=self.causal) sample = self.conv_norm_out(sample) + + if self.timestep_conditioning: + embedded_timestep = self.last_time_embedder( + timestep=scaled_timestep.flatten(), + resolution=None, + aspect_ratio=None, + batch_size=sample.shape[0], + hidden_dtype=sample.dtype, + ) + embedded_timestep = embedded_timestep.view( + batch_size, embedded_timestep.shape[-1], 1, 1, 1 + ) + ada_values = self.last_scale_shift_table[ + None, ..., None, None, None + ] + embedded_timestep.reshape( + batch_size, + 2, + -1, + embedded_timestep.shape[-3], + embedded_timestep.shape[-2], + embedded_timestep.shape[-1], + ) + shift, scale = ada_values.unbind(dim=1) + sample = sample * (1 + scale) + shift + sample = self.conv_act(sample) sample = self.conv_out(sample, causal=self.causal) @@ -379,12 +449,21 @@ def __init__( resnet_eps: float = 1e-6, resnet_groups: int = 32, norm_layer: str = "group_norm", + inject_noise: bool = False, + timestep_conditioning: bool = False, ): super().__init__() resnet_groups = ( resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) ) + self.timestep_conditioning = timestep_conditioning + + if timestep_conditioning: + self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings( + in_channels * 4, 0, operations=ops, + ) + self.res_blocks = nn.ModuleList( [ ResnetBlock3D( @@ -395,25 +474,48 @@ def __init__( groups=resnet_groups, dropout=dropout, norm_layer=norm_layer, + inject_noise=inject_noise, + timestep_conditioning=timestep_conditioning, ) for _ in range(num_layers) ] ) def forward( - self, hidden_states: torch.FloatTensor, causal: bool = True + self, hidden_states: torch.FloatTensor, causal: bool = True, timestep: Optional[torch.Tensor] = None ) -> torch.FloatTensor: + timestep_embed = None + if self.timestep_conditioning: + assert ( + timestep is not None + ), "should pass timestep with timestep_conditioning=True" + batch_size = hidden_states.shape[0] + timestep_embed = self.time_embedder( + timestep=timestep.flatten(), + resolution=None, + aspect_ratio=None, + batch_size=batch_size, + hidden_dtype=hidden_states.dtype, + ) + timestep_embed = timestep_embed.view( + batch_size, timestep_embed.shape[-1], 1, 1, 1 + ) + for resnet in self.res_blocks: - hidden_states = resnet(hidden_states, causal=causal) + hidden_states = resnet(hidden_states, causal=causal, timestep=timestep_embed) return hidden_states class DepthToSpaceUpsample(nn.Module): - def __init__(self, dims, in_channels, stride, residual=False): + def __init__( + self, dims, in_channels, stride, residual=False, out_channels_reduction_factor=1 + ): super().__init__() self.stride = stride - self.out_channels = math.prod(stride) * in_channels + self.out_channels = ( + math.prod(stride) * in_channels // out_channels_reduction_factor + ) self.conv = make_conv_nd( dims=dims, in_channels=in_channels, @@ -423,8 +525,9 @@ def __init__(self, dims, in_channels, stride, residual=False): causal=True, ) self.residual = residual + self.out_channels_reduction_factor = out_channels_reduction_factor - def forward(self, x, causal: bool = True): + def forward(self, x, causal: bool = True, timestep: Optional[torch.Tensor] = None): if self.residual: # Reshape and duplicate the input to match the output shape x_in = rearrange( @@ -434,7 +537,8 @@ def forward(self, x, causal: bool = True): p2=self.stride[1], p3=self.stride[2], ) - x_in = x_in.repeat(1, math.prod(self.stride), 1, 1, 1) + num_repeat = math.prod(self.stride) // self.out_channels_reduction_factor + x_in = x_in.repeat(1, num_repeat, 1, 1, 1) if self.stride[0] == 2: x_in = x_in[:, :, 1:, :, :] x = self.conv(x, causal=causal) @@ -451,7 +555,6 @@ def forward(self, x, causal: bool = True): x = x + x_in return x - class LayerNorm(nn.Module): def __init__(self, dim, eps, elementwise_affine=True) -> None: super().__init__() @@ -486,11 +589,14 @@ def __init__( groups: int = 32, eps: float = 1e-6, norm_layer: str = "group_norm", + inject_noise: bool = False, + timestep_conditioning: bool = False, ): super().__init__() self.in_channels = in_channels out_channels = in_channels if out_channels is None else out_channels self.out_channels = out_channels + self.inject_noise = inject_noise if norm_layer == "group_norm": self.norm1 = nn.GroupNorm( @@ -513,6 +619,9 @@ def __init__( causal=True, ) + if inject_noise: + self.per_channel_scale1 = nn.Parameter(torch.zeros((in_channels, 1, 1))) + if norm_layer == "group_norm": self.norm2 = nn.GroupNorm( num_groups=groups, num_channels=out_channels, eps=eps, affine=True @@ -534,6 +643,9 @@ def __init__( causal=True, ) + if inject_noise: + self.per_channel_scale2 = nn.Parameter(torch.zeros((in_channels, 1, 1))) + self.conv_shortcut = ( make_linear_nd( dims=dims, in_channels=in_channels, out_channels=out_channels @@ -548,29 +660,84 @@ def __init__( else nn.Identity() ) + self.timestep_conditioning = timestep_conditioning + + if timestep_conditioning: + self.scale_shift_table = nn.Parameter( + torch.randn(4, in_channels) / in_channels**0.5 + ) + + def _feed_spatial_noise( + self, hidden_states: torch.FloatTensor, per_channel_scale: torch.FloatTensor + ) -> torch.FloatTensor: + spatial_shape = hidden_states.shape[-2:] + device = hidden_states.device + dtype = hidden_states.dtype + + # similar to the "explicit noise inputs" method in style-gan + spatial_noise = torch.randn(spatial_shape, device=device, dtype=dtype)[None] + scaled_noise = (spatial_noise * per_channel_scale)[None, :, None, ...] + hidden_states = hidden_states + scaled_noise + + return hidden_states + def forward( self, input_tensor: torch.FloatTensor, causal: bool = True, + timestep: Optional[torch.Tensor] = None, ) -> torch.FloatTensor: hidden_states = input_tensor + batch_size = hidden_states.shape[0] hidden_states = self.norm1(hidden_states) + if self.timestep_conditioning: + assert ( + timestep is not None + ), "should pass timestep with timestep_conditioning=True" + ada_values = self.scale_shift_table[ + None, ..., None, None, None + ] + timestep.reshape( + batch_size, + 4, + -1, + timestep.shape[-3], + timestep.shape[-2], + timestep.shape[-1], + ) + shift1, scale1, shift2, scale2 = ada_values.unbind(dim=1) + + hidden_states = hidden_states * (1 + scale1) + shift1 hidden_states = self.non_linearity(hidden_states) hidden_states = self.conv1(hidden_states, causal=causal) + if self.inject_noise: + hidden_states = self._feed_spatial_noise( + hidden_states, self.per_channel_scale1 + ) + hidden_states = self.norm2(hidden_states) + if self.timestep_conditioning: + hidden_states = hidden_states * (1 + scale2) + shift2 + hidden_states = self.non_linearity(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.conv2(hidden_states, causal=causal) + if self.inject_noise: + hidden_states = self._feed_spatial_noise( + hidden_states, self.per_channel_scale2 + ) + input_tensor = self.norm3(input_tensor) + batch_size = input_tensor.shape[0] + input_tensor = self.conv_shortcut(input_tensor) output_tensor = input_tensor + hidden_states @@ -634,33 +801,71 @@ def normalize(self, x): return (x - self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)) / self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x) class VideoVAE(nn.Module): - def __init__(self): + def __init__(self, version=0): super().__init__() - config = { - "_class_name": "CausalVideoAutoencoder", - "dims": 3, - "in_channels": 3, - "out_channels": 3, - "latent_channels": 128, - "blocks": [ - ["res_x", 4], - ["compress_all", 1], - ["res_x_y", 1], - ["res_x", 3], - ["compress_all", 1], - ["res_x_y", 1], - ["res_x", 3], - ["compress_all", 1], - ["res_x", 3], - ["res_x", 4], - ], - "scaling_factor": 1.0, - "norm_layer": "pixel_norm", - "patch_size": 4, - "latent_log_var": "uniform", - "use_quant_conv": False, - "causal_decoder": False, - } + + if version == 0: + config = { + "_class_name": "CausalVideoAutoencoder", + "dims": 3, + "in_channels": 3, + "out_channels": 3, + "latent_channels": 128, + "blocks": [ + ["res_x", 4], + ["compress_all", 1], + ["res_x_y", 1], + ["res_x", 3], + ["compress_all", 1], + ["res_x_y", 1], + ["res_x", 3], + ["compress_all", 1], + ["res_x", 3], + ["res_x", 4], + ], + "scaling_factor": 1.0, + "norm_layer": "pixel_norm", + "patch_size": 4, + "latent_log_var": "uniform", + "use_quant_conv": False, + "causal_decoder": False, + } + else: + config = { + "_class_name": "CausalVideoAutoencoder", + "dims": 3, + "in_channels": 3, + "out_channels": 3, + "latent_channels": 128, + "decoder_blocks": [ + ["res_x", {"num_layers": 5, "inject_noise": True}], + ["compress_all", {"residual": True, "multiplier": 2}], + ["res_x", {"num_layers": 6, "inject_noise": True}], + ["compress_all", {"residual": True, "multiplier": 2}], + ["res_x", {"num_layers": 7, "inject_noise": True}], + ["compress_all", {"residual": True, "multiplier": 2}], + ["res_x", {"num_layers": 8, "inject_noise": False}] + ], + "encoder_blocks": [ + ["res_x", {"num_layers": 4}], + ["compress_all", {}], + ["res_x_y", 1], + ["res_x", {"num_layers": 3}], + ["compress_all", {}], + ["res_x_y", 1], + ["res_x", {"num_layers": 3}], + ["compress_all", {}], + ["res_x", {"num_layers": 3}], + ["res_x", {"num_layers": 4}] + ], + "scaling_factor": 1.0, + "norm_layer": "pixel_norm", + "patch_size": 4, + "latent_log_var": "uniform", + "use_quant_conv": False, + "causal_decoder": False, + "timestep_conditioning": True, + } double_z = config.get("double_z", True) latent_log_var = config.get( @@ -671,7 +876,7 @@ def __init__(self): dims=config["dims"], in_channels=config.get("in_channels", 3), out_channels=config["latent_channels"], - blocks=config.get("encoder_blocks", config.get("blocks")), + blocks=config.get("encoder_blocks", config.get("encoder_blocks", config.get("blocks"))), patch_size=config.get("patch_size", 1), latent_log_var=latent_log_var, norm_layer=config.get("norm_layer", "group_norm"), @@ -681,18 +886,22 @@ def __init__(self): dims=config["dims"], in_channels=config["latent_channels"], out_channels=config.get("out_channels", 3), - blocks=config.get("decoder_blocks", config.get("blocks")), + blocks=config.get("decoder_blocks", config.get("decoder_blocks", config.get("blocks"))), patch_size=config.get("patch_size", 1), norm_layer=config.get("norm_layer", "group_norm"), causal=config.get("causal_decoder", False), + timestep_conditioning=config.get("timestep_conditioning", False), ) + self.timestep_conditioning = config.get("timestep_conditioning", False) self.per_channel_statistics = processor() def encode(self, x): means, logvar = torch.chunk(self.encoder(x), 2, dim=1) return self.per_channel_statistics.normalize(means) - def decode(self, x): - return self.decoder(self.per_channel_statistics.un_normalize(x)) + def decode(self, x, timestep=0.05, noise_scale=0.025): + if self.timestep_conditioning: #TODO: seed + x = torch.randn_like(x) * noise_scale + (1.0 - noise_scale) * x + return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=timestep) diff --git a/comfy/sd.py b/comfy/sd.py index b5cf296c..dee8e984 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -340,7 +340,13 @@ def __init__(self, sd=None, device=None, config=None, dtype=None): self.downscale_ratio = (lambda a: max(0, math.floor((a + 5) / 6)), 8, 8) self.working_dtypes = [torch.float16, torch.float32] elif "decoder.up_blocks.0.res_blocks.0.conv1.conv.weight" in sd: #lightricks ltxv - self.first_stage_model = comfy.ldm.lightricks.vae.causal_video_autoencoder.VideoVAE() + tensor_conv1 = sd["decoder.up_blocks.0.res_blocks.0.conv1.conv.weight"] + version = 0 + if tensor_conv1.shape[0] == 512: + version = 0 + elif tensor_conv1.shape[0] == 1024: + version = 1 + self.first_stage_model = comfy.ldm.lightricks.vae.causal_video_autoencoder.VideoVAE(version=version) self.latent_channels = 128 self.latent_dim = 3 self.memory_used_decode = lambda shape, dtype: (900 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype)