Skip to content

Commit

Permalink
Make VAEDecodeTiled node work with video VAEs.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Nov 7, 2024
1 parent 5e29e7a commit b49616f
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 6 deletions.
21 changes: 18 additions & 3 deletions comfy/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,10 +361,25 @@ def decode(self, samples_in):
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
return pixel_samples

def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap = 16):
def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None):
model_management.load_model_gpu(self.patcher)
output = self.decode_tiled_(samples, tile_x, tile_y, overlap)
return output.movedim(1,-1)
dims = samples.ndim - 2
args = {}
if tile_x is not None:
args["tile_x"] = tile_x
if tile_y is not None:
args["tile_y"] = tile_y
if overlap is not None:
args["overlap"] = overlap

if dims == 1:
args.pop("tile_y")
output = self.decode_tiled_1d(samples, **args)
elif dims == 2:
output = self.decode_tiled_(samples, **args)
elif dims == 3:
output = self.decode_tiled_3d(samples, **args)
return output.movedim(1, -1)

def encode(self, pixel_samples):
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
Expand Down
12 changes: 9 additions & 3 deletions nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,15 +290,21 @@ class VAEDecodeTiled:
@classmethod
def INPUT_TYPES(s):
return {"required": {"samples": ("LATENT", ), "vae": ("VAE", ),
"tile_size": ("INT", {"default": 512, "min": 320, "max": 4096, "step": 64})
"tile_size": ("INT", {"default": 512, "min": 128, "max": 4096, "step": 32}),
"overlap": ("INT", {"default": 64, "min": 0, "max": 4096, "step": 32}),
}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "decode"

CATEGORY = "_for_testing"

def decode(self, vae, samples, tile_size):
return (vae.decode_tiled(samples["samples"], tile_x=tile_size // 8, tile_y=tile_size // 8, ), )
def decode(self, vae, samples, tile_size, overlap):
if tile_size < overlap * 4:
overlap = tile_size // 4
images = vae.decode_tiled(samples["samples"], tile_x=tile_size // 8, tile_y=tile_size // 8, overlap=overlap // 8)
if len(images.shape) == 5: #Combine batches
images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1])
return (images, )

class VAEEncode:
@classmethod
Expand Down

0 comments on commit b49616f

Please sign in to comment.