diff --git a/scripts/demo/sampling.py b/scripts/demo/sampling.py index 3f3e7072..2984dbf7 100644 --- a/scripts/demo/sampling.py +++ b/scripts/demo/sampling.py @@ -1,14 +1,6 @@ -import numpy as np from pytorch_lightning import seed_everything from scripts.demo.streamlit_helpers import * -from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering -from sgm.inference.helpers import ( - do_img2img, - do_sample, - get_unique_embedder_keys_from_conditioner, - perform_save_locally, -) SAVE_PATH = "outputs/demo/txt2img/" @@ -42,7 +34,16 @@ } VERSION2SPECS = { - "SD-XL base": { + "SDXL-base-1.0": { + "H": 1024, + "W": 1024, + "C": 4, + "f": 8, + "is_legacy": False, + "config": "configs/inference/sd_xl_base.yaml", + "ckpt": "checkpoints/sd_xl_base_1.0.safetensors", + }, + "SDXL-base-0.9": { "H": 1024, "W": 1024, "C": 4, @@ -50,9 +51,8 @@ "is_legacy": False, "config": "configs/inference/sd_xl_base.yaml", "ckpt": "checkpoints/sd_xl_base_0.9.safetensors", - "is_guided": True, }, - "sd-2.1": { + "SD-2.1": { "H": 512, "W": 512, "C": 4, @@ -60,9 +60,8 @@ "is_legacy": True, "config": "configs/inference/sd_2_1.yaml", "ckpt": "checkpoints/v2-1_512-ema-pruned.safetensors", - "is_guided": True, }, - "sd-2.1-768": { + "SD-2.1-768": { "H": 768, "W": 768, "C": 4, @@ -71,7 +70,7 @@ "config": "configs/inference/sd_2_1_768.yaml", "ckpt": "checkpoints/v2-1_768-ema-pruned.safetensors", }, - "SDXL-Refiner": { + "SDXL-refiner-0.9": { "H": 1024, "W": 1024, "C": 4, @@ -79,7 +78,15 @@ "is_legacy": True, "config": "configs/inference/sd_xl_refiner.yaml", "ckpt": "checkpoints/sd_xl_refiner_0.9.safetensors", - "is_guided": True, + }, + "SDXL-refiner-1.0": { + "H": 1024, + "W": 1024, + "C": 4, + "f": 8, + "is_legacy": True, + "config": "configs/inference/sd_xl_refiner.yaml", + "ckpt": "checkpoints/sd_xl_refiner_1.0.safetensors", }, } @@ -103,18 +110,19 @@ def load_img(display=True, key=None, device="cuda"): def run_txt2img( - state, version, version_dict, is_legacy=False, return_latents=False, filter=None + state, + version, + version_dict, + is_legacy=False, + return_latents=False, + filter=None, + stage2strength=None, ): - if version == "SD-XL base": - ratio = st.sidebar.selectbox("Ratio:", list(SD_XL_BASE_RATIOS.keys()), 10) - W, H = SD_XL_BASE_RATIOS[ratio] + if version.startswith("SDXL-base"): + W, H = st.selectbox("Resolution:", list(SD_XL_BASE_RATIOS.values()), 10) else: - H = st.sidebar.number_input( - "H", value=version_dict["H"], min_value=64, max_value=2048 - ) - W = st.sidebar.number_input( - "W", value=version_dict["W"], min_value=64, max_value=2048 - ) + H = st.number_input("H", value=version_dict["H"], min_value=64, max_value=2048) + W = st.number_input("W", value=version_dict["W"], min_value=64, max_value=2048) C = version_dict["C"] F = version_dict["f"] @@ -130,16 +138,11 @@ def run_txt2img( prompt=prompt, negative_prompt=negative_prompt, ) - num_rows, num_cols, sampler = init_sampling( - use_identity_guider=not version_dict["is_guided"] - ) - + sampler, num_rows, num_cols = init_sampling(stage2strength=stage2strength) num_samples = num_rows * num_cols if st.button("Sample"): st.write(f"**Model I:** {version}") - outputs = st.empty() - st.text("Sampling") out = do_sample( state["model"], sampler, @@ -153,13 +156,16 @@ def run_txt2img( return_latents=return_latents, filter=filter, ) - show_samples(out, outputs) - return out def run_img2img( - state, version_dict, is_legacy=False, return_latents=False, filter=None + state, + version_dict, + is_legacy=False, + return_latents=False, + filter=None, + stage2strength=None, ): img = load_img() if img is None: @@ -175,19 +181,19 @@ def run_img2img( value_dict = init_embedder_options( get_unique_embedder_keys_from_conditioner(state["model"].conditioner), init_dict, + prompt=prompt, + negative_prompt=negative_prompt, ) strength = st.number_input( - "**Img2Img Strength**", value=0.5, min_value=0.0, max_value=1.0 + "**Img2Img Strength**", value=0.75, min_value=0.0, max_value=1.0 ) - num_rows, num_cols, sampler = init_sampling( + sampler, num_rows, num_cols = init_sampling( img2img_strength=strength, - use_identity_guider=not version_dict["is_guided"], + stage2strength=stage2strength, ) num_samples = num_rows * num_cols if st.button("Sample"): - outputs = st.empty() - st.text("Sampling") out = do_img2img( repeat(img, "1 ... -> n ...", n=num_samples), state["model"], @@ -198,7 +204,6 @@ def run_img2img( return_latents=return_latents, filter=filter, ) - show_samples(out, outputs) return out @@ -210,6 +215,7 @@ def apply_refiner( prompt, negative_prompt, filter=None, + finish_denoising=False, ): init_dict = { "orig_width": input.shape[3] * 8, @@ -237,6 +243,7 @@ def apply_refiner( num_samples, skip_encode=True, filter=filter, + add_noise=not finish_denoising, ) return samples @@ -249,20 +256,22 @@ def apply_refiner( mode = st.radio("Mode", ("txt2img", "img2img"), 0) st.write("__________________________") - if version == "SD-XL base": - add_pipeline = st.checkbox("Load SDXL-Refiner?", False) + set_lowvram_mode(st.checkbox("Low vram mode", True)) + + if version.startswith("SDXL-base"): + add_pipeline = st.checkbox("Load SDXL-refiner?", False) st.write("__________________________") else: add_pipeline = False - filter = DeepFloydDataFiltering(verbose=False) - seed = st.sidebar.number_input("seed", value=42, min_value=0, max_value=int(1e9)) seed_everything(seed) save_locally, save_path = init_save_locally(os.path.join(SAVE_PATH, version)) - state = init_st(version_dict) + state = init_st(version_dict, load_filter=True) + if state["msg"]: + st.info(state["msg"]) model = state["model"] is_legacy = version_dict["is_legacy"] @@ -276,29 +285,34 @@ def apply_refiner( else: negative_prompt = "" # which is unused + stage2strength = None + finish_denoising = False + if add_pipeline: st.write("__________________________") - - version2 = "SDXL-Refiner" + version2 = st.selectbox("Refiner:", ["SDXL-refiner-1.0", "SDXL-refiner-0.9"]) st.warning( f"Running with {version2} as the second stage model. Make sure to provide (V)RAM :) " ) st.write("**Refiner Options:**") version_dict2 = VERSION2SPECS[version2] - state2 = init_st(version_dict2) + state2 = init_st(version_dict2, load_filter=False) + st.info(state2["msg"]) stage2strength = st.number_input( - "**Refinement strength**", value=0.3, min_value=0.0, max_value=1.0 + "**Refinement strength**", value=0.15, min_value=0.0, max_value=1.0 ) - sampler2 = init_sampling( + sampler2, *_ = init_sampling( key=2, img2img_strength=stage2strength, - use_identity_guider=not version_dict2["is_guided"], - get_num_samples=False, + specify_num_samples=False, ) st.write("__________________________") + finish_denoising = st.checkbox("Finish denoising with refiner.", True) + if not finish_denoising: + stage2strength = None if mode == "txt2img": out = run_txt2img( @@ -307,7 +321,8 @@ def apply_refiner( version_dict, is_legacy=is_legacy, return_latents=add_pipeline, - filter=filter, + filter=state.get("filter"), + stage2strength=stage2strength, ) elif mode == "img2img": out = run_img2img( @@ -315,7 +330,8 @@ def apply_refiner( version_dict, is_legacy=is_legacy, return_latents=add_pipeline, - filter=filter, + filter=state.get("filter"), + stage2strength=stage2strength, ) else: raise ValueError(f"unknown mode {mode}") @@ -326,7 +342,6 @@ def apply_refiner( samples_z = None if add_pipeline and samples_z is not None: - outputs = st.empty() st.write("**Running Refinement Stage**") samples = apply_refiner( samples_z, @@ -335,9 +350,9 @@ def apply_refiner( samples_z.shape[0], prompt=prompt, negative_prompt=negative_prompt if is_legacy else "", - filter=filter, + filter=state.get("filter"), + finish_denoising=finish_denoising, ) - show_samples(samples, outputs) if save_locally and samples is not None: perform_save_locally(save_path, samples) diff --git a/scripts/demo/streamlit_helpers.py b/scripts/demo/streamlit_helpers.py index 4b752a7a..82b7fb9c 100644 --- a/scripts/demo/streamlit_helpers.py +++ b/scripts/demo/streamlit_helpers.py @@ -1,12 +1,20 @@ +import math import os +from typing import List, Union + +import numpy as np import streamlit as st import torch -from PIL import Image from einops import rearrange, repeat -from omegaconf import OmegaConf +from imwatermark import WatermarkEncoder +from omegaconf import ListConfig, OmegaConf +from PIL import Image +from safetensors.torch import load_file as load_safetensors +from torch import autocast from torchvision import transforms +from torchvision.utils import make_grid - +from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering from sgm.modules.diffusionmodules.sampling import ( DPMPP2MSampler, DPMPP2SAncestralSampler, @@ -15,29 +23,140 @@ HeunEDMSampler, LinearMultistepSampler, ) -from sgm.inference.helpers import Img2ImgDiscretizationWrapper, embed_watermark -from sgm.util import load_model_from_config +from sgm.util import append_dims, instantiate_from_config + + +class WatermarkEmbedder: + def __init__(self, watermark): + self.watermark = watermark + self.num_bits = len(WATERMARK_BITS) + self.encoder = WatermarkEncoder() + self.encoder.set_watermark("bits", self.watermark) + + def __call__(self, image: torch.Tensor): + """ + Adds a predefined watermark to the input image + + Args: + image: ([N,] B, C, H, W) in range [0, 1] + + Returns: + same as input but watermarked + """ + # watermarking libary expects input as cv2 BGR format + squeeze = len(image.shape) == 4 + if squeeze: + image = image[None, ...] + n = image.shape[0] + image_np = rearrange( + (255 * image).detach().cpu(), "n b c h w -> (n b) h w c" + ).numpy()[:, :, :, ::-1] + # torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255] + for k in range(image_np.shape[0]): + image_np[k] = self.encoder.encode(image_np[k], "dwtDct") + image = torch.from_numpy( + rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n) + ).to(image.device) + image = torch.clamp(image / 255, min=0.0, max=1.0) + if squeeze: + image = image[0] + return image + + +# A fixed 48-bit message that was choosen at random +# WATERMARK_MESSAGE = 0xB3EC907BB19E +WATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110 +# bin(x)[2:] gives bits of x as str, use int to convert them to 0/1 +WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]] +embed_watemark = WatermarkEmbedder(WATERMARK_BITS) @st.cache_resource() -def init_st(version_dict, load_ckpt=True): +def init_st(version_dict, load_ckpt=True, load_filter=True): state = dict() if not "model" in state: config = version_dict["config"] ckpt = version_dict["ckpt"] config = OmegaConf.load(config) - model = load_model_from_config(config, ckpt if load_ckpt else None) - model = model.to("cuda") - model.conditioner.half() - model.model.half() + model, msg = load_model_from_config(config, ckpt if load_ckpt else None) + state["msg"] = msg state["model"] = model state["ckpt"] = ckpt if load_ckpt else None state["config"] = config + if load_filter: + state["filter"] = DeepFloydDataFiltering(verbose=False) return state +def load_model(model): + model.cuda() + + +lowvram_mode = False + + +def set_lowvram_mode(mode): + global lowvram_mode + lowvram_mode = mode + + +def initial_model_load(model): + global lowvram_mode + if lowvram_mode: + model.model.half() + else: + model.cuda() + return model + + +def unload_model(model): + global lowvram_mode + if lowvram_mode: + model.cpu() + torch.cuda.empty_cache() + + +def load_model_from_config(config, ckpt=None, verbose=True): + model = instantiate_from_config(config.model) + + if ckpt is not None: + print(f"Loading model from {ckpt}") + if ckpt.endswith("ckpt"): + pl_sd = torch.load(ckpt, map_location="cpu") + if "global_step" in pl_sd: + global_step = pl_sd["global_step"] + st.info(f"loaded ckpt from global step {global_step}") + print(f"Global Step: {pl_sd['global_step']}") + sd = pl_sd["state_dict"] + elif ckpt.endswith("safetensors"): + sd = load_safetensors(ckpt) + else: + raise NotImplementedError + + msg = None + + m, u = model.load_state_dict(sd, strict=False) + + if len(m) > 0 and verbose: + print("missing keys:") + print(m) + if len(u) > 0 and verbose: + print("unexpected keys:") + print(u) + else: + msg = None + + model = initial_model_load(model) + model.eval() + return model, msg + + +def get_unique_embedder_keys_from_conditioner(conditioner): + return list(set([x.input_key for x in conditioner.embedders])) + + def init_embedder_options(keys, init_dict, prompt=None, negative_prompt=None): # Hardcoded demo settings; might undergo some changes in the future @@ -81,23 +200,24 @@ def init_embedder_options(keys, init_dict, prompt=None, negative_prompt=None): value_dict["negative_aesthetic_score"] = 2.5 if key == "target_size_as_tuple": - target_width = st.number_input( - "target_width", - value=init_dict["target_width"], - min_value=16, - ) - target_height = st.number_input( - "target_height", - value=init_dict["target_height"], - min_value=16, - ) - - value_dict["target_width"] = target_width - value_dict["target_height"] = target_height + value_dict["target_width"] = init_dict["target_width"] + value_dict["target_height"] = init_dict["target_height"] return value_dict +def perform_save_locally(save_path, samples): + os.makedirs(os.path.join(save_path), exist_ok=True) + base_count = len(os.listdir(os.path.join(save_path))) + samples = embed_watemark(samples) + for sample in samples: + sample = 255.0 * rearrange(sample.cpu().numpy(), "c h w -> h w c") + Image.fromarray(sample.astype(np.uint8)).save( + os.path.join(save_path, f"{base_count:09}.png") + ) + base_count += 1 + + def init_save_locally(_dir, init_value: bool = False): save_locally = st.sidebar.checkbox("Save images locally", value=init_value) if save_locally: @@ -108,12 +228,58 @@ def init_save_locally(_dir, init_value: bool = False): return save_locally, save_path -def show_samples(samples, outputs): - if isinstance(samples, tuple): - samples, _ = samples - grid = embed_watermark(torch.stack([samples])) - grid = rearrange(grid, "n b c h w -> (n h) (b w) c") - outputs.image(grid.cpu().numpy()) +class Img2ImgDiscretizationWrapper: + """ + wraps a discretizer, and prunes the sigmas + params: + strength: float between 0.0 and 1.0. 1.0 means full sampling (all sigmas are returned) + """ + + def __init__(self, discretization, strength: float = 1.0): + self.discretization = discretization + self.strength = strength + assert 0.0 <= self.strength <= 1.0 + + def __call__(self, *args, **kwargs): + # sigmas start large first, and decrease then + sigmas = self.discretization(*args, **kwargs) + print(f"sigmas after discretization, before pruning img2img: ", sigmas) + sigmas = torch.flip(sigmas, (0,)) + sigmas = sigmas[: max(int(self.strength * len(sigmas)), 1)] + print("prune index:", max(int(self.strength * len(sigmas)), 1)) + sigmas = torch.flip(sigmas, (0,)) + print(f"sigmas after pruning: ", sigmas) + return sigmas + + +class Txt2NoisyDiscretizationWrapper: + """ + wraps a discretizer, and prunes the sigmas + params: + strength: float between 0.0 and 1.0. 0.0 means full sampling (all sigmas are returned) + """ + + def __init__(self, discretization, strength: float = 0.0, original_steps=None): + self.discretization = discretization + self.strength = strength + self.original_steps = original_steps + assert 0.0 <= self.strength <= 1.0 + + def __call__(self, *args, **kwargs): + # sigmas start large first, and decrease then + sigmas = self.discretization(*args, **kwargs) + print(f"sigmas after discretization, before pruning img2img: ", sigmas) + sigmas = torch.flip(sigmas, (0,)) + if self.original_steps is None: + steps = len(sigmas) + else: + steps = self.original_steps + 1 + prune_index = max(min(int(self.strength * steps) - 1, steps - 1), 0) + sigmas = sigmas[prune_index:] + print("prune index:", prune_index) + sigmas = torch.flip(sigmas, (0,)) + print(f"sigmas after pruning: ", sigmas) + return sigmas def get_guider(key): @@ -158,16 +324,19 @@ def get_guider(key): def init_sampling( - key=1, img2img_strength=1.0, use_identity_guider=False, get_num_samples=True + key=1, + img2img_strength=1.0, + specify_num_samples=True, + stage2strength=None, ): - if get_num_samples: - num_rows = 1 + num_rows, num_cols = 1, 1 + if specify_num_samples: num_cols = st.number_input( f"num cols #{key}", value=2, min_value=1, max_value=10 ) steps = st.sidebar.number_input( - f"steps #{key}", value=50, min_value=1, max_value=1000 + f"steps #{key}", value=40, min_value=1, max_value=1000 ) sampler = st.sidebar.selectbox( f"Sampler #{key}", @@ -201,9 +370,11 @@ def init_sampling( sampler.discretization = Img2ImgDiscretizationWrapper( sampler.discretization, strength=img2img_strength ) - if get_num_samples: - return num_rows, num_cols, sampler - return sampler + if stage2strength is not None: + sampler.discretization = Txt2NoisyDiscretizationWrapper( + sampler.discretization, strength=stage2strength, original_steps=steps + ) + return sampler, num_rows, num_cols def get_discretization(discretization, key=1): @@ -336,3 +507,238 @@ def get_init_img(batch_size=1, key=None): init_image = load_img(key=key).cuda() init_image = repeat(init_image, "1 ... -> b ...", b=batch_size) return init_image + + +def do_sample( + model, + sampler, + value_dict, + num_samples, + H, + W, + C, + F, + force_uc_zero_embeddings: List = None, + batch2model_input: List = None, + return_latents=False, + filter=None, +): + if force_uc_zero_embeddings is None: + force_uc_zero_embeddings = [] + if batch2model_input is None: + batch2model_input = [] + + st.text("Sampling") + + outputs = st.empty() + precision_scope = autocast + with torch.no_grad(): + with precision_scope("cuda"): + with model.ema_scope(): + num_samples = [num_samples] + load_model(model.conditioner) + batch, batch_uc = get_batch( + get_unique_embedder_keys_from_conditioner(model.conditioner), + value_dict, + num_samples, + ) + for key in batch: + if isinstance(batch[key], torch.Tensor): + print(key, batch[key].shape) + elif isinstance(batch[key], list): + print(key, [len(l) for l in batch[key]]) + else: + print(key, batch[key]) + c, uc = model.conditioner.get_unconditional_conditioning( + batch, + batch_uc=batch_uc, + force_uc_zero_embeddings=force_uc_zero_embeddings, + ) + unload_model(model.conditioner) + + for k in c: + if not k == "crossattn": + c[k], uc[k] = map( + lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc) + ) + + additional_model_inputs = {} + for k in batch2model_input: + additional_model_inputs[k] = batch[k] + + shape = (math.prod(num_samples), C, H // F, W // F) + randn = torch.randn(shape).to("cuda") + + def denoiser(input, sigma, c): + return model.denoiser( + model.model, input, sigma, c, **additional_model_inputs + ) + + load_model(model.denoiser) + load_model(model.model) + samples_z = sampler(denoiser, randn, cond=c, uc=uc) + unload_model(model.model) + unload_model(model.denoiser) + + load_model(model.first_stage_model) + samples_x = model.decode_first_stage(samples_z) + samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) + unload_model(model.first_stage_model) + + if filter is not None: + samples = filter(samples) + + grid = torch.stack([samples]) + grid = rearrange(grid, "n b c h w -> (n h) (b w) c") + outputs.image(grid.cpu().numpy()) + + if return_latents: + return samples, samples_z + return samples + + +def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"): + # Hardcoded demo setups; might undergo some changes in the future + + batch = {} + batch_uc = {} + + for key in keys: + if key == "txt": + batch["txt"] = ( + np.repeat([value_dict["prompt"]], repeats=math.prod(N)) + .reshape(N) + .tolist() + ) + batch_uc["txt"] = ( + np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)) + .reshape(N) + .tolist() + ) + elif key == "original_size_as_tuple": + batch["original_size_as_tuple"] = ( + torch.tensor([value_dict["orig_height"], value_dict["orig_width"]]) + .to(device) + .repeat(*N, 1) + ) + elif key == "crop_coords_top_left": + batch["crop_coords_top_left"] = ( + torch.tensor( + [value_dict["crop_coords_top"], value_dict["crop_coords_left"]] + ) + .to(device) + .repeat(*N, 1) + ) + elif key == "aesthetic_score": + batch["aesthetic_score"] = ( + torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1) + ) + batch_uc["aesthetic_score"] = ( + torch.tensor([value_dict["negative_aesthetic_score"]]) + .to(device) + .repeat(*N, 1) + ) + + elif key == "target_size_as_tuple": + batch["target_size_as_tuple"] = ( + torch.tensor([value_dict["target_height"], value_dict["target_width"]]) + .to(device) + .repeat(*N, 1) + ) + else: + batch[key] = value_dict[key] + + for key in batch.keys(): + if key not in batch_uc and isinstance(batch[key], torch.Tensor): + batch_uc[key] = torch.clone(batch[key]) + return batch, batch_uc + + +@torch.no_grad() +def do_img2img( + img, + model, + sampler, + value_dict, + num_samples, + force_uc_zero_embeddings=[], + additional_kwargs={}, + offset_noise_level: int = 0.0, + return_latents=False, + skip_encode=False, + filter=None, + add_noise=True, +): + st.text("Sampling") + + outputs = st.empty() + precision_scope = autocast + with torch.no_grad(): + with precision_scope("cuda"): + with model.ema_scope(): + load_model(model.conditioner) + batch, batch_uc = get_batch( + get_unique_embedder_keys_from_conditioner(model.conditioner), + value_dict, + [num_samples], + ) + c, uc = model.conditioner.get_unconditional_conditioning( + batch, + batch_uc=batch_uc, + force_uc_zero_embeddings=force_uc_zero_embeddings, + ) + unload_model(model.conditioner) + for k in c: + c[k], uc[k] = map(lambda y: y[k][:num_samples].to("cuda"), (c, uc)) + + for k in additional_kwargs: + c[k] = uc[k] = additional_kwargs[k] + if skip_encode: + z = img + else: + load_model(model.first_stage_model) + z = model.encode_first_stage(img) + unload_model(model.first_stage_model) + + noise = torch.randn_like(z) + + sigmas = sampler.discretization(sampler.num_steps).cuda() + sigma = sigmas[0] + + st.info(f"all sigmas: {sigmas}") + st.info(f"noising sigma: {sigma}") + if offset_noise_level > 0.0: + noise = noise + offset_noise_level * append_dims( + torch.randn(z.shape[0], device=z.device), z.ndim + ) + if add_noise: + noised_z = z + noise * append_dims(sigma, z.ndim).cuda() + noised_z = noised_z / torch.sqrt( + 1.0 + sigmas[0] ** 2.0 + ) # Note: hardcoded to DDPM-like scaling. need to generalize later. + else: + noised_z = z / torch.sqrt(1.0 + sigmas[0] ** 2.0) + + def denoiser(x, sigma, c): + return model.denoiser(model.model, x, sigma, c) + + load_model(model.denoiser) + load_model(model.model) + samples_z = sampler(denoiser, noised_z, cond=c, uc=uc) + unload_model(model.model) + unload_model(model.denoiser) + + load_model(model.first_stage_model) + samples_x = model.decode_first_stage(samples_z) + unload_model(model.first_stage_model) + samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) + + if filter is not None: + samples = filter(samples) + + grid = embed_watemark(torch.stack([samples])) + grid = rearrange(grid, "n b c h w -> (n h) (b w) c") + outputs.image(grid.cpu().numpy()) + if return_latents: + return samples, samples_z + return samples