Skip to content

Commit 4241ff9

Browse files
committed
added inversion + PTI
1 parent f9be58e commit 4241ff9

File tree

4 files changed

+352
-6
lines changed

4 files changed

+352
-6
lines changed

README.md

+16-6
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,6 @@ If you find our code or paper useful, please cite
3434
- StyleGAN-XL + CLIP (Implemented by CasualGANPapers)  -  [![StyleGAN-XL + CLIP](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/CasualGANPapers/unconditional-StyleGANXL-CLIP/blob/main/StyleganXL%2BCLIP.ipynb)
3535
- StyleGAN-XL + CLIP (Modified by Katherine Crowson to optimize in W+ space)  -  [![StyleGAN-XL + CLIP](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1ZEnJE-EUnh-aCXJbu0kVhi8_Qdi2BV-S)
3636

37-
## ToDos
38-
- [x] Initial code release
39-
- [x] Add pretrained models (ImageNet{16,32,64,128,256,512,1024}, FFHQ{256,512,1024}, Pokemon{256,512,1024})
40-
- [x] Add StyleMC for editing
41-
- [ ] Add PTI for inversion
42-
4337
## Requirements ##
4438
- 64-bit Python 3.8 and PyTorch 1.9.0 (or later). See https://pytorch.org for PyTorch install instructions.
4539
- CUDA toolkit 11.1 or later.
@@ -134,6 +128,22 @@ GAN](https://self-distilled-stylegan.github.io/)). We generated 600k find 10k cl
134128
:--- | :---: | :---:
135129
<img src="media/no_truncation.png"> | <img src="media/unimodal_truncation.png">| <img src="media/multimodal_truncation.png">
136130

131+
## Image Inversion ##
132+
<p align="center">
133+
<img src="media/inversion.gif" width="60%">
134+
</p>
135+
136+
To invert a given image via latent optimization, and optionally use our reimplementation of [Pivotal Tuning Inversion](https://arxiv.org/abs/2106.05744), run
137+
138+
```
139+
python run_inversion.py --outdir=inversion_out \
140+
--target media/jay.png \
141+
--inv-steps 1000 --run-pti --pti-steps 350 \
142+
--network=https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/imagenet512.pkl
143+
```
144+
145+
Provide an image via ```target```, it is automatically resized and center-cropped to match the generator network. You do not need to provide a class for ImageNet models, we infer the class of a given sample via a pretrained classifier.
146+
137147
## Image Editing ##
138148
<img src="media/editing_banner.png">
139149

media/inversion.gif

7.29 MB
Loading

media/jay.png

18.7 KB
Loading

run_inversion.py

+336
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,336 @@
1+
"""Project given image to the latent space of pretrained network pickle."""
2+
3+
import copy
4+
import os
5+
from time import perf_counter
6+
7+
import dill
8+
import click
9+
import imageio
10+
import numpy as np
11+
import PIL.Image
12+
import torch
13+
import torch.nn.functional as F
14+
15+
from tqdm import trange
16+
import dnnlib
17+
import legacy
18+
from metrics import metric_utils
19+
import timm
20+
21+
from training.diffaug import DiffAugment
22+
from pg_modules.blocks import Interpolate
23+
24+
25+
def get_morphed_w_code(new_w_code, fixed_w, regularizer_alpha=30):
26+
interpolation_direction = new_w_code - fixed_w
27+
interpolation_direction_norm = torch.norm(interpolation_direction, p=2)
28+
direction_to_move = regularizer_alpha * interpolation_direction / interpolation_direction_norm
29+
result_w = fixed_w + direction_to_move
30+
return result_w
31+
32+
33+
def space_regularizer_loss(
34+
G_pti,
35+
G_original,
36+
w_batch,
37+
vgg16,
38+
num_of_sampled_latents=1,
39+
lpips_lambda=10,
40+
):
41+
42+
z_samples = np.random.randn(num_of_sampled_latents, G_original.z_dim)
43+
z_samples = torch.from_numpy(z_samples).to(w_batch.device)
44+
45+
if not G_original.c_dim:
46+
c_samples = None
47+
else:
48+
c_samples = F.one_hot(torch.randint(G_original.c_dim, (num_of_sampled_latents,)), G_original.c_dim)
49+
c_samples = c_samples.to(w_batch.device)
50+
51+
w_samples = G_original.mapping(z_samples, c_samples, truncation_psi=0.5)
52+
territory_indicator_ws = [get_morphed_w_code(w_code.unsqueeze(0), w_batch) for w_code in w_samples]
53+
54+
for w_code in territory_indicator_ws:
55+
new_img = G_pti.synthesis(w_code, noise_mode='none', force_fp32=True)
56+
with torch.no_grad():
57+
old_img = G_original.synthesis(w_code, noise_mode='none', force_fp32=True)
58+
59+
# Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images.
60+
if new_img.shape[-1] > 256:
61+
new_img = F.interpolate(new_img, size=(256, 256), mode='area')
62+
old_img = F.interpolate(old_img, size=(256, 256), mode='area')
63+
64+
new_feat = vgg16(new_img, resize_images=False, return_lpips=True)
65+
old_feat = vgg16(old_img, resize_images=False, return_lpips=True)
66+
lpips_loss = lpips_lambda * (old_feat - new_feat).square().sum()
67+
68+
return lpips_loss / len(territory_indicator_ws)
69+
70+
71+
def pivotal_tuning(
72+
G,
73+
w_pivot,
74+
target,
75+
device: torch.device,
76+
num_steps=350,
77+
learning_rate = 3e-4,
78+
noise_mode="const",
79+
verbose = False,
80+
):
81+
G_original = copy.deepcopy(G).eval().requires_grad_(False).to(device)
82+
G_pti = copy.deepcopy(G).train().requires_grad_(True).to(device)
83+
w_pivot.requires_grad_(False)
84+
85+
# Load VGG16 feature detector.
86+
vgg16_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/vgg16.pkl'
87+
vgg16 = metric_utils.get_feature_detector(vgg16_url, device=device)
88+
89+
# l2 criterion
90+
l2_criterion = torch.nn.MSELoss(reduction='mean')
91+
92+
# Features for target image.
93+
target_images = target.unsqueeze(0).to(device).to(torch.float32)
94+
if target_images.shape[2] > 256:
95+
target_images = F.interpolate(target_images, size=(256, 256), mode='area')
96+
target_features = vgg16(target_images, resize_images=False, return_lpips=True)
97+
98+
# initalize optimizer
99+
optimizer = torch.optim.Adam(G_pti.parameters(), lr=learning_rate)
100+
101+
# run optimization loop
102+
all_images = []
103+
for step in range(num_steps):
104+
# Synth images from opt_w.
105+
synth_images = G_pti.synthesis(w_pivot[0].repeat(1,G.num_ws,1), noise_mode=noise_mode)
106+
107+
# track images
108+
synth_images = (synth_images + 1) * (255/2)
109+
synth_images_np = synth_images.clone().detach().permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy()
110+
all_images.append(synth_images_np)
111+
112+
# Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images.
113+
if synth_images.shape[2] > 256:
114+
synth_images = F.interpolate(synth_images, size=(256, 256), mode='area')
115+
116+
# LPIPS loss
117+
synth_features = vgg16(synth_images, resize_images=False, return_lpips=True)
118+
lpips_loss = (target_features - synth_features).square().sum()
119+
120+
# MSE loss
121+
mse_loss = l2_criterion(target_images, synth_images)
122+
123+
# space regularizer
124+
reg_loss = space_regularizer_loss(G_pti, G_original, w_pivot, vgg16)
125+
126+
# Step
127+
optimizer.zero_grad(set_to_none=True)
128+
loss = mse_loss + lpips_loss + reg_loss
129+
loss.backward()
130+
optimizer.step()
131+
132+
msg = f'[ step {step+1:>4d}/{num_steps}] '
133+
msg += f'[ loss: {float(loss):<5.2f}] '
134+
msg += f'[ lpips: {float(lpips_loss):<5.2f}] '
135+
msg += f'[ mse: {float(mse_loss):<5.2f}]'
136+
msg += f'[ reg: {float(reg_loss):<5.2f}]'
137+
if verbose: print(msg)
138+
139+
return all_images, G_pti
140+
141+
142+
def project(
143+
G,
144+
target: torch.Tensor, # [C,H,W] and dynamic range [0,255], W & H must match G output resolution
145+
*,
146+
num_steps = 1000,
147+
w_avg_samples = 10000,
148+
initial_learning_rate = 0.1,
149+
lr_rampdown_length = 0.25,
150+
lr_rampup_length = 0.05,
151+
verbose = False,
152+
device: torch.device,
153+
noise_mode="const",
154+
):
155+
assert target.shape == (G.img_channels, G.img_resolution, G.img_resolution)
156+
157+
G = copy.deepcopy(G).eval().requires_grad_(False).to(device) # type: ignore
158+
159+
# Compute w stats.
160+
print(f'Computing W midpoint and stddev using {w_avg_samples} samples...')
161+
z_samples = torch.from_numpy(np.random.RandomState(123).randn(w_avg_samples, G.z_dim)).to(device)
162+
163+
# get class probas by classifier
164+
if not G.c_dim:
165+
c_samples = None
166+
else:
167+
classifier = timm.create_model('deit_base_distilled_patch16_224', pretrained=True).eval().to(device)
168+
cls_target = F.interpolate((target.to(device).to(torch.float32) / 127.5 - 1)[None], 224)
169+
logits = classifier(cls_target).softmax(1)
170+
classes = torch.multinomial(logits, w_avg_samples, replacement=True).squeeze()
171+
print(f'Main class: {logits.argmax(1).item()}, confidence: {logits.max().item():.4f}')
172+
c_samples = np.zeros([w_avg_samples, G.c_dim], dtype=np.float32)
173+
for i, c in enumerate(classes):
174+
c_samples[i, c] = 1
175+
c_samples = torch.from_numpy(c_samples).to(device)
176+
177+
w_samples = G.mapping(z_samples, c_samples) # [N, L, C]
178+
179+
# get empirical w_avg
180+
w_samples = w_samples[:, :1, :].cpu().numpy().astype(np.float32) # [N, 1, C]
181+
w_avg = np.mean(w_samples, axis=0, keepdims=True) # [1, 1, C]
182+
183+
# Load VGG16 feature detector.
184+
vgg16_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/vgg16.pkl'
185+
vgg16 = metric_utils.get_feature_detector(vgg16_url, device=device)
186+
187+
# Features for target image.
188+
target_images = target.unsqueeze(0).to(device).to(torch.float32)
189+
if target_images.shape[2] > 256:
190+
target_images = F.interpolate(target_images, size=(256, 256), mode='area')
191+
target_features = vgg16(target_images, resize_images=False, return_lpips=True)
192+
193+
# initalize optimizer
194+
w_opt = torch.tensor(w_avg, dtype=torch.float32, device=device, requires_grad=True) # pylint: disable=not-callable
195+
optimizer = torch.optim.Adam([w_opt], betas=(0.9, 0.999), lr=initial_learning_rate)
196+
197+
# run optimization loop
198+
all_images = []
199+
for step in range(num_steps):
200+
# Learning rate schedule.
201+
t = step / num_steps
202+
lr_ramp = min(1.0, (1.0 - t) / lr_rampdown_length)
203+
lr_ramp = 0.5 - 0.5 * np.cos(lr_ramp * np.pi)
204+
lr_ramp = lr_ramp * min(1.0, t / lr_rampup_length)
205+
lr = initial_learning_rate * lr_ramp
206+
for param_group in optimizer.param_groups:
207+
param_group['lr'] = lr
208+
209+
# Synth images from opt_w.
210+
synth_images = G.synthesis(w_opt[0].repeat(1,G.num_ws,1), noise_mode=noise_mode)
211+
212+
# track images
213+
synth_images = (synth_images + 1) * (255/2)
214+
synth_images_np = synth_images.clone().detach().permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy()
215+
all_images.append(synth_images_np)
216+
217+
# Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images.
218+
if synth_images.shape[2] > 256:
219+
synth_images = F.interpolate(synth_images, size=(256, 256), mode='area')
220+
221+
# Features for synth images.
222+
synth_features = vgg16(synth_images, resize_images=False, return_lpips=True)
223+
lpips_loss = (target_features - synth_features).square().sum()
224+
225+
# Step
226+
optimizer.zero_grad(set_to_none=True)
227+
loss = lpips_loss
228+
loss.backward()
229+
optimizer.step()
230+
msg = f'[ step {step+1:>4d}/{num_steps}] '
231+
msg += f'[ loss: {float(loss):<5.2f}] '
232+
if verbose: print(msg)
233+
234+
return all_images, w_opt.detach()[0]
235+
236+
237+
@click.command()
238+
@click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
239+
@click.option('--target', 'target_fname', help='Target image file to project to', required=True, metavar='FILE')
240+
@click.option('--seed', help='Random seed', type=int, default=42, show_default=True)
241+
@click.option('--save-video', help='Save an mp4 video of optimization progress', type=bool, default=True, show_default=True)
242+
@click.option('--outdir', help='Where to save the output images', required=True, metavar='DIR')
243+
@click.option('--inv-steps', help='Number of inversion steps', type=int, default=1000, show_default=True)
244+
@click.option('--w-init', help='path to inital latent', type=str, default='', show_default=True)
245+
@click.option('--run-pti', help='run pivotal tuning', is_flag=True)
246+
@click.option('--pti-steps', help='Number of pti steps', type=int, default=350, show_default=True)
247+
def run_projection(
248+
network_pkl: str,
249+
target_fname: str,
250+
outdir: str,
251+
save_video: bool,
252+
seed: int,
253+
inv_steps: int,
254+
w_init: str,
255+
run_pti: bool,
256+
pti_steps: int,
257+
):
258+
np.random.seed(seed)
259+
torch.manual_seed(seed)
260+
261+
# Load networks.
262+
print('Loading networks from "%s"...' % network_pkl)
263+
device = torch.device('cuda')
264+
with dnnlib.util.open_url(network_pkl) as fp:
265+
G = legacy.load_network_pkl(fp)['G_ema'].to(device) # type: ignore
266+
267+
# Load target image.
268+
target_pil = PIL.Image.open(target_fname).convert('RGB')
269+
w, h = target_pil.size
270+
s = min(w, h)
271+
target_pil = target_pil.crop(((w - s) // 2, (h - s) // 2, (w + s) // 2, (h + s) // 2))
272+
target_pil = target_pil.resize((G.img_resolution, G.img_resolution), PIL.Image.LANCZOS)
273+
target_uint8 = np.array(target_pil, dtype=np.uint8)
274+
275+
# Latent optimization
276+
start_time = perf_counter()
277+
all_images = []
278+
if not w_init:
279+
print('Running Latent Optimization...')
280+
all_images, projected_w = project(
281+
G,
282+
target=torch.tensor(target_uint8.transpose([2, 0, 1]), device=device), # pylint: disable=not-callable
283+
num_steps=inv_steps,
284+
device=device,
285+
verbose=True,
286+
noise_mode='const',
287+
)
288+
print(f'Elapsed time: {(perf_counter()-start_time):.1f} s')
289+
else:
290+
projected_w = torch.from_numpy(np.load(w_init)['w'])[0].to(device)
291+
292+
start_time = perf_counter()
293+
294+
# Run PTI
295+
if run_pti:
296+
print('Running Pivotal Tuning Inversion...')
297+
gen_images, G = pivotal_tuning(
298+
G,
299+
projected_w,
300+
target=torch.tensor(target_uint8.transpose([2, 0, 1]), device=device),
301+
device=device,
302+
num_steps=pti_steps,
303+
verbose=True,
304+
)
305+
all_images += gen_images
306+
print(f'Elapsed time: {(perf_counter()-start_time):.1f} s')
307+
308+
# Render debug output: optional video and projected image and W vector.
309+
os.makedirs(outdir, exist_ok=True)
310+
if save_video:
311+
video = imageio.get_writer(f'{outdir}/proj.mp4', mode='I', fps=60, codec='libx264', bitrate='16M')
312+
print (f'Saving optimization progress video "{outdir}/proj.mp4"')
313+
for synth_image in all_images:
314+
video.append_data(np.concatenate([target_uint8, synth_image], axis=1))
315+
video.close()
316+
317+
# Save final projected frame and W vector.
318+
target_pil.save(f'{outdir}/target.png')
319+
synth_image = G.synthesis(projected_w.repeat(1, G.num_ws, 1))
320+
synth_image = (synth_image + 1) * (255/2)
321+
synth_image = synth_image.permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy()
322+
PIL.Image.fromarray(synth_image, 'RGB').save(f'{outdir}/proj.png')
323+
324+
# save latents
325+
np.savez(f'{outdir}/projected_w.npz', w=projected_w.unsqueeze(0).cpu().numpy())
326+
327+
# Save Generator weights
328+
snapshot_data = {'G': G, 'G_ema': G}
329+
with open(f"{outdir}/G.pkl", 'wb') as f:
330+
dill.dump(snapshot_data, f)
331+
332+
#----------------------------------------------------------------------------
333+
334+
335+
if __name__ == "__main__":
336+
run_projection() # pylint: disable=no-value-for-parameter

0 commit comments

Comments
 (0)