diff --git a/.github/workflows/CODEOWNERS b/.github/workflows/CODEOWNERS new file mode 100644 index 00000000..ae8a7a26 --- /dev/null +++ b/.github/workflows/CODEOWNERS @@ -0,0 +1 @@ +.github @Stability-AI/infrastructure \ No newline at end of file diff --git a/.github/workflows/black.yml b/.github/workflows/black.yml index 80823b44..ab652601 100644 --- a/.github/workflows/black.yml +++ b/.github/workflows/black.yml @@ -1,5 +1,5 @@ name: Run black -on: [push, pull_request] +on: [pull_request] jobs: lint: diff --git a/.github/workflows/test-build.yaml b/.github/workflows/test-build.yaml index ffbeff46..8aabe376 100644 --- a/.github/workflows/test-build.yaml +++ b/.github/workflows/test-build.yaml @@ -2,6 +2,7 @@ name: Build package on: push: + branches: [ main ] pull_request: jobs: diff --git a/.github/workflows/test-inference.yml b/.github/workflows/test-inference.yml new file mode 100644 index 00000000..9687d7ea --- /dev/null +++ b/.github/workflows/test-inference.yml @@ -0,0 +1,34 @@ +name: Test inference + +on: + pull_request: + push: + branches: + - main + +jobs: + test: + name: "Test inference" + # This action is designed only to run on the Stability research cluster at this time, so many assumptions are made about the environment + if: github.repository == 'stability-ai/generative-models' + runs-on: [self-hosted, slurm, g40] + steps: + - uses: actions/checkout@v3 + - name: "Symlink checkpoints" + run: ln -s ${{secrets.SGM_CHECKPOINTS_PATH}} checkpoints + - name: "Setup python" + uses: actions/setup-python@v4 + with: + python-version: "3.10" + - name: "Install Hatch" + run: pip install hatch + - name: "Run inference tests" + run: hatch run ci:test-inference --junit-xml test-results.xml + - name: Surface failing tests + if: always() + uses: pmeier/pytest-results-action@main + with: + path: test-results.xml + summary: true + display-options: fEX + fail-on-empty: true diff --git a/pyproject.toml b/pyproject.toml index 3b790a8a..93d00c8a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,3 +32,17 @@ include = [ [tool.hatch.build.targets.wheel.force-include] "./configs" = "sgm/configs" + +[tool.hatch.envs.ci] +skip-install = false + +dependencies = [ + "pytest" +] + +[tool.hatch.envs.ci.scripts] +test-inference = [ + "pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 torchaudio==2.0.2+cu118 --index-url https://download.pytorch.org/whl/cu118", + "pip install -r requirements/pt2.txt", + "pytest -v tests/inference/test_inference.py {args}", +] \ No newline at end of file diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 00000000..d79bd9b2 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,3 @@ +[pytest] +markers = + inference: mark as inference test (deselect with '-m "not inference"') \ No newline at end of file diff --git a/scripts/demo/sampling.py b/scripts/demo/sampling.py index 87d80155..3f3e7072 100644 --- a/scripts/demo/sampling.py +++ b/scripts/demo/sampling.py @@ -1,7 +1,14 @@ +import numpy as np from pytorch_lightning import seed_everything from scripts.demo.streamlit_helpers import * from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering +from sgm.inference.helpers import ( + do_img2img, + do_sample, + get_unique_embedder_keys_from_conditioner, + perform_save_locally, +) SAVE_PATH = "outputs/demo/txt2img/" @@ -131,6 +138,8 @@ def run_txt2img( if st.button("Sample"): st.write(f"**Model I:** {version}") + outputs = st.empty() + st.text("Sampling") out = do_sample( state["model"], sampler, @@ -144,6 +153,8 @@ def run_txt2img( return_latents=return_latents, filter=filter, ) + show_samples(out, outputs) + return out @@ -175,6 +186,8 @@ def run_img2img( num_samples = num_rows * num_cols if st.button("Sample"): + outputs = st.empty() + st.text("Sampling") out = do_img2img( repeat(img, "1 ... -> n ...", n=num_samples), state["model"], @@ -185,6 +198,7 @@ def run_img2img( return_latents=return_latents, filter=filter, ) + show_samples(out, outputs) return out @@ -249,8 +263,6 @@ def apply_refiner( save_locally, save_path = init_save_locally(os.path.join(SAVE_PATH, version)) state = init_st(version_dict) - if state["msg"]: - st.info(state["msg"]) model = state["model"] is_legacy = version_dict["is_legacy"] @@ -275,7 +287,6 @@ def apply_refiner( version_dict2 = VERSION2SPECS[version2] state2 = init_st(version_dict2) - st.info(state2["msg"]) stage2strength = st.number_input( "**Refinement strength**", value=0.3, min_value=0.0, max_value=1.0 @@ -315,6 +326,7 @@ def apply_refiner( samples_z = None if add_pipeline and samples_z is not None: + outputs = st.empty() st.write("**Running Refinement Stage**") samples = apply_refiner( samples_z, @@ -325,6 +337,7 @@ def apply_refiner( negative_prompt=negative_prompt if is_legacy else "", filter=filter, ) + show_samples(samples, outputs) if save_locally and samples is not None: perform_save_locally(save_path, samples) diff --git a/scripts/demo/streamlit_helpers.py b/scripts/demo/streamlit_helpers.py index 2cf165b6..4b752a7a 100644 --- a/scripts/demo/streamlit_helpers.py +++ b/scripts/demo/streamlit_helpers.py @@ -1,18 +1,11 @@ -import math import os -from typing import List, Union - -import numpy as np import streamlit as st import torch -from einops import rearrange, repeat -from imwatermark import WatermarkEncoder -from omegaconf import ListConfig, OmegaConf from PIL import Image -from safetensors.torch import load_file as load_safetensors -from torch import autocast +from einops import rearrange, repeat +from omegaconf import OmegaConf from torchvision import transforms -from torchvision.utils import make_grid + from sgm.modules.diffusionmodules.sampling import ( DPMPP2MSampler, @@ -22,52 +15,8 @@ HeunEDMSampler, LinearMultistepSampler, ) -from sgm.util import append_dims, instantiate_from_config - - -class WatermarkEmbedder: - def __init__(self, watermark): - self.watermark = watermark - self.num_bits = len(WATERMARK_BITS) - self.encoder = WatermarkEncoder() - self.encoder.set_watermark("bits", self.watermark) - - def __call__(self, image: torch.Tensor): - """ - Adds a predefined watermark to the input image - - Args: - image: ([N,] B, C, H, W) in range [0, 1] - - Returns: - same as input but watermarked - """ - # watermarking libary expects input as cv2 BGR format - squeeze = len(image.shape) == 4 - if squeeze: - image = image[None, ...] - n = image.shape[0] - image_np = rearrange( - (255 * image).detach().cpu(), "n b c h w -> (n b) h w c" - ).numpy()[:, :, :, ::-1] - # torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255] - for k in range(image_np.shape[0]): - image_np[k] = self.encoder.encode(image_np[k], "dwtDct") - image = torch.from_numpy( - rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n) - ).to(image.device) - image = torch.clamp(image / 255, min=0.0, max=1.0) - if squeeze: - image = image[0] - return image - - -# A fixed 48-bit message that was choosen at random -# WATERMARK_MESSAGE = 0xB3EC907BB19E -WATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110 -# bin(x)[2:] gives bits of x as str, use int to convert them to 0/1 -WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]] -embed_watemark = WatermarkEmbedder(WATERMARK_BITS) +from sgm.inference.helpers import Img2ImgDiscretizationWrapper, embed_watermark +from sgm.util import load_model_from_config @st.cache_resource() @@ -78,54 +27,17 @@ def init_st(version_dict, load_ckpt=True): ckpt = version_dict["ckpt"] config = OmegaConf.load(config) - model, msg = load_model_from_config(config, ckpt if load_ckpt else None) + model = load_model_from_config(config, ckpt if load_ckpt else None) + model = model.to("cuda") + model.conditioner.half() + model.model.half() - state["msg"] = msg state["model"] = model state["ckpt"] = ckpt if load_ckpt else None state["config"] = config return state -def load_model_from_config(config, ckpt=None, verbose=True): - model = instantiate_from_config(config.model) - - if ckpt is not None: - print(f"Loading model from {ckpt}") - if ckpt.endswith("ckpt"): - pl_sd = torch.load(ckpt, map_location="cpu") - if "global_step" in pl_sd: - global_step = pl_sd["global_step"] - st.info(f"loaded ckpt from global step {global_step}") - print(f"Global Step: {pl_sd['global_step']}") - sd = pl_sd["state_dict"] - elif ckpt.endswith("safetensors"): - sd = load_safetensors(ckpt) - else: - raise NotImplementedError - - msg = None - - m, u = model.load_state_dict(sd, strict=False) - - if len(m) > 0 and verbose: - print("missing keys:") - print(m) - if len(u) > 0 and verbose: - print("unexpected keys:") - print(u) - else: - msg = None - - model.cuda() - model.eval() - return model, msg - - -def get_unique_embedder_keys_from_conditioner(conditioner): - return list(set([x.input_key for x in conditioner.embedders])) - - def init_embedder_options(keys, init_dict, prompt=None, negative_prompt=None): # Hardcoded demo settings; might undergo some changes in the future @@ -186,18 +98,6 @@ def init_embedder_options(keys, init_dict, prompt=None, negative_prompt=None): return value_dict -def perform_save_locally(save_path, samples): - os.makedirs(os.path.join(save_path), exist_ok=True) - base_count = len(os.listdir(os.path.join(save_path))) - samples = embed_watemark(samples) - for sample in samples: - sample = 255.0 * rearrange(sample.cpu().numpy(), "c h w -> h w c") - Image.fromarray(sample.astype(np.uint8)).save( - os.path.join(save_path, f"{base_count:09}.png") - ) - base_count += 1 - - def init_save_locally(_dir, init_value: bool = False): save_locally = st.sidebar.checkbox("Save images locally", value=init_value) if save_locally: @@ -208,28 +108,12 @@ def init_save_locally(_dir, init_value: bool = False): return save_locally, save_path -class Img2ImgDiscretizationWrapper: - """ - wraps a discretizer, and prunes the sigmas - params: - strength: float between 0.0 and 1.0. 1.0 means full sampling (all sigmas are returned) - """ - - def __init__(self, discretization, strength: float = 1.0): - self.discretization = discretization - self.strength = strength - assert 0.0 <= self.strength <= 1.0 - - def __call__(self, *args, **kwargs): - # sigmas start large first, and decrease then - sigmas = self.discretization(*args, **kwargs) - print(f"sigmas after discretization, before pruning img2img: ", sigmas) - sigmas = torch.flip(sigmas, (0,)) - sigmas = sigmas[: max(int(self.strength * len(sigmas)), 1)] - print("prune index:", max(int(self.strength * len(sigmas)), 1)) - sigmas = torch.flip(sigmas, (0,)) - print(f"sigmas after pruning: ", sigmas) - return sigmas +def show_samples(samples, outputs): + if isinstance(samples, tuple): + samples, _ = samples + grid = embed_watermark(torch.stack([samples])) + grid = rearrange(grid, "n b c h w -> (n h) (b w) c") + outputs.image(grid.cpu().numpy()) def get_guider(key): @@ -452,214 +336,3 @@ def get_init_img(batch_size=1, key=None): init_image = load_img(key=key).cuda() init_image = repeat(init_image, "1 ... -> b ...", b=batch_size) return init_image - - -def do_sample( - model, - sampler, - value_dict, - num_samples, - H, - W, - C, - F, - force_uc_zero_embeddings: List = None, - batch2model_input: List = None, - return_latents=False, - filter=None, -): - if force_uc_zero_embeddings is None: - force_uc_zero_embeddings = [] - if batch2model_input is None: - batch2model_input = [] - - st.text("Sampling") - - outputs = st.empty() - precision_scope = autocast - with torch.no_grad(): - with precision_scope("cuda"): - with model.ema_scope(): - num_samples = [num_samples] - batch, batch_uc = get_batch( - get_unique_embedder_keys_from_conditioner(model.conditioner), - value_dict, - num_samples, - ) - for key in batch: - if isinstance(batch[key], torch.Tensor): - print(key, batch[key].shape) - elif isinstance(batch[key], list): - print(key, [len(l) for l in batch[key]]) - else: - print(key, batch[key]) - c, uc = model.conditioner.get_unconditional_conditioning( - batch, - batch_uc=batch_uc, - force_uc_zero_embeddings=force_uc_zero_embeddings, - ) - - for k in c: - if not k == "crossattn": - c[k], uc[k] = map( - lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc) - ) - - additional_model_inputs = {} - for k in batch2model_input: - additional_model_inputs[k] = batch[k] - - shape = (math.prod(num_samples), C, H // F, W // F) - randn = torch.randn(shape).to("cuda") - - def denoiser(input, sigma, c): - return model.denoiser( - model.model, input, sigma, c, **additional_model_inputs - ) - - samples_z = sampler(denoiser, randn, cond=c, uc=uc) - samples_x = model.decode_first_stage(samples_z) - samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) - - if filter is not None: - samples = filter(samples) - - grid = torch.stack([samples]) - grid = rearrange(grid, "n b c h w -> (n h) (b w) c") - outputs.image(grid.cpu().numpy()) - - if return_latents: - return samples, samples_z - return samples - - -def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"): - # Hardcoded demo setups; might undergo some changes in the future - - batch = {} - batch_uc = {} - - for key in keys: - if key == "txt": - batch["txt"] = ( - np.repeat([value_dict["prompt"]], repeats=math.prod(N)) - .reshape(N) - .tolist() - ) - batch_uc["txt"] = ( - np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)) - .reshape(N) - .tolist() - ) - elif key == "original_size_as_tuple": - batch["original_size_as_tuple"] = ( - torch.tensor([value_dict["orig_height"], value_dict["orig_width"]]) - .to(device) - .repeat(*N, 1) - ) - elif key == "crop_coords_top_left": - batch["crop_coords_top_left"] = ( - torch.tensor( - [value_dict["crop_coords_top"], value_dict["crop_coords_left"]] - ) - .to(device) - .repeat(*N, 1) - ) - elif key == "aesthetic_score": - batch["aesthetic_score"] = ( - torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1) - ) - batch_uc["aesthetic_score"] = ( - torch.tensor([value_dict["negative_aesthetic_score"]]) - .to(device) - .repeat(*N, 1) - ) - - elif key == "target_size_as_tuple": - batch["target_size_as_tuple"] = ( - torch.tensor([value_dict["target_height"], value_dict["target_width"]]) - .to(device) - .repeat(*N, 1) - ) - else: - batch[key] = value_dict[key] - - for key in batch.keys(): - if key not in batch_uc and isinstance(batch[key], torch.Tensor): - batch_uc[key] = torch.clone(batch[key]) - return batch, batch_uc - - -@torch.no_grad() -def do_img2img( - img, - model, - sampler, - value_dict, - num_samples, - force_uc_zero_embeddings=[], - additional_kwargs={}, - offset_noise_level: int = 0.0, - return_latents=False, - skip_encode=False, - filter=None, -): - st.text("Sampling") - - outputs = st.empty() - precision_scope = autocast - with torch.no_grad(): - with precision_scope("cuda"): - with model.ema_scope(): - batch, batch_uc = get_batch( - get_unique_embedder_keys_from_conditioner(model.conditioner), - value_dict, - [num_samples], - ) - c, uc = model.conditioner.get_unconditional_conditioning( - batch, - batch_uc=batch_uc, - force_uc_zero_embeddings=force_uc_zero_embeddings, - ) - - for k in c: - c[k], uc[k] = map(lambda y: y[k][:num_samples].to("cuda"), (c, uc)) - - for k in additional_kwargs: - c[k] = uc[k] = additional_kwargs[k] - if skip_encode: - z = img - else: - z = model.encode_first_stage(img) - noise = torch.randn_like(z) - sigmas = sampler.discretization(sampler.num_steps) - sigma = sigmas[0] - - st.info(f"all sigmas: {sigmas}") - st.info(f"noising sigma: {sigma}") - - if offset_noise_level > 0.0: - noise = noise + offset_noise_level * append_dims( - torch.randn(z.shape[0], device=z.device), z.ndim - ) - noised_z = z + noise * append_dims(sigma, z.ndim) - noised_z = noised_z / torch.sqrt( - 1.0 + sigmas[0] ** 2.0 - ) # Note: hardcoded to DDPM-like scaling. need to generalize later. - - def denoiser(x, sigma, c): - return model.denoiser(model.model, x, sigma, c) - - samples_z = sampler(denoiser, noised_z, cond=c, uc=uc) - samples_x = model.decode_first_stage(samples_z) - samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) - - if filter is not None: - samples = filter(samples) - - grid = embed_watemark(torch.stack([samples])) - grid = rearrange(grid, "n b c h w -> (n h) (b w) c") - outputs.image(grid.cpu().numpy()) - if return_latents: - return samples, samples_z - return samples diff --git a/sgm/inference/api.py b/sgm/inference/api.py new file mode 100644 index 00000000..0635d112 --- /dev/null +++ b/sgm/inference/api.py @@ -0,0 +1,388 @@ +from dataclasses import dataclass, asdict +from enum import Enum +from omegaconf import OmegaConf +import pathlib +from sgm.inference.helpers import ( + do_sample, + do_img2img, + Img2ImgDiscretizationWrapper, +) +from sgm.modules.diffusionmodules.sampling import ( + EulerEDMSampler, + HeunEDMSampler, + EulerAncestralSampler, + DPMPP2SAncestralSampler, + DPMPP2MSampler, + LinearMultistepSampler, +) +from sgm.util import load_model_from_config +from typing import Optional + + +class ModelArchitecture(str, Enum): + SD_2_1 = "stable-diffusion-v2-1" + SD_2_1_768 = "stable-diffusion-v2-1-768" + SDXL_V0_9_BASE = "stable-diffusion-xl-v0-9-base" + SDXL_V0_9_REFINER = "stable-diffusion-xl-v0-9-refiner" + SDXL_V1_BASE = "stable-diffusion-xl-v1-base" + SDXL_V1_REFINER = "stable-diffusion-xl-v1-refiner" + + +class Sampler(str, Enum): + EULER_EDM = "EulerEDMSampler" + HEUN_EDM = "HeunEDMSampler" + EULER_ANCESTRAL = "EulerAncestralSampler" + DPMPP2S_ANCESTRAL = "DPMPP2SAncestralSampler" + DPMPP2M = "DPMPP2MSampler" + LINEAR_MULTISTEP = "LinearMultistepSampler" + + +class Discretization(str, Enum): + LEGACY_DDPM = "LegacyDDPMDiscretization" + EDM = "EDMDiscretization" + + +class Guider(str, Enum): + VANILLA = "VanillaCFG" + IDENTITY = "IdentityGuider" + + +class Thresholder(str, Enum): + NONE = "None" + + +@dataclass +class SamplingParams: + width: int = 1024 + height: int = 1024 + steps: int = 50 + sampler: Sampler = Sampler.DPMPP2M + discretization: Discretization = Discretization.LEGACY_DDPM + guider: Guider = Guider.VANILLA + thresholder: Thresholder = Thresholder.NONE + scale: float = 6.0 + aesthetic_score: float = 5.0 + negative_aesthetic_score: float = 5.0 + img2img_strength: float = 1.0 + orig_width: int = 1024 + orig_height: int = 1024 + crop_coords_top: int = 0 + crop_coords_left: int = 0 + sigma_min: float = 0.0292 + sigma_max: float = 14.6146 + rho: float = 3.0 + s_churn: float = 0.0 + s_tmin: float = 0.0 + s_tmax: float = 999.0 + s_noise: float = 1.0 + eta: float = 1.0 + order: int = 4 + + +@dataclass +class SamplingSpec: + width: int + height: int + channels: int + factor: int + is_legacy: bool + config: str + ckpt: str + is_guided: bool + + +model_specs = { + ModelArchitecture.SD_2_1: SamplingSpec( + height=512, + width=512, + channels=4, + factor=8, + is_legacy=True, + config="sd_2_1.yaml", + ckpt="v2-1_512-ema-pruned.safetensors", + is_guided=True, + ), + ModelArchitecture.SD_2_1_768: SamplingSpec( + height=768, + width=768, + channels=4, + factor=8, + is_legacy=True, + config="sd_2_1_768.yaml", + ckpt="v2-1_768-ema-pruned.safetensors", + is_guided=True, + ), + ModelArchitecture.SDXL_V0_9_BASE: SamplingSpec( + height=1024, + width=1024, + channels=4, + factor=8, + is_legacy=False, + config="sd_xl_base.yaml", + ckpt="sd_xl_base_0.9.safetensors", + is_guided=True, + ), + ModelArchitecture.SDXL_V0_9_REFINER: SamplingSpec( + height=1024, + width=1024, + channels=4, + factor=8, + is_legacy=True, + config="sd_xl_refiner.yaml", + ckpt="sd_xl_refiner_0.9.safetensors", + is_guided=True, + ), + ModelArchitecture.SDXL_V1_BASE: SamplingSpec( + height=1024, + width=1024, + channels=4, + factor=8, + is_legacy=False, + config="sd_xl_base.yaml", + ckpt="sd_xl_base_1.0-metadata.safetensors", + is_guided=True, + ), + ModelArchitecture.SDXL_V1_REFINER: SamplingSpec( + height=1024, + width=1024, + channels=4, + factor=8, + is_legacy=True, + config="sd_xl_refiner.yaml", + ckpt="sd_xl_refiner_1.0-metadata.safetensors", + is_guided=True, + ), +} + + +class SamplingPipeline: + def __init__( + self, + model_id: ModelArchitecture, + model_path="checkpoints", + config_path="configs/inference", + device="cuda", + use_fp16=True, + ) -> None: + if model_id not in model_specs: + raise ValueError(f"Model {model_id} not supported") + self.model_id = model_id + self.specs = model_specs[self.model_id] + self.config = str(pathlib.Path(config_path, self.specs.config)) + self.ckpt = str(pathlib.Path(model_path, self.specs.ckpt)) + self.device = device + self.model = self._load_model(device=device, use_fp16=use_fp16) + + def _load_model(self, device="cuda", use_fp16=True): + config = OmegaConf.load(self.config) + model = load_model_from_config(config, self.ckpt) + if model is None: + raise ValueError(f"Model {self.model_id} could not be loaded") + model.to(device) + if use_fp16: + model.conditioner.half() + model.model.half() + return model + + def text_to_image( + self, + params: SamplingParams, + prompt: str, + negative_prompt: str = "", + samples: int = 1, + return_latents: bool = False, + ): + sampler = get_sampler_config(params) + value_dict = asdict(params) + value_dict["prompt"] = prompt + value_dict["negative_prompt"] = negative_prompt + value_dict["target_width"] = params.width + value_dict["target_height"] = params.height + return do_sample( + self.model, + sampler, + value_dict, + samples, + params.height, + params.width, + self.specs.channels, + self.specs.factor, + force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [], + return_latents=return_latents, + filter=None, + ) + + def image_to_image( + self, + params: SamplingParams, + image, + prompt: str, + negative_prompt: str = "", + samples: int = 1, + return_latents: bool = False, + ): + sampler = get_sampler_config(params) + + if params.img2img_strength < 1.0: + sampler.discretization = Img2ImgDiscretizationWrapper( + sampler.discretization, + strength=params.img2img_strength, + ) + height, width = image.shape[2], image.shape[3] + value_dict = asdict(params) + value_dict["prompt"] = prompt + value_dict["negative_prompt"] = negative_prompt + value_dict["target_width"] = width + value_dict["target_height"] = height + return do_img2img( + image, + self.model, + sampler, + value_dict, + samples, + force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [], + return_latents=return_latents, + filter=None, + ) + + def refiner( + self, + params: SamplingParams, + image, + prompt: str, + negative_prompt: Optional[str] = None, + samples: int = 1, + return_latents: bool = False, + ): + sampler = get_sampler_config(params) + value_dict = { + "orig_width": image.shape[3] * 8, + "orig_height": image.shape[2] * 8, + "target_width": image.shape[3] * 8, + "target_height": image.shape[2] * 8, + "prompt": prompt, + "negative_prompt": negative_prompt, + "crop_coords_top": 0, + "crop_coords_left": 0, + "aesthetic_score": 6.0, + "negative_aesthetic_score": 2.5, + } + + return do_img2img( + image, + self.model, + sampler, + value_dict, + samples, + skip_encode=True, + return_latents=return_latents, + filter=None, + ) + + +def get_guider_config(params: SamplingParams): + if params.guider == Guider.IDENTITY: + guider_config = { + "target": "sgm.modules.diffusionmodules.guiders.IdentityGuider" + } + elif params.guider == Guider.VANILLA: + scale = params.scale + + thresholder = params.thresholder + + if thresholder == Thresholder.NONE: + dyn_thresh_config = { + "target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding" + } + else: + raise NotImplementedError + + guider_config = { + "target": "sgm.modules.diffusionmodules.guiders.VanillaCFG", + "params": {"scale": scale, "dyn_thresh_config": dyn_thresh_config}, + } + else: + raise NotImplementedError + return guider_config + + +def get_discretization_config(params: SamplingParams): + if params.discretization == Discretization.LEGACY_DDPM: + discretization_config = { + "target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization", + } + elif params.discretization == Discretization.EDM: + discretization_config = { + "target": "sgm.modules.diffusionmodules.discretizer.EDMDiscretization", + "params": { + "sigma_min": params.sigma_min, + "sigma_max": params.sigma_max, + "rho": params.rho, + }, + } + else: + raise ValueError(f"unknown discretization {params.discretization}") + return discretization_config + + +def get_sampler_config(params: SamplingParams): + discretization_config = get_discretization_config(params) + guider_config = get_guider_config(params) + sampler = None + if params.sampler == Sampler.EULER_EDM: + return EulerEDMSampler( + num_steps=params.steps, + discretization_config=discretization_config, + guider_config=guider_config, + s_churn=params.s_churn, + s_tmin=params.s_tmin, + s_tmax=params.s_tmax, + s_noise=params.s_noise, + verbose=True, + ) + if params.sampler == Sampler.HEUN_EDM: + return HeunEDMSampler( + num_steps=params.steps, + discretization_config=discretization_config, + guider_config=guider_config, + s_churn=params.s_churn, + s_tmin=params.s_tmin, + s_tmax=params.s_tmax, + s_noise=params.s_noise, + verbose=True, + ) + if params.sampler == Sampler.EULER_ANCESTRAL: + return EulerAncestralSampler( + num_steps=params.steps, + discretization_config=discretization_config, + guider_config=guider_config, + eta=params.eta, + s_noise=params.s_noise, + verbose=True, + ) + if params.sampler == Sampler.DPMPP2S_ANCESTRAL: + return DPMPP2SAncestralSampler( + num_steps=params.steps, + discretization_config=discretization_config, + guider_config=guider_config, + eta=params.eta, + s_noise=params.s_noise, + verbose=True, + ) + if params.sampler == Sampler.DPMPP2M: + return DPMPP2MSampler( + num_steps=params.steps, + discretization_config=discretization_config, + guider_config=guider_config, + verbose=True, + ) + if params.sampler == Sampler.LINEAR_MULTISTEP: + return LinearMultistepSampler( + num_steps=params.steps, + discretization_config=discretization_config, + guider_config=guider_config, + order=params.order, + verbose=True, + ) + + raise ValueError(f"unknown sampler {params.sampler}!") diff --git a/sgm/inference/helpers.py b/sgm/inference/helpers.py new file mode 100644 index 00000000..1c653708 --- /dev/null +++ b/sgm/inference/helpers.py @@ -0,0 +1,305 @@ +import os +from typing import Union, List, Optional + +import math +import numpy as np +import torch +from PIL import Image +from einops import rearrange +from imwatermark import WatermarkEncoder +from omegaconf import ListConfig +from torch import autocast + +from sgm.util import append_dims + + +class WatermarkEmbedder: + def __init__(self, watermark): + self.watermark = watermark + self.num_bits = len(WATERMARK_BITS) + self.encoder = WatermarkEncoder() + self.encoder.set_watermark("bits", self.watermark) + + def __call__(self, image: torch.Tensor): + """ + Adds a predefined watermark to the input image + + Args: + image: ([N,] B, C, H, W) in range [0, 1] + + Returns: + same as input but watermarked + """ + # watermarking libary expects input as cv2 BGR format + squeeze = len(image.shape) == 4 + if squeeze: + image = image[None, ...] + n = image.shape[0] + image_np = rearrange( + (255 * image).detach().cpu(), "n b c h w -> (n b) h w c" + ).numpy()[:, :, :, ::-1] + # torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255] + for k in range(image_np.shape[0]): + image_np[k] = self.encoder.encode(image_np[k], "dwtDct") + image = torch.from_numpy( + rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n) + ).to(image.device) + image = torch.clamp(image / 255, min=0.0, max=1.0) + if squeeze: + image = image[0] + return image + + +# A fixed 48-bit message that was choosen at random +# WATERMARK_MESSAGE = 0xB3EC907BB19E +WATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110 +# bin(x)[2:] gives bits of x as str, use int to convert them to 0/1 +WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]] +embed_watermark = WatermarkEmbedder(WATERMARK_BITS) + + +def get_unique_embedder_keys_from_conditioner(conditioner): + return list({x.input_key for x in conditioner.embedders}) + + +def perform_save_locally(save_path, samples): + os.makedirs(os.path.join(save_path), exist_ok=True) + base_count = len(os.listdir(os.path.join(save_path))) + samples = embed_watermark(samples) + for sample in samples: + sample = 255.0 * rearrange(sample.cpu().numpy(), "c h w -> h w c") + Image.fromarray(sample.astype(np.uint8)).save( + os.path.join(save_path, f"{base_count:09}.png") + ) + base_count += 1 + + +class Img2ImgDiscretizationWrapper: + """ + wraps a discretizer, and prunes the sigmas + params: + strength: float between 0.0 and 1.0. 1.0 means full sampling (all sigmas are returned) + """ + + def __init__(self, discretization, strength: float = 1.0): + self.discretization = discretization + self.strength = strength + assert 0.0 <= self.strength <= 1.0 + + def __call__(self, *args, **kwargs): + # sigmas start large first, and decrease then + sigmas = self.discretization(*args, **kwargs) + print(f"sigmas after discretization, before pruning img2img: ", sigmas) + sigmas = torch.flip(sigmas, (0,)) + sigmas = sigmas[: max(int(self.strength * len(sigmas)), 1)] + print("prune index:", max(int(self.strength * len(sigmas)), 1)) + sigmas = torch.flip(sigmas, (0,)) + print(f"sigmas after pruning: ", sigmas) + return sigmas + + +def do_sample( + model, + sampler, + value_dict, + num_samples, + H, + W, + C, + F, + force_uc_zero_embeddings: Optional[List] = None, + batch2model_input: Optional[List] = None, + return_latents=False, + filter=None, + device="cuda", +): + if force_uc_zero_embeddings is None: + force_uc_zero_embeddings = [] + if batch2model_input is None: + batch2model_input = [] + + with torch.no_grad(): + with autocast(device) as precision_scope: + with model.ema_scope(): + num_samples = [num_samples] + batch, batch_uc = get_batch( + get_unique_embedder_keys_from_conditioner(model.conditioner), + value_dict, + num_samples, + ) + for key in batch: + if isinstance(batch[key], torch.Tensor): + print(key, batch[key].shape) + elif isinstance(batch[key], list): + print(key, [len(l) for l in batch[key]]) + else: + print(key, batch[key]) + c, uc = model.conditioner.get_unconditional_conditioning( + batch, + batch_uc=batch_uc, + force_uc_zero_embeddings=force_uc_zero_embeddings, + ) + + for k in c: + if not k == "crossattn": + c[k], uc[k] = map( + lambda y: y[k][: math.prod(num_samples)].to(device), (c, uc) + ) + + additional_model_inputs = {} + for k in batch2model_input: + additional_model_inputs[k] = batch[k] + + shape = (math.prod(num_samples), C, H // F, W // F) + randn = torch.randn(shape).to(device) + + def denoiser(input, sigma, c): + return model.denoiser( + model.model, input, sigma, c, **additional_model_inputs + ) + + samples_z = sampler(denoiser, randn, cond=c, uc=uc) + samples_x = model.decode_first_stage(samples_z) + samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) + + if filter is not None: + samples = filter(samples) + + if return_latents: + return samples, samples_z + return samples + + +def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"): + # Hardcoded demo setups; might undergo some changes in the future + + batch = {} + batch_uc = {} + + for key in keys: + if key == "txt": + batch["txt"] = ( + np.repeat([value_dict["prompt"]], repeats=math.prod(N)) + .reshape(N) + .tolist() + ) + batch_uc["txt"] = ( + np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)) + .reshape(N) + .tolist() + ) + elif key == "original_size_as_tuple": + batch["original_size_as_tuple"] = ( + torch.tensor([value_dict["orig_height"], value_dict["orig_width"]]) + .to(device) + .repeat(*N, 1) + ) + elif key == "crop_coords_top_left": + batch["crop_coords_top_left"] = ( + torch.tensor( + [value_dict["crop_coords_top"], value_dict["crop_coords_left"]] + ) + .to(device) + .repeat(*N, 1) + ) + elif key == "aesthetic_score": + batch["aesthetic_score"] = ( + torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1) + ) + batch_uc["aesthetic_score"] = ( + torch.tensor([value_dict["negative_aesthetic_score"]]) + .to(device) + .repeat(*N, 1) + ) + + elif key == "target_size_as_tuple": + batch["target_size_as_tuple"] = ( + torch.tensor([value_dict["target_height"], value_dict["target_width"]]) + .to(device) + .repeat(*N, 1) + ) + else: + batch[key] = value_dict[key] + + for key in batch.keys(): + if key not in batch_uc and isinstance(batch[key], torch.Tensor): + batch_uc[key] = torch.clone(batch[key]) + return batch, batch_uc + + +def get_input_image_tensor(image: Image.Image, device="cuda"): + w, h = image.size + print(f"loaded input image of size ({w}, {h})") + width, height = map( + lambda x: x - x % 64, (w, h) + ) # resize to integer multiple of 64 + image = image.resize((width, height)) + image_array = np.array(image.convert("RGB")) + image_array = image_array[None].transpose(0, 3, 1, 2) + image_tensor = torch.from_numpy(image_array).to(dtype=torch.float32) / 127.5 - 1.0 + return image_tensor.to(device) + + +def do_img2img( + img, + model, + sampler, + value_dict, + num_samples, + force_uc_zero_embeddings=[], + additional_kwargs={}, + offset_noise_level: float = 0.0, + return_latents=False, + skip_encode=False, + filter=None, + device="cuda", +): + with torch.no_grad(): + with autocast(device) as precision_scope: + with model.ema_scope(): + batch, batch_uc = get_batch( + get_unique_embedder_keys_from_conditioner(model.conditioner), + value_dict, + [num_samples], + ) + c, uc = model.conditioner.get_unconditional_conditioning( + batch, + batch_uc=batch_uc, + force_uc_zero_embeddings=force_uc_zero_embeddings, + ) + + for k in c: + c[k], uc[k] = map(lambda y: y[k][:num_samples].to(device), (c, uc)) + + for k in additional_kwargs: + c[k] = uc[k] = additional_kwargs[k] + if skip_encode: + z = img + else: + z = model.encode_first_stage(img) + noise = torch.randn_like(z) + sigmas = sampler.discretization(sampler.num_steps) + sigma = sigmas[0].to(z.device) + + if offset_noise_level > 0.0: + noise = noise + offset_noise_level * append_dims( + torch.randn(z.shape[0], device=z.device), z.ndim + ) + noised_z = z + noise * append_dims(sigma, z.ndim) + noised_z = noised_z / torch.sqrt( + 1.0 + sigmas[0] ** 2.0 + ) # Note: hardcoded to DDPM-like scaling. need to generalize later. + + def denoiser(x, sigma, c): + return model.denoiser(model.model, x, sigma, c) + + samples_z = sampler(denoiser, noised_z, cond=c, uc=uc) + samples_x = model.decode_first_stage(samples_z) + samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) + + if filter is not None: + samples = filter(samples) + + if return_latents: + return samples, samples_z + return samples diff --git a/tests/inference/test_inference.py b/tests/inference/test_inference.py new file mode 100644 index 00000000..2b2af11e --- /dev/null +++ b/tests/inference/test_inference.py @@ -0,0 +1,111 @@ +import numpy +from PIL import Image +import pytest +from pytest import fixture +import torch +from typing import Tuple + +from sgm.inference.api import ( + model_specs, + SamplingParams, + SamplingPipeline, + Sampler, + ModelArchitecture, +) +import sgm.inference.helpers as helpers + + +@pytest.mark.inference +class TestInference: + @fixture(scope="class", params=model_specs.keys()) + def pipeline(self, request) -> SamplingPipeline: + pipeline = SamplingPipeline(request.param) + yield pipeline + del pipeline + torch.cuda.empty_cache() + + @fixture( + scope="class", + params=[ + [ModelArchitecture.SDXL_V1_BASE, ModelArchitecture.SDXL_V1_REFINER], + [ModelArchitecture.SDXL_V0_9_BASE, ModelArchitecture.SDXL_V0_9_REFINER], + ], + ids=["SDXL_V1", "SDXL_V0_9"], + ) + def sdxl_pipelines(self, request) -> Tuple[SamplingPipeline, SamplingPipeline]: + base_pipeline = SamplingPipeline(request.param[0]) + refiner_pipeline = SamplingPipeline(request.param[1]) + yield base_pipeline, refiner_pipeline + del base_pipeline + del refiner_pipeline + torch.cuda.empty_cache() + + def create_init_image(self, h, w): + image_array = numpy.random.rand(h, w, 3) * 255 + image = Image.fromarray(image_array.astype("uint8")).convert("RGB") + return helpers.get_input_image_tensor(image) + + @pytest.mark.parametrize("sampler_enum", Sampler) + def test_txt2img(self, pipeline: SamplingPipeline, sampler_enum): + output = pipeline.text_to_image( + params=SamplingParams(sampler=sampler_enum.value, steps=10), + prompt="A professional photograph of an astronaut riding a pig", + negative_prompt="", + samples=1, + ) + + assert output is not None + + @pytest.mark.parametrize("sampler_enum", Sampler) + def test_img2img(self, pipeline: SamplingPipeline, sampler_enum): + output = pipeline.image_to_image( + params=SamplingParams(sampler=sampler_enum.value, steps=10), + image=self.create_init_image(pipeline.specs.height, pipeline.specs.width), + prompt="A professional photograph of an astronaut riding a pig", + negative_prompt="", + samples=1, + ) + assert output is not None + + @pytest.mark.parametrize("sampler_enum", Sampler) + @pytest.mark.parametrize( + "use_init_image", [True, False], ids=["img2img", "txt2img"] + ) + def test_sdxl_with_refiner( + self, + sdxl_pipelines: Tuple[SamplingPipeline, SamplingPipeline], + sampler_enum, + use_init_image, + ): + base_pipeline, refiner_pipeline = sdxl_pipelines + if use_init_image: + output = base_pipeline.image_to_image( + params=SamplingParams(sampler=sampler_enum.value, steps=10), + image=self.create_init_image( + base_pipeline.specs.height, base_pipeline.specs.width + ), + prompt="A professional photograph of an astronaut riding a pig", + negative_prompt="", + samples=1, + return_latents=True, + ) + else: + output = base_pipeline.text_to_image( + params=SamplingParams(sampler=sampler_enum.value, steps=10), + prompt="A professional photograph of an astronaut riding a pig", + negative_prompt="", + samples=1, + return_latents=True, + ) + + assert isinstance(output, (tuple, list)) + samples, samples_z = output + assert samples is not None + assert samples_z is not None + refiner_pipeline.refiner( + params=SamplingParams(sampler=sampler_enum.value, steps=10), + image=samples_z, + prompt="A professional photograph of an astronaut riding a pig", + negative_prompt="", + samples=1, + )