Skip to content

Commit

Permalink
fix normalized prompt when a variation is generated
Browse files Browse the repository at this point in the history
- The seed printed needs to be the one generated prior to the
  initial noising operation. To do this, I added a new "first_seed"
  argument to the image callback in dream.py.
- Closes #641
  • Loading branch information
lstein committed Sep 21, 2022
1 parent 0632a3a commit b93f04e
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 6 deletions.
3 changes: 2 additions & 1 deletion ldm/dream/generator/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def generate(self,prompt,init_image,width,height,iterations=1,seed=None,

results = []
seed = seed if seed else self.new_seed()
first_seed = seed
seed, initial_noise = self.generate_initial_noise(seed, width, height)
with scope(self.model.device.type), self.model.ema_scope():
for n in trange(iterations, desc='Generating'):
Expand All @@ -71,7 +72,7 @@ def generate(self,prompt,init_image,width,height,iterations=1,seed=None,
image = make_image(x_T)
results.append([image, seed])
if image_callback is not None:
image_callback(image, seed)
image_callback(image, seed, first_seed=first_seed)
seed = self.new_seed()
return results

Expand Down
4 changes: 2 additions & 2 deletions ldm/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,10 +765,10 @@ def _load_model_from_config(self, config, weights):
m, u = model.load_state_dict(sd, strict=False)

if self.precision == 'float16':
print('Using faster float16 precision')
print('>> Using faster float16 precision')
model.to(torch.float16)
else:
print('Using more accurate float32 precision')
print('>> Using more accurate float32 precision')

model.to(self.device)
model.eval()
Expand Down
7 changes: 4 additions & 3 deletions scripts/dream.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,11 +250,12 @@ def main_loop(gen, opt, infile):
results = [] # list of filename, prompt pairs
grid_images = dict() # seed -> Image, only used if `opt.grid`
prior_variations = opt.with_variations or []
first_seed = opt.seed

def image_writer(image, seed, upscaled=False):
def image_writer(image, seed, upscaled=False, first_seed=None):
# note the seed is the seed of the current image
# the first_seed is the original seed that noise is added to
# when the -v switch is used to generate variations
path = None
nonlocal first_seed
nonlocal prior_variations
if opt.grid:
grid_images[seed] = image
Expand Down

0 comments on commit b93f04e

Please sign in to comment.