diff --git a/optimizedSD/optimized_txt2img.py b/optimizedSD/optimized_txt2img.py index c82918240..ddfe31687 100644 --- a/optimizedSD/optimized_txt2img.py +++ b/optimizedSD/optimized_txt2img.py @@ -1,4 +1,6 @@ -import argparse, os, re +import argparse +import os +import re import torch import numpy as np from random import randint @@ -33,15 +35,48 @@ def load_model_from_config(ckpt, verbose=False): return sd +def vectorize_prompt(modelCS, batch_size, prompt): + empty_result = modelCS.get_learned_conditioning(batch_size * [""]) + result = torch.zeros_like(empty_result) + subprompts, weights = split_weighted_subprompts(prompt) + weights_sum = sum(weights) + cntr = 0 + for i, subprompt in enumerate(subprompts): + cntr += 1 + result = torch.add(result, + modelCS.get_learned_conditioning(batch_size + * [subprompt]), + alpha=weights[i] / weights_sum) + if cntr == 0: + result = empty_result + return result + + config = "optimizedSD/v1-inference.yaml" DEFAULT_CKPT = "models/ldm/stable-diffusion-v1/model.ckpt" parser = argparse.ArgumentParser() parser.add_argument( - "--prompt", type=str, nargs="?", default="a painting of a virus monster playing guitar", help="the prompt to render" + "--prompt", + type=str, + nargs="?", + default="a painting of a virus monster playing guitar", + help="the prompt to render" +) +parser.add_argument( + "--nprompt", + type=str, + default="", + help="negative prompt to render" +) +parser.add_argument( + "--outdir", + type=str, + nargs="?", + help="dir to write results to", + default="outputs/txt2img-samples" ) -parser.add_argument("--outdir", type=str, nargs="?", help="dir to write results to", default="outputs/txt2img-samples") parser.add_argument( "--skip_grid", action="store_true", @@ -58,7 +93,6 @@ def load_model_from_config(ckpt, verbose=False): default=50, help="number of ddim sampling steps", ) - parser.add_argument( "--fixed_code", action="store_true", @@ -147,7 +181,7 @@ def load_model_from_config(ckpt, verbose=False): help="Reduces inference time on the expense of 1GB VRAM", ) parser.add_argument( - "--precision", + "--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], @@ -164,7 +198,7 @@ def load_model_from_config(ckpt, verbose=False): "--sampler", type=str, help="sampler", - choices=["ddim", "plms","heun", "euler", "euler_a", "dpm2", "dpm2_a", "lms"], + choices=["ddim", "plms", "heun", "euler", "euler_a", "dpm2", "dpm2_a", "lms"], default="plms", ) parser.add_argument( @@ -185,7 +219,7 @@ def load_model_from_config(ckpt, verbose=False): seed_everything(opt.seed) # Logging -logger(vars(opt), log_csv = "logs/txt2img_logs.csv") +logger(vars(opt), log_csv="logs/txt2img_logs.csv") sd = load_model_from_config(f"{opt.ckpt}") li, lo = [], [] @@ -258,12 +292,12 @@ def load_model_from_config(ckpt, verbose=False): seeds = "" with torch.no_grad(): - all_samples = list() for n in trange(opt.n_iter, desc="Sampling"): for prompts in tqdm(data, desc="data"): - sample_path = os.path.join(outpath, "_".join(re.split(":| ", prompts[0])))[:150] + if prompts[0] == "": + sample_path = os.path.join(outpath, "empty_prompt") os.makedirs(sample_path, exist_ok=True) base_count = len(os.listdir(sample_path)) @@ -271,22 +305,13 @@ def load_model_from_config(ckpt, verbose=False): modelCS.to(opt.device) uc = None if opt.scale != 1.0: - uc = modelCS.get_learned_conditioning(batch_size * [""]) + uc = vectorize_prompt(modelCS, batch_size, opt.nprompt) if isinstance(prompts, tuple): prompts = list(prompts) + c = vectorize_prompt(modelCS, batch_size, prompts[0]) - subprompts, weights = split_weighted_subprompts(prompts[0]) - if len(subprompts) > 1: - c = torch.zeros_like(uc) - totalWeight = sum(weights) - # normalize each "sub prompt" and add it - for i in range(len(subprompts)): - weight = weights[i] - # if not skip_normalize: - weight = weight / totalWeight - c = torch.add(c, modelCS.get_learned_conditioning(subprompts[i]), alpha=weight) - else: - c = modelCS.get_learned_conditioning(prompts) + # vectorize_prompt() (see above) replaces the existing code. + # vectorize_prompt() by https://github.com/consciencia/stable-diffusion shape = [opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f] @@ -306,7 +331,7 @@ def load_model_from_config(ckpt, verbose=False): unconditional_conditioning=uc, eta=opt.ddim_eta, x_T=start_code, - sampler = opt.sampler, + sampler=opt.sampler, ) modelFS.to(opt.device)