Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions hps.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,4 +118,6 @@ def add_imle_arguments(parser):
parser.add_argument('--ppl_save_name', type=str, default='ppl')
parser.add_argument("--fid_factor", type=int, default=5, help="number of the samples for calculating FID")
parser.add_argument("--fid_freq", type=int, default=5, help="frequency of calculating fid")
parser.add_argument("--num_steps", type=int, default=4000, help="frequency of calculating fid")
parser.add_argument("--pool_staleness", type=int, default=3, help="frequency of calculating fid")
return parser
2 changes: 1 addition & 1 deletion mapping_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class MappingNetowrk(nn.Module):
def __init__(self, code_dim=512, n_mlp=8):
super().__init__()

layers = [PixelNorm()]
layers = []
for i in range(n_mlp):
layers.append(EqualLinear(code_dim, code_dim))
layers.append(nn.LeakyReLU(0.2))
Expand Down
141 changes: 44 additions & 97 deletions sampler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from curses import update_lines_cols
from math import comb
import time

import numpy as np
Expand All @@ -23,6 +22,7 @@ def __init__(self, H, sz, preprocess_fn):
self.latent_lr = H.latent_lr
self.entire_ds = torch.arange(sz)
self.selected_latents = torch.empty([sz, H.latent_dim], dtype=torch.float32)
self.selected_indices = torch.empty([sz], dtype=torch.int64)
self.selected_latents_tmp = torch.empty([sz, H.latent_dim], dtype=torch.float32)

blocks = parse_layer_string(H.dec_blocks)
Expand Down Expand Up @@ -62,6 +62,7 @@ def __init__(self, H, sz, preprocess_fn):
self.dataset_proj = torch.empty([sz, sum(dims)], dtype=torch.float32)
self.pool_samples_proj = torch.empty([self.pool_size, sum(dims)], dtype=torch.float32)
self.snoise_pool_samples_proj = torch.empty([sz * H.snoise_factor, sum(dims)], dtype=torch.float32)
self.pool_last_updated = torch.zeros([self.pool_size], dtype=torch.int64)

def get_projected(self, inp, permute=True):
if permute:
Expand Down Expand Up @@ -139,93 +140,52 @@ def calc_dists_existing(self, dataset_tensor, gen, dists=None, latents=None, to_
dists[batch_slice] = torch.squeeze(dist)
return dists

def imle_sample(self, dataset, gen, factor=None):
if factor is None:
factor = self.H.imle_factor
imle_pool_size = int(len(dataset) * factor)
t1 = time.time()
self.selected_dists_tmp[:] = self.selected_dists[:]
for i in range(imle_pool_size // self.H.imle_db_size):
self.temp_latent_rnds.normal_()
for j in range(len(self.res)):
self.snoise_tmp[j].normal_()
for j in range(self.H.imle_db_size // self.H.imle_batch):
batch_slice = slice(j * self.H.imle_batch, (j + 1) * self.H.imle_batch)
cur_latents = self.temp_latent_rnds[batch_slice]
cur_snoise = [x[batch_slice] for x in self.snoise_tmp]
with torch.no_grad():
self.temp_samples[batch_slice] = gen(cur_latents, cur_snoise)
self.temp_samples_proj[batch_slice] = self.get_projected(self.temp_samples[batch_slice], False)

if not gen.module.dci_db:
device_count = torch.cuda.device_count()
gen.module.dci_db = MDCI(self.temp_samples_proj.shape[1], num_comp_indices=self.H.num_comp_indices,
num_simp_indices=self.H.num_simp_indices, devices=[i for i in range(device_count)], ts=device_count)

# gen.module.dci_db = DCI(self.temp_samples_proj.shape[1], num_comp_indices=self.H.num_comp_indices,
# num_simp_indices=self.H.num_simp_indices)
gen.module.dci_db.add(self.temp_samples_proj)

t0 = time.time()
for ind, y in enumerate(DataLoader(dataset, batch_size=self.H.imle_batch)):
# t2 = time.time()
_, target = self.preprocess_fn(y)
x = self.dataset_proj[ind * self.H.imle_batch:ind * self.H.imle_batch + target.shape[0]]
cur_batch_data_flat = x.float()
nearest_indices, _ = gen.module.dci_db.query(cur_batch_data_flat, num_neighbours=1)
nearest_indices = nearest_indices.long()[:, 0]

batch_slice = slice(ind * self.H.imle_batch, ind * self.H.imle_batch + x.size()[0])
actual_selected_dists = self.calc_loss(target.permute(0, 3, 1, 2),
self.temp_samples[nearest_indices].cuda(), use_mean=False)
# actual_selected_dists = torch.squeeze(actual_selected_dists)

to_update = torch.nonzero(actual_selected_dists < self.selected_dists[batch_slice], as_tuple=False)
to_update = torch.squeeze(to_update)
self.selected_dists[ind * self.H.imle_batch + to_update] = actual_selected_dists[to_update].clone()
self.selected_latents[ind * self.H.imle_batch + to_update] = self.temp_latent_rnds[nearest_indices[to_update]].clone()
for k in range(len(self.res)):
self.selected_snoise[k][ind * self.H.imle_batch + to_update] = self.snoise_tmp[k][nearest_indices[to_update]].clone()

del cur_batch_data_flat

gen.module.dci_db.clear()

# adding perturbation
changed = torch.sum(self.selected_dists_tmp != self.selected_dists).item()
print("Samples and NN are calculated, time: {}, mean: {} # changed: {}, {}%".format(time.time() - t1,
self.selected_dists.mean(),
changed, (changed / len(
dataset)) * 100))

def resample_pool(self, gen, ds):
# self.init_projection(ds)
self.pool_latents.normal_()
for i in range(len(self.res)):
self.snoise_pool[i].normal_()

for j in range(self.pool_size // self.H.imle_batch):
batch_slice = slice(j * self.H.imle_batch, (j + 1) * self.H.imle_batch)
def resample_pool(self, gen, to_update, rnd=True):
self.pool_last_updated[to_update] = 0
if rnd:
self.pool_latents[to_update].normal_()
for i in range(len(self.res)):
self.snoise_pool[i][to_update].normal_()

for j in range(to_update.shape[0] // self.H.imle_batch + 1):
sl = slice(j * self.H.imle_batch, (j + 1) * self.H.imle_batch)
batch_slice = to_update[sl]
if batch_slice.shape[0] == 0:
continue

cur_latents = self.pool_latents[batch_slice]
cur_snosie = [s[batch_slice] for s in self.snoise_pool]
with torch.no_grad():
self.pool_samples_proj[batch_slice] = self.get_projected(gen(cur_latents, cur_snosie), False)
self.pool_samples_proj[batch_slice] = self.get_projected(gen(cur_latents, cur_snosie), False).cpu()
# self.get_projected(gen(cur_latents, cur_snosie), False)


def imle_sample_force(self, dataset, gen, to_update=None):
self.pool_last_updated += 1
if to_update is None:
to_update = self.entire_ds
# resample all pool
self.resample_pool(gen, torch.arange(self.pool_size))
if to_update.shape[0] == 0:
return

pool_acceptable_stal = self.H.pool_staleness
# resample those that are too old
pool_old_indices = torch.where(self.pool_last_updated > pool_acceptable_stal)[0]
self.resample_pool(gen, pool_old_indices, rnd=False)

t1 = time.time()
print(torch.any(self.sample_pool_usage[to_update]), torch.any(self.sample_pool_usage))
if torch.any(self.sample_pool_usage[to_update]):
self.resample_pool(gen, dataset)
self.sample_pool_usage[:] = False
print(f'resampling took {time.time() - t1}')
# if torch.any(self.sample_pool_usage[to_update]):
# self.resample_pool(gen, dataset)
# self.sample_pool_usage[:] = False
# print(f'resampling took {time.time() - t1}')
# to_update_indices = self.selected_indices[to_update]
# self.resample_pool(gen, to_update_indices)

self.selected_dists_tmp[:] = np.inf
self.sample_pool_usage[to_update] = True


with torch.no_grad():
for i in range(self.pool_size // self.H.imle_db_size):
Expand All @@ -245,12 +205,15 @@ def imle_sample_force(self, dataset, gen, to_update=None):
indices = to_update[batch_slice]
x = self.dataset_proj[indices]
nearest_indices, dci_dists = gen.module.dci_db.query(x.float(), num_neighbours=1)
nearest_indices = nearest_indices.long()[:, 0]
nearest_indices = nearest_indices.long()[:, 0].cpu()
dci_dists = dci_dists[:, 0]

need_update = dci_dists < self.selected_dists_tmp[indices]
global_need_update = indices[need_update]

real_nearest_indices = nearest_indices[need_update] + pool_slice.start
self.selected_indices[global_need_update] = real_nearest_indices.clone()

self.selected_dists_tmp[global_need_update] = dci_dists[need_update].clone()
self.selected_latents_tmp[global_need_update] = pool_latents[nearest_indices[need_update]].clone() + self.H.imle_perturb_coef * torch.randn((need_update.sum(), self.H.latent_dim))
for j in range(len(self.res)):
Expand All @@ -262,29 +225,13 @@ def imle_sample_force(self, dataset, gen, to_update=None):
print("NN calculated for {} out of {} - {}".format((i + 1) * self.H.imle_db_size, self.pool_size, time.time() - t0))


if self.H.latent_epoch > 0:
for param in gen.parameters():
param.requires_grad = False
updatable_latents = self.selected_latents_tmp[to_update].clone().requires_grad_(True)
latent_optimizer = AdamW([updatable_latents], lr=self.latent_lr)
comb_dataset = ZippedDataset(TensorDataset(dataset[to_update]), TensorDataset(updatable_latents))

for gd_epoch in range(self.H.latent_epoch):
losses = []
for cur, _ in DataLoader(comb_dataset, batch_size=self.H.n_batch):
x = cur[0]
latents = cur[1][0]
_, target = self.preprocess_fn(x)
gen.zero_grad()
px_z = gen(latents) # TODO fix this
loss = self.calc_loss(px_z, target.permute(0, 3, 1, 2))
loss.backward()
latent_optimizer.step()
updatable_latents.grad.zero_()

losses.append(loss.detach())
print('avg loss', gd_epoch, sum(losses) / len(losses))
self.selected_latents[to_update] = updatable_latents.detach().clone()
self.selected_latents[to_update] = self.selected_latents_tmp[to_update].clone()
# self.pool_latents[self.selected_indices[to_update]].normal_()
# for i in range(len(self.res)):
# self.snoise_pool[i][self.selected_indices[to_update]].normal_()

to_update_indices = self.selected_indices[to_update].cuda()
self.resample_pool(gen, to_update_indices)

if self.H.latent_epoch > 0:
for param in gen.parameters():
Expand Down
30 changes: 30 additions & 0 deletions test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#!/bin/bash


name=100-shot-panda
change=0.0
factor=20
force=10
lr=0.00001
stal=10
wand_name="$name-chg-${change}-fac-${factor}-frc-${force}-lr-${lr}-stl-${stal}"

save_dir=/home/mehranag/scratch/saved_models/vdimle-reproduce/$wand_name
data_root=/home/mehranag/projects/rrg-keli/data/few-shot-images/${name}
restore_latent_path=/home/mehranag/scratch/saved_models/archived/4-ada3-2048-50p-full/test/latent/0-latest.npy
restore_path=/home/mehranag/scratch/saved_models/panda-naive/test/iter-450000-

#cd dciknn_cuda
#python setup.py install
#cd ..
cp /home/mehranag/inception-2015-12-05.pt /tmp

ssh -D 9050 -q -C -N narval1 &
python train.py --hps fewshot --save_dir $save_dir --data_root $data_root --lpips_coef 1 --l2_coef 0.1 \
--change_threshold 1 --change_coef $change --force_factor $factor --imle_db_size 5000 --imle_staleness $stal \
--imle_force_resample $force --latent_epoch 0 --latent_lr 0.0 --imle_factor 0 --lr $lr --n_batch 4 \
--proj_dim 800 --imle_batch 20 --iters_per_save 1000 --iters_per_images 500 --image_size 256 \
--proj_proportion 1 --latent_dim 1024 --iters_per_ckpt 5000 \
--dec_blocks '1x4,4m1,4x4,8m4,8x4,16m8,16x3,32m16,32x2,64m32,64x2,128m64,128x2,256m128' \
--max_hierarchy 256 --image_size 256 --use_wandb 1 --wandb_project $name-keep --wandb_name $wand_name --wandb_mode offline \
--fid_freq 10 --fid_factor 5
47 changes: 41 additions & 6 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from visual.utils import (generate_and_save, generate_for_NN,
generate_images_initial,
get_sample_for_visualization)
from torch.optim.lr_scheduler import LambdaLR


def training_step_imle(H, n, targets, latents, snoise, imle, ema_imle, optimizer, loss_fn):
Expand All @@ -41,6 +42,33 @@ def training_step_imle(H, n, targets, latents, snoise, imle, ema_imle, optimizer
stats.update(skipped_updates=0, iter_time=time.time() - t0, grad_norm=0)
return stats

class DecayLR:
def __init__(self, tmax=100000, staleness=10):
self.tmax = int(tmax)
self.staleness = staleness
assert self.tmax > 0
self.lr_step = (0 - 1) / self.tmax

def step(self, step):
per = step % self.staleness
lr = 1 + self.lr_step * step
lr = lr + ((0 - 1) / self.staleness) * per
lr = max(1e-6, min(1.0, lr))
return lr

def get_lrschedule(args, optimizer):
# if args.lr_schedule:
# scheduler = DecayLR(tmax=args.num_steps)
# lr_scheduler = LambdaLR(optimizer, lambda x: scheduler.step(x))
# else:
# lr_scheduler = LambdaLR(optimizer, lambda x: 1.0)
# return lr_scheduler
scheduler = DecayLR(tmax=args.num_steps, staleness=args.imle_staleness)
return LambdaLR(optimizer, lambda x: scheduler.step(x))
# return LambdaLR(optimizer, lambda x: 1.0)




def train_loop_imle(H, data_train, data_valid, preprocess_fn, imle, ema_imle, logprint):
subset_len = len(data_train)
Expand All @@ -51,6 +79,8 @@ def train_loop_imle(H, data_train, data_valid, preprocess_fn, imle, ema_imle, lo
break

optimizer, scheduler, _, iterate, _ = load_opt(H, imle, logprint)
# lr_scheduler = get_lrschedule(H, optimizer)
lr_scheduler = scheduler

stats = []
H.ema_rate = torch.as_tensor(H.ema_rate)
Expand Down Expand Up @@ -89,6 +119,7 @@ def train_loop_imle(H, data_train, data_valid, preprocess_fn, imle, ema_imle, lo
in_threshold = torch.logical_and(dists_in_threshold, updated_enough)
all_conditions = torch.logical_or(in_threshold, updated_too_much)
to_update = torch.nonzero(all_conditions, as_tuple=False).squeeze(1)
change_thresholds[to_update] = sampler.selected_dists[to_update].clone() * (1 - H.change_coef)

if epoch == 0:
if os.path.isfile(str(H.restore_latent_path)):
Expand All @@ -109,20 +140,21 @@ def train_loop_imle(H, data_train, data_valid, preprocess_fn, imle, ema_imle, lo
change_thresholds[:] = threshold[:]
print('loaded thresholds', torch.mean(change_thresholds))
else:
to_update = sampler.entire_ds
to_update = None


change_thresholds[to_update] = sampler.selected_dists[to_update].clone() * (1 - H.change_coef)

print(to_update)
sampler.imle_sample_force(split_x_tensor, imle, to_update)

last_updated[to_update] = 0
times_updated[to_update] = times_updated[to_update] + 1
if to_update is not None:
last_updated[to_update] = 0
times_updated[to_update] = times_updated[to_update] + 1

save_latents_latest(H, split_ind, sampler.selected_latents)
save_latents_latest(H, split_ind, change_thresholds, name='threshold_latest')

if to_update.shape[0] >= H.num_images_visualize:
if to_update is not None and to_update.shape[0] >= H.num_images_visualize:
latents = sampler.selected_latents[to_update[:H.num_images_visualize]]
with torch.no_grad():
generate_for_NN(sampler, split_x_tensor[to_update[:H.num_images_visualize]], latents,
Expand Down Expand Up @@ -165,6 +197,8 @@ def train_loop_imle(H, data_train, data_valid, preprocess_fn, imle, ema_imle, lo
save_latents(H, iterate, split_ind, change_thresholds, name='threshold')
save_snoise(H, iterate, sampler.selected_snoise)

# lr_scheduler.step()

cur_dists = torch.empty([subset_len], dtype=torch.float32).cuda()
cur_dists[:] = sampler.calc_dists_existing(split_x_tensor, imle, dists=cur_dists)
torch.save(cur_dists, f'{H.save_dir}/latent/dists-{epoch}.npy')
Expand All @@ -174,6 +208,7 @@ def train_loop_imle(H, data_train, data_valid, preprocess_fn, imle, ema_imle, lo
'std_loss': torch.std(cur_dists).item(),
'max_loss': torch.max(cur_dists).item(),
'min_loss': torch.min(cur_dists).item(),
'epoch': epoch,
}

if epoch % H.fid_freq == 0:
Expand All @@ -191,7 +226,7 @@ def train_loop_imle(H, data_train, data_valid, preprocess_fn, imle, ema_imle, lo
metrics['best_fid'] = best_fid


logprint(model=H.desc, type='train_loss', epoch=epoch, step=iterate, **metrics)
logprint(model=H.desc, type='train_loss', step=iterate, **metrics)

if H.use_wandb:
wandb.log(metrics, step=iterate)
Expand Down