diff --git a/hps.py b/hps.py index da7a595..bc7999f 100644 --- a/hps.py +++ b/hps.py @@ -118,4 +118,6 @@ def add_imle_arguments(parser): parser.add_argument('--ppl_save_name', type=str, default='ppl') parser.add_argument("--fid_factor", type=int, default=5, help="number of the samples for calculating FID") parser.add_argument("--fid_freq", type=int, default=5, help="frequency of calculating fid") + parser.add_argument("--num_steps", type=int, default=4000, help="frequency of calculating fid") + parser.add_argument("--pool_staleness", type=int, default=3, help="frequency of calculating fid") return parser diff --git a/mapping_network.py b/mapping_network.py index 21d6480..851ebe4 100644 --- a/mapping_network.py +++ b/mapping_network.py @@ -62,7 +62,7 @@ class MappingNetowrk(nn.Module): def __init__(self, code_dim=512, n_mlp=8): super().__init__() - layers = [PixelNorm()] + layers = [] for i in range(n_mlp): layers.append(EqualLinear(code_dim, code_dim)) layers.append(nn.LeakyReLU(0.2)) diff --git a/sampler.py b/sampler.py index f4bb40b..9185b64 100644 --- a/sampler.py +++ b/sampler.py @@ -1,5 +1,4 @@ from curses import update_lines_cols -from math import comb import time import numpy as np @@ -23,6 +22,7 @@ def __init__(self, H, sz, preprocess_fn): self.latent_lr = H.latent_lr self.entire_ds = torch.arange(sz) self.selected_latents = torch.empty([sz, H.latent_dim], dtype=torch.float32) + self.selected_indices = torch.empty([sz], dtype=torch.int64) self.selected_latents_tmp = torch.empty([sz, H.latent_dim], dtype=torch.float32) blocks = parse_layer_string(H.dec_blocks) @@ -62,6 +62,7 @@ def __init__(self, H, sz, preprocess_fn): self.dataset_proj = torch.empty([sz, sum(dims)], dtype=torch.float32) self.pool_samples_proj = torch.empty([self.pool_size, sum(dims)], dtype=torch.float32) self.snoise_pool_samples_proj = torch.empty([sz * H.snoise_factor, sum(dims)], dtype=torch.float32) + self.pool_last_updated = torch.zeros([self.pool_size], dtype=torch.int64) def get_projected(self, inp, permute=True): if permute: @@ -139,93 +140,52 @@ def calc_dists_existing(self, dataset_tensor, gen, dists=None, latents=None, to_ dists[batch_slice] = torch.squeeze(dist) return dists - def imle_sample(self, dataset, gen, factor=None): - if factor is None: - factor = self.H.imle_factor - imle_pool_size = int(len(dataset) * factor) - t1 = time.time() - self.selected_dists_tmp[:] = self.selected_dists[:] - for i in range(imle_pool_size // self.H.imle_db_size): - self.temp_latent_rnds.normal_() - for j in range(len(self.res)): - self.snoise_tmp[j].normal_() - for j in range(self.H.imle_db_size // self.H.imle_batch): - batch_slice = slice(j * self.H.imle_batch, (j + 1) * self.H.imle_batch) - cur_latents = self.temp_latent_rnds[batch_slice] - cur_snoise = [x[batch_slice] for x in self.snoise_tmp] - with torch.no_grad(): - self.temp_samples[batch_slice] = gen(cur_latents, cur_snoise) - self.temp_samples_proj[batch_slice] = self.get_projected(self.temp_samples[batch_slice], False) - - if not gen.module.dci_db: - device_count = torch.cuda.device_count() - gen.module.dci_db = MDCI(self.temp_samples_proj.shape[1], num_comp_indices=self.H.num_comp_indices, - num_simp_indices=self.H.num_simp_indices, devices=[i for i in range(device_count)], ts=device_count) - - # gen.module.dci_db = DCI(self.temp_samples_proj.shape[1], num_comp_indices=self.H.num_comp_indices, - # num_simp_indices=self.H.num_simp_indices) - gen.module.dci_db.add(self.temp_samples_proj) - - t0 = time.time() - for ind, y in enumerate(DataLoader(dataset, batch_size=self.H.imle_batch)): - # t2 = time.time() - _, target = self.preprocess_fn(y) - x = self.dataset_proj[ind * self.H.imle_batch:ind * self.H.imle_batch + target.shape[0]] - cur_batch_data_flat = x.float() - nearest_indices, _ = gen.module.dci_db.query(cur_batch_data_flat, num_neighbours=1) - nearest_indices = nearest_indices.long()[:, 0] - - batch_slice = slice(ind * self.H.imle_batch, ind * self.H.imle_batch + x.size()[0]) - actual_selected_dists = self.calc_loss(target.permute(0, 3, 1, 2), - self.temp_samples[nearest_indices].cuda(), use_mean=False) - # actual_selected_dists = torch.squeeze(actual_selected_dists) - - to_update = torch.nonzero(actual_selected_dists < self.selected_dists[batch_slice], as_tuple=False) - to_update = torch.squeeze(to_update) - self.selected_dists[ind * self.H.imle_batch + to_update] = actual_selected_dists[to_update].clone() - self.selected_latents[ind * self.H.imle_batch + to_update] = self.temp_latent_rnds[nearest_indices[to_update]].clone() - for k in range(len(self.res)): - self.selected_snoise[k][ind * self.H.imle_batch + to_update] = self.snoise_tmp[k][nearest_indices[to_update]].clone() - - del cur_batch_data_flat - - gen.module.dci_db.clear() - - # adding perturbation - changed = torch.sum(self.selected_dists_tmp != self.selected_dists).item() - print("Samples and NN are calculated, time: {}, mean: {} # changed: {}, {}%".format(time.time() - t1, - self.selected_dists.mean(), - changed, (changed / len( - dataset)) * 100)) - - def resample_pool(self, gen, ds): - # self.init_projection(ds) - self.pool_latents.normal_() - for i in range(len(self.res)): - self.snoise_pool[i].normal_() - - for j in range(self.pool_size // self.H.imle_batch): - batch_slice = slice(j * self.H.imle_batch, (j + 1) * self.H.imle_batch) + def resample_pool(self, gen, to_update, rnd=True): + self.pool_last_updated[to_update] = 0 + if rnd: + self.pool_latents[to_update].normal_() + for i in range(len(self.res)): + self.snoise_pool[i][to_update].normal_() + + for j in range(to_update.shape[0] // self.H.imle_batch + 1): + sl = slice(j * self.H.imle_batch, (j + 1) * self.H.imle_batch) + batch_slice = to_update[sl] + if batch_slice.shape[0] == 0: + continue + cur_latents = self.pool_latents[batch_slice] cur_snosie = [s[batch_slice] for s in self.snoise_pool] with torch.no_grad(): - self.pool_samples_proj[batch_slice] = self.get_projected(gen(cur_latents, cur_snosie), False) + self.pool_samples_proj[batch_slice] = self.get_projected(gen(cur_latents, cur_snosie), False).cpu() + # self.get_projected(gen(cur_latents, cur_snosie), False) + def imle_sample_force(self, dataset, gen, to_update=None): + self.pool_last_updated += 1 if to_update is None: to_update = self.entire_ds + # resample all pool + self.resample_pool(gen, torch.arange(self.pool_size)) if to_update.shape[0] == 0: return + + pool_acceptable_stal = self.H.pool_staleness + # resample those that are too old + pool_old_indices = torch.where(self.pool_last_updated > pool_acceptable_stal)[0] + self.resample_pool(gen, pool_old_indices, rnd=False) t1 = time.time() print(torch.any(self.sample_pool_usage[to_update]), torch.any(self.sample_pool_usage)) - if torch.any(self.sample_pool_usage[to_update]): - self.resample_pool(gen, dataset) - self.sample_pool_usage[:] = False - print(f'resampling took {time.time() - t1}') + # if torch.any(self.sample_pool_usage[to_update]): + # self.resample_pool(gen, dataset) + # self.sample_pool_usage[:] = False + # print(f'resampling took {time.time() - t1}') + # to_update_indices = self.selected_indices[to_update] + # self.resample_pool(gen, to_update_indices) self.selected_dists_tmp[:] = np.inf self.sample_pool_usage[to_update] = True + with torch.no_grad(): for i in range(self.pool_size // self.H.imle_db_size): @@ -245,12 +205,15 @@ def imle_sample_force(self, dataset, gen, to_update=None): indices = to_update[batch_slice] x = self.dataset_proj[indices] nearest_indices, dci_dists = gen.module.dci_db.query(x.float(), num_neighbours=1) - nearest_indices = nearest_indices.long()[:, 0] + nearest_indices = nearest_indices.long()[:, 0].cpu() dci_dists = dci_dists[:, 0] need_update = dci_dists < self.selected_dists_tmp[indices] global_need_update = indices[need_update] + real_nearest_indices = nearest_indices[need_update] + pool_slice.start + self.selected_indices[global_need_update] = real_nearest_indices.clone() + self.selected_dists_tmp[global_need_update] = dci_dists[need_update].clone() self.selected_latents_tmp[global_need_update] = pool_latents[nearest_indices[need_update]].clone() + self.H.imle_perturb_coef * torch.randn((need_update.sum(), self.H.latent_dim)) for j in range(len(self.res)): @@ -262,29 +225,13 @@ def imle_sample_force(self, dataset, gen, to_update=None): print("NN calculated for {} out of {} - {}".format((i + 1) * self.H.imle_db_size, self.pool_size, time.time() - t0)) - if self.H.latent_epoch > 0: - for param in gen.parameters(): - param.requires_grad = False - updatable_latents = self.selected_latents_tmp[to_update].clone().requires_grad_(True) - latent_optimizer = AdamW([updatable_latents], lr=self.latent_lr) - comb_dataset = ZippedDataset(TensorDataset(dataset[to_update]), TensorDataset(updatable_latents)) - - for gd_epoch in range(self.H.latent_epoch): - losses = [] - for cur, _ in DataLoader(comb_dataset, batch_size=self.H.n_batch): - x = cur[0] - latents = cur[1][0] - _, target = self.preprocess_fn(x) - gen.zero_grad() - px_z = gen(latents) # TODO fix this - loss = self.calc_loss(px_z, target.permute(0, 3, 1, 2)) - loss.backward() - latent_optimizer.step() - updatable_latents.grad.zero_() - - losses.append(loss.detach()) - print('avg loss', gd_epoch, sum(losses) / len(losses)) - self.selected_latents[to_update] = updatable_latents.detach().clone() + self.selected_latents[to_update] = self.selected_latents_tmp[to_update].clone() + # self.pool_latents[self.selected_indices[to_update]].normal_() + # for i in range(len(self.res)): + # self.snoise_pool[i][self.selected_indices[to_update]].normal_() + + to_update_indices = self.selected_indices[to_update].cuda() + self.resample_pool(gen, to_update_indices) if self.H.latent_epoch > 0: for param in gen.parameters(): diff --git a/test.sh b/test.sh new file mode 100755 index 0000000..1c84e40 --- /dev/null +++ b/test.sh @@ -0,0 +1,30 @@ +#!/bin/bash + + +name=100-shot-panda +change=0.0 +factor=20 +force=10 +lr=0.00001 +stal=10 +wand_name="$name-chg-${change}-fac-${factor}-frc-${force}-lr-${lr}-stl-${stal}" + +save_dir=/home/mehranag/scratch/saved_models/vdimle-reproduce/$wand_name +data_root=/home/mehranag/projects/rrg-keli/data/few-shot-images/${name} +restore_latent_path=/home/mehranag/scratch/saved_models/archived/4-ada3-2048-50p-full/test/latent/0-latest.npy +restore_path=/home/mehranag/scratch/saved_models/panda-naive/test/iter-450000- + +#cd dciknn_cuda +#python setup.py install +#cd .. +cp /home/mehranag/inception-2015-12-05.pt /tmp + +ssh -D 9050 -q -C -N narval1 & +python train.py --hps fewshot --save_dir $save_dir --data_root $data_root --lpips_coef 1 --l2_coef 0.1 \ + --change_threshold 1 --change_coef $change --force_factor $factor --imle_db_size 5000 --imle_staleness $stal \ + --imle_force_resample $force --latent_epoch 0 --latent_lr 0.0 --imle_factor 0 --lr $lr --n_batch 4 \ + --proj_dim 800 --imle_batch 20 --iters_per_save 1000 --iters_per_images 500 --image_size 256 \ + --proj_proportion 1 --latent_dim 1024 --iters_per_ckpt 5000 \ + --dec_blocks '1x4,4m1,4x4,8m4,8x4,16m8,16x3,32m16,32x2,64m32,64x2,128m64,128x2,256m128' \ + --max_hierarchy 256 --image_size 256 --use_wandb 1 --wandb_project $name-keep --wandb_name $wand_name --wandb_mode offline \ + --fid_freq 10 --fid_factor 5 diff --git a/train.py b/train.py index f8c7a22..8abba77 100644 --- a/train.py +++ b/train.py @@ -25,6 +25,7 @@ from visual.utils import (generate_and_save, generate_for_NN, generate_images_initial, get_sample_for_visualization) +from torch.optim.lr_scheduler import LambdaLR def training_step_imle(H, n, targets, latents, snoise, imle, ema_imle, optimizer, loss_fn): @@ -41,6 +42,33 @@ def training_step_imle(H, n, targets, latents, snoise, imle, ema_imle, optimizer stats.update(skipped_updates=0, iter_time=time.time() - t0, grad_norm=0) return stats +class DecayLR: + def __init__(self, tmax=100000, staleness=10): + self.tmax = int(tmax) + self.staleness = staleness + assert self.tmax > 0 + self.lr_step = (0 - 1) / self.tmax + + def step(self, step): + per = step % self.staleness + lr = 1 + self.lr_step * step + lr = lr + ((0 - 1) / self.staleness) * per + lr = max(1e-6, min(1.0, lr)) + return lr + +def get_lrschedule(args, optimizer): + # if args.lr_schedule: + # scheduler = DecayLR(tmax=args.num_steps) + # lr_scheduler = LambdaLR(optimizer, lambda x: scheduler.step(x)) + # else: + # lr_scheduler = LambdaLR(optimizer, lambda x: 1.0) + # return lr_scheduler + scheduler = DecayLR(tmax=args.num_steps, staleness=args.imle_staleness) + return LambdaLR(optimizer, lambda x: scheduler.step(x)) + # return LambdaLR(optimizer, lambda x: 1.0) + + + def train_loop_imle(H, data_train, data_valid, preprocess_fn, imle, ema_imle, logprint): subset_len = len(data_train) @@ -51,6 +79,8 @@ def train_loop_imle(H, data_train, data_valid, preprocess_fn, imle, ema_imle, lo break optimizer, scheduler, _, iterate, _ = load_opt(H, imle, logprint) + # lr_scheduler = get_lrschedule(H, optimizer) + lr_scheduler = scheduler stats = [] H.ema_rate = torch.as_tensor(H.ema_rate) @@ -89,6 +119,7 @@ def train_loop_imle(H, data_train, data_valid, preprocess_fn, imle, ema_imle, lo in_threshold = torch.logical_and(dists_in_threshold, updated_enough) all_conditions = torch.logical_or(in_threshold, updated_too_much) to_update = torch.nonzero(all_conditions, as_tuple=False).squeeze(1) + change_thresholds[to_update] = sampler.selected_dists[to_update].clone() * (1 - H.change_coef) if epoch == 0: if os.path.isfile(str(H.restore_latent_path)): @@ -109,20 +140,21 @@ def train_loop_imle(H, data_train, data_valid, preprocess_fn, imle, ema_imle, lo change_thresholds[:] = threshold[:] print('loaded thresholds', torch.mean(change_thresholds)) else: - to_update = sampler.entire_ds + to_update = None - change_thresholds[to_update] = sampler.selected_dists[to_update].clone() * (1 - H.change_coef) + print(to_update) sampler.imle_sample_force(split_x_tensor, imle, to_update) - last_updated[to_update] = 0 - times_updated[to_update] = times_updated[to_update] + 1 + if to_update is not None: + last_updated[to_update] = 0 + times_updated[to_update] = times_updated[to_update] + 1 save_latents_latest(H, split_ind, sampler.selected_latents) save_latents_latest(H, split_ind, change_thresholds, name='threshold_latest') - if to_update.shape[0] >= H.num_images_visualize: + if to_update is not None and to_update.shape[0] >= H.num_images_visualize: latents = sampler.selected_latents[to_update[:H.num_images_visualize]] with torch.no_grad(): generate_for_NN(sampler, split_x_tensor[to_update[:H.num_images_visualize]], latents, @@ -165,6 +197,8 @@ def train_loop_imle(H, data_train, data_valid, preprocess_fn, imle, ema_imle, lo save_latents(H, iterate, split_ind, change_thresholds, name='threshold') save_snoise(H, iterate, sampler.selected_snoise) + # lr_scheduler.step() + cur_dists = torch.empty([subset_len], dtype=torch.float32).cuda() cur_dists[:] = sampler.calc_dists_existing(split_x_tensor, imle, dists=cur_dists) torch.save(cur_dists, f'{H.save_dir}/latent/dists-{epoch}.npy') @@ -174,6 +208,7 @@ def train_loop_imle(H, data_train, data_valid, preprocess_fn, imle, ema_imle, lo 'std_loss': torch.std(cur_dists).item(), 'max_loss': torch.max(cur_dists).item(), 'min_loss': torch.min(cur_dists).item(), + 'epoch': epoch, } if epoch % H.fid_freq == 0: @@ -191,7 +226,7 @@ def train_loop_imle(H, data_train, data_valid, preprocess_fn, imle, ema_imle, lo metrics['best_fid'] = best_fid - logprint(model=H.desc, type='train_loss', epoch=epoch, step=iterate, **metrics) + logprint(model=H.desc, type='train_loss', step=iterate, **metrics) if H.use_wandb: wandb.log(metrics, step=iterate)