Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement support for negative prompts using --nprompts #220

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 48 additions & 23 deletions optimizedSD/optimized_txt2img.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import argparse, os, re
import argparse
import os
import re
import torch
import numpy as np
from random import randint
Expand Down Expand Up @@ -33,15 +35,48 @@ def load_model_from_config(ckpt, verbose=False):
return sd


def vectorize_prompt(modelCS, batch_size, prompt):
empty_result = modelCS.get_learned_conditioning(batch_size * [""])
result = torch.zeros_like(empty_result)
subprompts, weights = split_weighted_subprompts(prompt)
weights_sum = sum(weights)
cntr = 0
for i, subprompt in enumerate(subprompts):
cntr += 1
result = torch.add(result,
modelCS.get_learned_conditioning(batch_size
* [subprompt]),
alpha=weights[i] / weights_sum)
if cntr == 0:
result = empty_result
return result


config = "optimizedSD/v1-inference.yaml"
DEFAULT_CKPT = "models/ldm/stable-diffusion-v1/model.ckpt"

parser = argparse.ArgumentParser()

parser.add_argument(
"--prompt", type=str, nargs="?", default="a painting of a virus monster playing guitar", help="the prompt to render"
"--prompt",
type=str,
nargs="?",
default="a painting of a virus monster playing guitar",
help="the prompt to render"
)
parser.add_argument(
"--nprompt",
type=str,
default="",
help="negative prompt to render"
)
parser.add_argument(
"--outdir",
type=str,
nargs="?",
help="dir to write results to",
default="outputs/txt2img-samples"
)
parser.add_argument("--outdir", type=str, nargs="?", help="dir to write results to", default="outputs/txt2img-samples")
parser.add_argument(
"--skip_grid",
action="store_true",
Expand All @@ -58,7 +93,6 @@ def load_model_from_config(ckpt, verbose=False):
default=50,
help="number of ddim sampling steps",
)

parser.add_argument(
"--fixed_code",
action="store_true",
Expand Down Expand Up @@ -147,7 +181,7 @@ def load_model_from_config(ckpt, verbose=False):
help="Reduces inference time on the expense of 1GB VRAM",
)
parser.add_argument(
"--precision",
"--precision",
type=str,
help="evaluate at this precision",
choices=["full", "autocast"],
Expand All @@ -164,7 +198,7 @@ def load_model_from_config(ckpt, verbose=False):
"--sampler",
type=str,
help="sampler",
choices=["ddim", "plms","heun", "euler", "euler_a", "dpm2", "dpm2_a", "lms"],
choices=["ddim", "plms", "heun", "euler", "euler_a", "dpm2", "dpm2_a", "lms"],
default="plms",
)
parser.add_argument(
Expand All @@ -185,7 +219,7 @@ def load_model_from_config(ckpt, verbose=False):
seed_everything(opt.seed)

# Logging
logger(vars(opt), log_csv = "logs/txt2img_logs.csv")
logger(vars(opt), log_csv="logs/txt2img_logs.csv")

sd = load_model_from_config(f"{opt.ckpt}")
li, lo = [], []
Expand Down Expand Up @@ -258,35 +292,26 @@ def load_model_from_config(ckpt, verbose=False):

seeds = ""
with torch.no_grad():

all_samples = list()
for n in trange(opt.n_iter, desc="Sampling"):
for prompts in tqdm(data, desc="data"):

sample_path = os.path.join(outpath, "_".join(re.split(":| ", prompts[0])))[:150]
if prompts[0] == "":
sample_path = os.path.join(outpath, "empty_prompt")
os.makedirs(sample_path, exist_ok=True)
base_count = len(os.listdir(sample_path))

with precision_scope("cuda"):
modelCS.to(opt.device)
uc = None
if opt.scale != 1.0:
uc = modelCS.get_learned_conditioning(batch_size * [""])
uc = vectorize_prompt(modelCS, batch_size, opt.nprompt)
if isinstance(prompts, tuple):
prompts = list(prompts)
c = vectorize_prompt(modelCS, batch_size, prompts[0])

subprompts, weights = split_weighted_subprompts(prompts[0])
if len(subprompts) > 1:
c = torch.zeros_like(uc)
totalWeight = sum(weights)
# normalize each "sub prompt" and add it
for i in range(len(subprompts)):
weight = weights[i]
# if not skip_normalize:
weight = weight / totalWeight
c = torch.add(c, modelCS.get_learned_conditioning(subprompts[i]), alpha=weight)
else:
c = modelCS.get_learned_conditioning(prompts)
# vectorize_prompt() (see above) replaces the existing code.
# vectorize_prompt() by https://github.com/consciencia/stable-diffusion

shape = [opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f]

Expand All @@ -306,7 +331,7 @@ def load_model_from_config(ckpt, verbose=False):
unconditional_conditioning=uc,
eta=opt.ddim_eta,
x_T=start_code,
sampler = opt.sampler,
sampler=opt.sampler,
)

modelFS.to(opt.device)
Expand Down