Skip to content

Commit c0d27b1

Browse files
author
xl-sr
committed
improvements to gen_images/gen_cond_samplesheet. Added multi-modal truncation'
1 parent ec732c2 commit c0d27b1

File tree

5 files changed

+122
-134
lines changed

5 files changed

+122
-134
lines changed

README.md

+7-10
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ If you have enough compute, a good tactic is to train several stages in parallel
102102

103103
To generate samples and interpolation videos, run
104104
```
105-
python gen_images.py --outdir=out --trunc=0.7 --seeds=10-15 \
105+
python gen_images.py --outdir=out --trunc=0.7 --seeds=10-15 --batch-sz 1 \
106106
--network=https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/pokemon256.pkl
107107
```
108108
and
@@ -112,27 +112,24 @@ python gen_video.py --output=lerp.mp4 --trunc=0.7 --seeds=0-31 --grid=4x2 \
112112
```
113113
For class-conditional models, you can pass the class index via ```--class```, a index-to-label dictionary for Imagenet can be found [here](https://github.com/xl-sr/stylegan_xl_release/blob/main/media/imagenet_idx2labels.txt).
114114

115-
Generating large sample sheets:
115+
Generate a conditional sample sheets:
116116
```
117-
# unconditional model
118-
python gen_samplesheet.py --outdir=sample_sheets --trunc=1.0 \
119-
--network=https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/pokemon256.pkl \
120-
--samples-per-class 128
121-
122-
# conditional model
123117
python gen_samplesheet.py --outdir=sample_sheets --trunc=1.0 \
124118
--network=https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/imagenet128.pkl \
125-
--max-classes 100 --samples-per-class 4 --classes-per-row 5
119+
--samples-per-class 4 --classes 0-32 --grid-width 32 \\
126120
```
127121

122+
For the ImageNet models, we enable multi-modal truncation (as proposed by [self-distilled
123+
GAN](https://self-distilled-stylegan.github.io/)). To switch from uni- to multi-modal truncation, pass ``` centroids-path=https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/imagenet_centroids.npy```.
124+
128125
We provide the following pretrained models (pass the url as `PATH_TO_NETWORK_PKL`):
129126

130127
|Dataset| Res | FID | PATH
131128
:--- | ---: | ---: | :---
132129
ImageNet| 16<sup>2</sup> |0.74| <sub>`https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/imagenet16.pkl`</sub><br>
133130
ImageNet| 32<sup>2</sup> |1.11| <sub>`https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/imagenet32.pkl`</sub><br>
134131
ImageNet| 64<sup>2</sup> |1.52| <sub>`https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/imagenet64.pkl`</sub><br>
135-
ImageNet| 128<sup>2</sup> |1.82| <sub>`https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/imagenet128.pkl`</sub><br>
132+
ImageNet| 128<sup>2</sup> |1.77| <sub>`https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/imagenet128.pkl`</sub><br>
136133
CIFAR10 | 32<sup>2</sup> |1.85| <sub>`https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/cifar10.pkl`</sub><br>
137134
FFHQ | 256<sup>2</sup> |2.19| <sub>`https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/ffhq256.pkl`</sub><br>
138135
Pokemon | 256<sup>2</sup> |23.97| <sub>`https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/pokemon256.pkl`</sub><br>

gen_class_samplesheet.py

+70
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import os
2+
from pathlib import Path
3+
import PIL.Image
4+
from typing import List
5+
import click
6+
import numpy as np
7+
import torch
8+
from tqdm import tqdm
9+
10+
import legacy
11+
import dnnlib
12+
from training.training_loop import save_image_grid
13+
from torch_utils import gen_utils
14+
from gen_images import parse_range
15+
16+
@click.command()
17+
@click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
18+
@click.option('--trunc', 'truncation_psi', help='Truncation psi', type=float, default=1, show_default=True)
19+
@click.option('--seed', help='Random seed', type=int, default=42)
20+
@click.option('--centroids-path', type=str, help='Pass path to precomputed centroids to enable multimodal truncation')
21+
@click.option('--classes', type=parse_range, help='List of classes (e.g., \'0,1,4-6\')', required=True)
22+
@click.option('--samples-per-class', help='Samples per class.', type=int, default=4)
23+
@click.option('--grid-width', help='Total width of image grid', type=int, default=32)
24+
@click.option('--batch-gpu', help='Samples per pass, adapt to fit on GPU', type=int, default=32)
25+
@click.option('--outdir', help='Where to save the output images', type=str, required=True, metavar='DIR')
26+
@click.option('--desc', help='String to include in result dir name', metavar='STR', type=str)
27+
def generate_samplesheet(
28+
network_pkl: str,
29+
truncation_psi: float,
30+
seed: int,
31+
centroids_path: str,
32+
classes: List[int],
33+
samples_per_class: int,
34+
batch_gpu: int,
35+
grid_width: int,
36+
outdir: str,
37+
desc: str,
38+
):
39+
print('Loading networks from "%s"...' % network_pkl)
40+
device = torch.device('cuda')
41+
with dnnlib.util.open_url(network_pkl) as f:
42+
G = legacy.load_network_pkl(f)['G_ema'].to(device).requires_grad_(False)
43+
44+
# setup
45+
os.makedirs(outdir, exist_ok=True)
46+
desc_full = f'{Path(network_pkl).stem}_trunc_{truncation_psi}'
47+
if desc is not None: desc_full += f'-{desc}'
48+
run_dir = Path(gen_utils.make_run_dir(outdir, desc_full))
49+
50+
print('Generating latents.')
51+
ws = []
52+
for class_idx in tqdm(classes):
53+
w = gen_utils.get_w_from_seed(G, samples_per_class, device, truncation_psi, seed=seed,
54+
centroids_path=centroids_path, class_idx=class_idx)
55+
ws.append(w)
56+
ws = torch.cat(ws)
57+
58+
print('Generating samples.')
59+
images = []
60+
for w in tqdm(ws.split(batch_gpu)):
61+
img = gen_utils.w_to_img(G, w, to_np=True)
62+
images.append(img)
63+
64+
# adjust grid widht to prohibit folding between same class then save to disk
65+
grid_width = grid_width - grid_width % samples_per_class
66+
images = gen_utils.create_image_grid(np.concatenate(images), grid_size=(grid_width, None))
67+
PIL.Image.fromarray(images, 'RGB').save(run_dir / 'sheet.png')
68+
69+
if __name__ == "__main__":
70+
generate_samplesheet()

gen_images.py

+11-30
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import torch
2020

2121
import legacy
22+
from torch_utils import gen_utils
2223

2324
#----------------------------------------------------------------------------
2425

@@ -71,7 +72,9 @@ def make_transform(translate: Tuple[float,float], angle: float):
7172
@click.command()
7273
@click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
7374
@click.option('--seeds', type=parse_range, help='List of random seeds (e.g., \'0,1,4-6\')', required=True)
75+
@click.option('--batch-sz', type=int, help='Batch size per sample', default=1)
7476
@click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True)
77+
@click.option('--centroids-path', type=str, help='Pass path to precomputed centroids to enable multimodal truncation')
7578
@click.option('--class', 'class_idx', type=int, help='Class label (unconditional if not specified)')
7679
@click.option('--noise-mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True)
7780
@click.option('--translate', help='Translate XY-coordinate (e.g. \'0.3,1\')', type=parse_vec2, default='0,0', show_default=True, metavar='VEC2')
@@ -80,49 +83,26 @@ def make_transform(translate: Tuple[float,float], angle: float):
8083
def generate_images(
8184
network_pkl: str,
8285
seeds: List[int],
86+
batch_sz: int,
8387
truncation_psi: float,
88+
centroids_path: str,
8489
noise_mode: str,
8590
outdir: str,
8691
translate: Tuple[float,float],
8792
rotate: float,
8893
class_idx: Optional[int]
8994
):
90-
"""Generate images using pretrained network pickle.
91-
92-
Examples:
93-
94-
\b
95-
# Generate an image using pre-trained AFHQv2 model ("Ours" in Figure 1, left).
96-
python gen_images.py --outdir=out --trunc=1 --seeds=2 \\
97-
--network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-afhqv2-512x512.pkl
98-
99-
\b
100-
# Generate uncurated images with truncation using the MetFaces-U dataset
101-
python gen_images.py --outdir=out --trunc=0.7 --seeds=600-605 \\
102-
--network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-metfacesu-1024x1024.pkl
103-
"""
104-
10595
print('Loading networks from "%s"...' % network_pkl)
10696
device = torch.device('cuda')
10797
with dnnlib.util.open_url(network_pkl) as f:
108-
G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore
98+
G = legacy.load_network_pkl(f)['G_ema']
99+
G = G.eval().requires_grad_(False).to(device)
109100

110101
os.makedirs(outdir, exist_ok=True)
111102

112-
# Labels.
113-
label = torch.zeros([1, G.c_dim], device=device)
114-
if G.c_dim != 0:
115-
if class_idx is None:
116-
raise click.ClickException('Must specify class label with --class when using a conditional network')
117-
label[:, class_idx] = 1
118-
else:
119-
if class_idx is not None:
120-
print ('warn: --class=lbl ignored when running on an unconditional network')
121-
122103
# Generate images.
123104
for seed_idx, seed in enumerate(seeds):
124105
print('Generating image for seed %d (%d/%d) ...' % (seed, seed_idx, len(seeds)))
125-
z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device)
126106

127107
# Construct an inverse rotation/translation matrix and pass to the generator. The
128108
# generator expects this matrix as an inverse to avoid potentially failing numerical
@@ -132,9 +112,10 @@ def generate_images(
132112
m = np.linalg.inv(m)
133113
G.synthesis.input.transform.copy_(torch.from_numpy(m))
134114

135-
img = G(z, label, truncation_psi=truncation_psi, noise_mode=noise_mode)
136-
img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
137-
PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB').save(f'{outdir}/seed{seed:04d}.png')
115+
w = gen_utils.get_w_from_seed(G, batch_sz, device, truncation_psi, seed=seed,
116+
centroids_path=centroids_path, class_idx=class_idx)
117+
img = gen_utils.w_to_img(G, w, to_np=True)
118+
PIL.Image.fromarray(gen_utils.create_image_grid(img), 'RGB').save(f'{outdir}/seed{seed:04d}.png')
138119

139120

140121
#----------------------------------------------------------------------------

gen_samplesheet.py

-85
This file was deleted.

torch_utils/gen_utils.py

+34-9
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
import click
1010
import numpy as np
1111
import torch
12+
import torch.nn.functional as F
13+
import dnnlib
1214

1315

1416
# ----------------------------------------------------------------------------
@@ -413,29 +415,52 @@ def w_to_img(G, dlatents: Union[List[torch.Tensor], torch.Tensor], noise_mode: s
413415
assert isinstance(dlatents, torch.Tensor), f'dlatents should be a torch.Tensor!: "{type(dlatents)}"'
414416
if len(dlatents.shape) == 2:
415417
dlatents = dlatents.unsqueeze(0) # An individual dlatent => [1, G.mapping.num_ws, G.mapping.w_dim]
418+
416419
synth_image = G.synthesis(dlatents, noise_mode=noise_mode)
417420
synth_image = (synth_image + 1) * 255/2 # [-1.0, 1.0] -> [0.0, 255.0]
418421
if to_np:
419422
synth_image = synth_image.permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8).cpu().numpy() # NCWH => NWHC
420423
return synth_image
421424

422425

423-
def get_w_from_seed(G, device: torch.device, seed: int, truncation_psi: float, class_idx: Optional[int]) -> torch.Tensor:
424-
"""Get the dlatent from a random seed, using the truncation trick (this could be optional)"""
426+
def get_w_from_seed(G, batch_sz: int, device: torch.device, truncation_psi: float, seed: Optional[int], centroids_path: Optional[str], class_idx: Optional[int]) -> torch.Tensor:
427+
"""Get the dlatent from a list of random seeds, using the truncation trick (this could be optional)"""
425428

426-
label = torch.zeros([1, G.c_dim], device=device)
427429
if G.c_dim != 0:
430+
# sample random labels if no class idx is given
428431
if class_idx is None:
429-
raise click.ClickException('Must specify class label via --class when using a conditional network')
430-
w_avg = G.mapping.w_avg[class_idx]
431-
label[:, class_idx] = 1
432+
class_indices = np.random.RandomState(seed).randint(low=0, high=G.c_dim, size=(batch_sz))
433+
class_indices = torch.from_numpy(class_indices).to(device)
434+
w_avg = G.mapping.w_avg.index_select(0, class_indices)
435+
else:
436+
w_avg = G.mapping.w_avg[class_idx].unsqueeze(0).repeat(batch_sz, 1)
437+
class_indices = torch.full((batch_sz,), class_idx).to(device)
438+
439+
labels = F.one_hot(class_indices, G.c_dim)
440+
432441
else:
433-
w_avg = G.mapping.w_avg
442+
w_avg = G.mapping.w_avg.unsqueeze(0)
443+
labels = None
434444
if class_idx is not None:
435445
print('Warning: --class is ignored when running an unconditional network')
436446

437-
z = np.random.RandomState(seed).randn(1, G.z_dim)
438-
w = G.mapping(torch.from_numpy(z).to(device), label)
447+
z = np.random.RandomState(seed).randn(batch_sz, G.z_dim)
448+
z = torch.from_numpy(z).to(device)
449+
w = G.mapping(z, labels)
450+
451+
# multimodal truncation
452+
if centroids_path is not None:
453+
454+
with dnnlib.util.open_url(centroids_path, verbose=False) as f:
455+
w_centroids = np.load(f)
456+
w_centroids = torch.from_numpy(w_centroids).to(device)
457+
w_centroids = w_centroids[None].repeat(batch_sz, 1, 1)
458+
459+
# measure distances
460+
dist = torch.norm(w_centroids - w[:, :1], dim=2, p=2)
461+
w_avg = w_centroids[0].index_select(0, dist.argmin(1))
462+
463+
w_avg = w_avg.unsqueeze(1).repeat(1, G.mapping.num_ws, 1)
439464
w = w_avg + (w - w_avg) * truncation_psi
440465

441466
return w

0 commit comments

Comments
 (0)