19
19
import torch
20
20
21
21
import legacy
22
+ from torch_utils import gen_utils
22
23
23
24
#----------------------------------------------------------------------------
24
25
@@ -71,7 +72,9 @@ def make_transform(translate: Tuple[float,float], angle: float):
71
72
@click .command ()
72
73
@click .option ('--network' , 'network_pkl' , help = 'Network pickle filename' , required = True )
73
74
@click .option ('--seeds' , type = parse_range , help = 'List of random seeds (e.g., \' 0,1,4-6\' )' , required = True )
75
+ @click .option ('--batch-sz' , type = int , help = 'Batch size per sample' , default = 1 )
74
76
@click .option ('--trunc' , 'truncation_psi' , type = float , help = 'Truncation psi' , default = 1 , show_default = True )
77
+ @click .option ('--centroids-path' , type = str , help = 'Pass path to precomputed centroids to enable multimodal truncation' )
75
78
@click .option ('--class' , 'class_idx' , type = int , help = 'Class label (unconditional if not specified)' )
76
79
@click .option ('--noise-mode' , help = 'Noise mode' , type = click .Choice (['const' , 'random' , 'none' ]), default = 'const' , show_default = True )
77
80
@click .option ('--translate' , help = 'Translate XY-coordinate (e.g. \' 0.3,1\' )' , type = parse_vec2 , default = '0,0' , show_default = True , metavar = 'VEC2' )
@@ -80,49 +83,26 @@ def make_transform(translate: Tuple[float,float], angle: float):
80
83
def generate_images (
81
84
network_pkl : str ,
82
85
seeds : List [int ],
86
+ batch_sz : int ,
83
87
truncation_psi : float ,
88
+ centroids_path : str ,
84
89
noise_mode : str ,
85
90
outdir : str ,
86
91
translate : Tuple [float ,float ],
87
92
rotate : float ,
88
93
class_idx : Optional [int ]
89
94
):
90
- """Generate images using pretrained network pickle.
91
-
92
- Examples:
93
-
94
- \b
95
- # Generate an image using pre-trained AFHQv2 model ("Ours" in Figure 1, left).
96
- python gen_images.py --outdir=out --trunc=1 --seeds=2 \\
97
- --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-afhqv2-512x512.pkl
98
-
99
- \b
100
- # Generate uncurated images with truncation using the MetFaces-U dataset
101
- python gen_images.py --outdir=out --trunc=0.7 --seeds=600-605 \\
102
- --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-metfacesu-1024x1024.pkl
103
- """
104
-
105
95
print ('Loading networks from "%s"...' % network_pkl )
106
96
device = torch .device ('cuda' )
107
97
with dnnlib .util .open_url (network_pkl ) as f :
108
- G = legacy .load_network_pkl (f )['G_ema' ].to (device ) # type: ignore
98
+ G = legacy .load_network_pkl (f )['G_ema' ]
99
+ G = G .eval ().requires_grad_ (False ).to (device )
109
100
110
101
os .makedirs (outdir , exist_ok = True )
111
102
112
- # Labels.
113
- label = torch .zeros ([1 , G .c_dim ], device = device )
114
- if G .c_dim != 0 :
115
- if class_idx is None :
116
- raise click .ClickException ('Must specify class label with --class when using a conditional network' )
117
- label [:, class_idx ] = 1
118
- else :
119
- if class_idx is not None :
120
- print ('warn: --class=lbl ignored when running on an unconditional network' )
121
-
122
103
# Generate images.
123
104
for seed_idx , seed in enumerate (seeds ):
124
105
print ('Generating image for seed %d (%d/%d) ...' % (seed , seed_idx , len (seeds )))
125
- z = torch .from_numpy (np .random .RandomState (seed ).randn (1 , G .z_dim )).to (device )
126
106
127
107
# Construct an inverse rotation/translation matrix and pass to the generator. The
128
108
# generator expects this matrix as an inverse to avoid potentially failing numerical
@@ -132,9 +112,10 @@ def generate_images(
132
112
m = np .linalg .inv (m )
133
113
G .synthesis .input .transform .copy_ (torch .from_numpy (m ))
134
114
135
- img = G (z , label , truncation_psi = truncation_psi , noise_mode = noise_mode )
136
- img = (img .permute (0 , 2 , 3 , 1 ) * 127.5 + 128 ).clamp (0 , 255 ).to (torch .uint8 )
137
- PIL .Image .fromarray (img [0 ].cpu ().numpy (), 'RGB' ).save (f'{ outdir } /seed{ seed :04d} .png' )
115
+ w = gen_utils .get_w_from_seed (G , batch_sz , device , truncation_psi , seed = seed ,
116
+ centroids_path = centroids_path , class_idx = class_idx )
117
+ img = gen_utils .w_to_img (G , w , to_np = True )
118
+ PIL .Image .fromarray (gen_utils .create_image_grid (img ), 'RGB' ).save (f'{ outdir } /seed{ seed :04d} .png' )
138
119
139
120
140
121
#----------------------------------------------------------------------------
0 commit comments