From 93701fd0e3d49d3e230ec3c74b3c20be9b103ab6 Mon Sep 17 00:00:00 2001 From: Scott Hawley Date: Mon, 30 May 2022 22:03:58 +0000 Subject: [PATCH 1/3] added newline to end --- defaults.ini | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/defaults.ini b/defaults.ini index 4d0fd3f..9af1180 100644 --- a/defaults.ini +++ b/defaults.ini @@ -1,4 +1,3 @@ - [DEFAULTS] #name of the run @@ -53,4 +52,4 @@ codebook_size = 1024 num_quantizers = 8 # If true training data is kept in RAM -cache_training_data = False \ No newline at end of file +cache_training_data = False From 5ca1901dcf3ec73aa21f3df937aab5d6df0fcf06 Mon Sep 17 00:00:00 2001 From: Scott Hawley Date: Fri, 3 Jun 2022 19:43:06 +0000 Subject: [PATCH 2/3] scott's simple autoencoder --- decoders/decoders.py | 536 +++++++++++++++++++++++++++++++++++++++++++ defaults.ini | 10 +- run_zmc.py | 475 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 1019 insertions(+), 2 deletions(-) create mode 100644 decoders/decoders.py create mode 100755 run_zmc.py diff --git a/decoders/decoders.py b/decoders/decoders.py new file mode 100644 index 0000000..0abcd7c --- /dev/null +++ b/decoders/decoders.py @@ -0,0 +1,536 @@ +## Modified from https://github.com/wesbz/SoundStream/blob/main/net.py +from xml.etree.ElementPath import prepare_predicate, prepare_star +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import cached_conv as cc +#import torch.nn.utils.weight_norm as wn +import torch.nn.utils.weight_norm as weight_norm + +import librosa as li +import torch.fft as fft +from einops import rearrange + +# TODO: Remove the RAVE code I'm not using anymore -SH + +MAX_BATCH_SIZE = 64 + + +def WNConv1d(*args, **kwargs): + return weight_norm(cc.Conv1d(*args, **kwargs)) + +def WNConvTranspose1d(*args, **kwargs): + return weight_norm(cc.ConvTranspose1d(*args, **kwargs)) + +class CachedPadding1d(nn.Module): + """ + Cached Padding implementation, replace zero padding with the end of + the previous tensor. + """ + def __init__(self, padding, crop=False): + super().__init__() + self.initialized = 0 + self.padding = padding + self.crop = crop + + @torch.jit.unused + @torch.no_grad() + def init_cache(self, x): + b, c, _ = x.shape + self.register_buffer( + "pad", + torch.zeros(MAX_BATCH_SIZE, c, self.padding).to(x)) + self.initialized += 1 + + def forward(self, x): + if not self.initialized: + self.init_cache(x) + + if self.padding: + x = torch.cat([self.pad[:x.shape[0]], x.clone()], -1) + self.pad[:x.shape[0]].copy_(x.clone()[..., -self.padding:]) + + if self.crop: + x = x.clone()[..., :-self.padding] + + return x + + +class AlignBranches(nn.Module): + def __init__(self, *branches, delays=None, cumulative_delay=0, stride=1): + super().__init__() + self.branches = nn.ModuleList(branches) + + if delays is None: + delays = list(map(lambda x: x.cumulative_delay, self.branches)) + + max_delay = max(delays) + + self.paddings = nn.ModuleList([ + CachedPadding1d(p, crop=True) + for p in map(lambda f: max_delay - f, delays) + ]) + + self.cumulative_delay = int(cumulative_delay * stride) + max_delay + + def forward(self, x): + outs = [] + print("q x.size() = ",x.size()) + for branch, pad in zip(self.branches, self.paddings): + print("branch, pad = ",branch, pad) + delayed_x = pad(x) + bd = branch(delayed_x) + print("delayed_x.size(), bd.size() = ",delayed_x.size(), bd.size()) + outs.append(bd) + return outs + + + +def mod_sigmoid(x): + return 2 * torch.sigmoid(x)**2.3 + 1e-7 + + +def amp_to_impulse_response(amp, target_size): + """ + transforms frequecny amps to ir on the last dimension + """ + amp = torch.stack([amp, torch.zeros_like(amp)], -1) + amp = torch.view_as_complex(amp) + amp = fft.irfft(amp.clone()) + + filter_size = amp.shape[-1] + + amp = torch.roll(amp, filter_size // 2, -1) + win = torch.hann_window(filter_size, dtype=amp.dtype, device=amp.device) + + amp = amp * win + + amp = nn.functional.pad( + amp, + (0, int(target_size) - int(filter_size)), + ) + amp = torch.roll(amp, -filter_size // 2, -1) + + return amp + + + +def fft_convolve(signal, kernel): + """ + convolves signal by kernel on the last dimension + """ + signal = nn.functional.pad(signal, (0, signal.shape[-1])) + kernel = nn.functional.pad(kernel, (kernel.shape[-1], 0)) + + output = fft.irfft(fft.rfft(signal.clone()) * fft.rfft(kernel.clone())) + output = output[..., output.shape[-1] // 2:] + + return output + + +def multiscale_stft(signal, scales, overlap): + """ + Compute a stft on several scales, with a constant overlap value. + Parameters + ---------- + signal: torch.Tensor + input signal to process ( B X C X T ) + + scales: list + scales to use + overlap: float + overlap between windows ( 0 - 1 ) + """ + signal = rearrange(signal, "b c t -> (b c) t") + stfts = [] + for s in scales: + S = torch.stft( + signal, + s, + int(s * (1 - overlap)), + s, + torch.hann_window(s).to(signal), + True, + normalized=True, + return_complex=True, + ).abs() + stfts.append(S) + return stfts + + +class Loudness(nn.Module): + + def __init__(self, sr, block_size, n_fft=2048): + super().__init__() + self.sr = sr + self.block_size = block_size + self.n_fft = n_fft + + f = np.linspace(0, sr / 2, n_fft // 2 + 1) + 1e-7 + a_weight = li.A_weighting(f).reshape(-1, 1) + + self.register_buffer("a_weight", torch.from_numpy(a_weight).float()) + self.register_buffer("window", torch.hann_window(self.n_fft)) + + def forward(self, x): + x = x.clone()[:,0,:] # mono loundness ? + print("x.squeeze(1).size() = ", x.squeeze(1).size()) + x = torch.stft( + x.squeeze(1), + self.n_fft, + self.block_size, + self.n_fft, + center=True, + window=self.window, + return_complex=True, + ).abs() + x = torch.log(x + 1e-7) + self.a_weight + return torch.mean(x, 1, keepdim=True) + + + +class RAVEResidual(nn.Module): + def __init__(self, module, cumulative_delay=0): + super().__init__() + additional_delay = module.cumulative_delay + self.aligned = cc.AlignBranches( + module, + nn.Identity(), + delays=[additional_delay, 0], + ) + self.cumulative_delay = additional_delay + cumulative_delay + + def forward(self, x): + x_net, x_res = self.aligned(x) + return x_net + x_res + + +class RAVEResidualStack(nn.Module): + def __init__(self, + dim, + kernel_size, + padding_mode, + cumulative_delay=0, + bias=False): + super().__init__() + net = [] + + res_cum_delay = 0 + # SEQUENTIAL RESIDUALS + for i in range(3): + # RESIDUAL BLOCK + seq = [nn.LeakyReLU(.2)] + seq.append( + WNConv1d( + dim, + dim, + kernel_size, + padding=cc.get_padding( + kernel_size, + dilation=3**i, + mode=padding_mode, + ), + dilation=3**i, + bias=bias, + )) + + seq.append(nn.LeakyReLU(.2)) + seq.append( + WNConv1d( + dim, + dim, + kernel_size, + padding=cc.get_padding(kernel_size, mode=padding_mode), + bias=bias, + cumulative_delay=seq[-2].cumulative_delay, + )) + + res_net = cc.CachedSequential(*seq) + + net.append(RAVEResidual(res_net, cumulative_delay=res_cum_delay)) + res_cum_delay = net[-1].cumulative_delay + + self.net = cc.CachedSequential(*net) + self.cumulative_delay = self.net.cumulative_delay + cumulative_delay + + def forward(self, x): + return self.net(x) + +# RAVE bits grabbed from IRCAM-RAVE, https://github.com/acids-ircam/RAVE/blob/master/rave/model.py +class RAVEUpsampleLayer(nn.Module): + def __init__(self, + in_dim, + out_dim, + ratio, + padding_mode, + cumulative_delay=0, + bias=False): + super().__init__() + net = [nn.LeakyReLU(.2)] + if ratio > 1: + net.append( + WNConvTranspose1d( + in_dim, + out_dim, + 2 * ratio, + stride=ratio, + padding=ratio // 2, + bias=bias, + )) + else: + net.append( + WNConv1d( + in_dim, + out_dim, + 3, + padding=cc.get_padding(3, mode=padding_mode), + bias=bias, + )) + + self.net = cc.CachedSequential(*net) + self.cumulative_delay = self.net.cumulative_delay + cumulative_delay * ratio + + def forward(self, x): + return self.net(x) + + +class RAVENoiseGenerator(nn.Module): + def __init__(self, in_size, data_size, ratios, noise_bands, padding_mode): + super().__init__() + net = [] + channels = [in_size] * len(ratios) + [data_size * noise_bands] + cum_delay = 0 + for i, r in enumerate(ratios): + net.append( + cc.Conv1d( + channels[i], + channels[i + 1], + 3, + padding=cc.get_padding(3, r, mode=padding_mode), + stride=r, + cumulative_delay=cum_delay, + )) + cum_delay = net[-1].cumulative_delay + if i != len(ratios) - 1: + net.append(nn.LeakyReLU(.2)) + + self.net = cc.CachedSequential(*net) + self.data_size = data_size + self.cumulative_delay = self.net.cumulative_delay * int( + np.prod(ratios)) + + self.register_buffer( + "target_size", + torch.tensor(np.prod(ratios)).long(), + ) + + def forward(self, x): + amp = mod_sigmoid(self.net(x) - 5) + amp = amp.permute(0, 2, 1) + amp = amp.reshape(amp.shape[0], amp.shape[1], self.data_size, -1) + + ir = amp_to_impulse_response(amp, self.target_size) + noise = torch.rand_like(ir) * 2 - 1 + + noise = fft_convolve(noise.clone(), ir).permute(0, 2, 1, 3) + noise = noise.reshape(noise.shape[0], noise.shape[1], -1) + return noise + + + +class RAVEGenerator(nn.Module): + def __init__(self, + latent_size, + capacity, + data_size, + ratios, + loud_stride, + use_noise, + noise_ratios, + noise_bands, + padding_mode, + bias=False): + super().__init__() + net = [ + WNConv1d( + latent_size, + 2**len(ratios) * capacity, + 7, + padding=cc.get_padding(7, mode=padding_mode), + bias=bias, + ) + ] + + for i, r in enumerate(ratios): + in_dim = 2**(len(ratios) - i) * capacity + out_dim = 2**(len(ratios) - i - 1) * capacity + print("i, r, in_dim, out_dim = ",i, r, in_dim, out_dim) + + net.append( + RAVEUpsampleLayer( + in_dim, + out_dim, + r, + padding_mode, + cumulative_delay=net[-1].cumulative_delay, + )) + net.append( + RAVEResidualStack( + out_dim, + 3, + padding_mode, + cumulative_delay=net[-1].cumulative_delay, + )) + + self.net = cc.CachedSequential(*net) + + wave_gen = WNConv1d( + out_dim, + data_size, + 7, + padding=cc.get_padding(7, mode=padding_mode), + bias=bias, + ) + + loud_gen = WNConv1d( + out_dim, + 1, + 2 * loud_stride + 1, + stride=loud_stride, + padding=cc.get_padding(2 * loud_stride + 1, + loud_stride, + mode=padding_mode), + bias=bias, + ) + + branches = [wave_gen, loud_gen] + + if use_noise: + noise_gen = RAVENoiseGenerator( + out_dim, + data_size, + noise_ratios, + noise_bands, + padding_mode=padding_mode, + ) + branches.append(noise_gen) + + self.synth = AlignBranches( + *branches, + cumulative_delay=self.net.cumulative_delay, + ) + + self.use_noise = use_noise + self.loud_stride = loud_stride + self.cumulative_delay = self.synth.cumulative_delay + + def forward(self, x, add_noise: bool = True): + print("\n RAVEGEnerator: in x.size() = ",x.size()) + x = self.net(x) + print(" new x.size() = ",x.size()) + + """ + if self.use_noise: + waveform, loudness, noise = self.synth(x) + else: + waveform, loudness = self.synth(x) + noise = torch.zeros_like(waveform) + """ + waveform, loudness = x.clone()[:,0:1,:], x.clone()[:,2,:] + + print("1 waveform.size() = ",waveform.size()) + print("1 loudness.size() = ",loudness.size()) + + loudness = loudness.repeat_interleave(self.loud_stride) + loudness = loudness.reshape(x.shape[0], 1, -1) + + waveform = torch.tanh(waveform.clone()) * mod_sigmoid(loudness) + print("2 waveform.size() = ",waveform.size()) + + #if add_noise: + # waveform = waveform + noise + + waveform = waveform.clone()[:,:,0:32768] #truncate + + return waveform + + +def GenBlock(input_channels, output_channels, kernel_size=4, stride=2, padding=1, final_block=False): + if not final_block: + return nn.Sequential( + nn.ConvTranspose1d(input_channels, output_channels, kernel_size, stride=stride, padding=padding), + nn.BatchNorm1d(output_channels), + nn.ReLU() + ) + else: # Final block + return nn.Sequential( + nn.ConvTranspose1d(input_channels, output_channels, kernel_size, stride=stride, padding=padding), + #nn.Tanh() # save tanh for end of loop, to rescale it + ) + +class Upscale_new(nn.Module): + def __init__(self, inc, outc, ksize, scale=2, final_block=False, add_noise=True): + super().__init__() + self.gb = GenBlock(inc, outc, final_block=final_block, stride=scale) + self.conv_same_size1 = nn.Conv1d(outc,outc,ksize,stride=1,padding=1) + self.conv_same_size2 = nn.Conv1d(outc,outc,ksize,stride=1,padding=1) + + self.add_noise = add_noise + self.act = nn.ReLU() + + def forward(self, x): + x = self.gb(x) + if self.add_noise: # some way of letting the network better match Zach's crazy Splice dataset + noise = torch.rand_like(x) * 2 - 1 + morph = self.act(self.conv_same_size1(x * noise)) # let x serve as the switch to allow more or less noise + x = self.conv_same_size2(x + morph) + return x + +""" +class Upscale_old(nn.Module): + def __init__(self, inc, outc, ksize, scale=2): + super().__init__() + #self.upsize = nn.ConvTranspose1d(inc, outc, ksize, stride=scale) # can cause checkerboaring + self.upsize = nn.Upsample(scale_factor=scale) + self.conv1 = nn.Conv1d(inc,outc,ksize,stride=1,padding=1) + self.act = nn.LeakyReLU(0.2) + self.conv2 = nn.Conv1d(outc, outc, ksize, stride=1, padding=1) + + def forward(self, x): + x = self.upsize(x) + x = self.conv1(x) # filters to lower the number of channels + x = self.act(x) + x = self.conv2(x) + out = x + return out +""" + +class SimpleDecoder(nn.Module): + """ + Scott trying Just making a basic expanding thingy + """ + def __init__(self, latent_dim, io_channels, out_length=32768, depth=16, add_noise=True): + super().__init__() + channels = [latent_dim,32,16,8,4,4,2] + scales = [2,2,4,4,2,2] + ksize = 3 + assert len(scales) == (len(channels)-1) + self.out_length = out_length + + + self.uplayers = nn.ModuleList( + [Upscale_new(channels[i],channels[i+1], ksize, scale=scales[i], + final_block=(i==len(scales)-1), add_noise=add_noise) for i in range(len(scales))] + ) + #self.final_conv = nn.Conv1d(channels[-1], channels[-1], ksize, stride=1, padding=1) + self.final_act = torch.tanh # so output waveform is on [-1,1] + + def forward(self, x): + # initially, x = z = latents. then we upscale it + for i in range(len(self.uplayers)): + #print(f"{i} 1 x.size() = ",x.size()) + x = self.uplayers[i](x) + x = 1.1*self.final_act(x) + return x[:,:,0:self.out_length] # crop to desired length, throw away the rest + diff --git a/defaults.ini b/defaults.ini index 9af1180..5b00172 100644 --- a/defaults.ini +++ b/defaults.ini @@ -4,7 +4,7 @@ name = splice-fastdiff # training data directory -training_dir = /home/ubuntu/datasets/SignalTrain_LA2A_Dataset_1.1 +training_dir = /home/ubuntu/datasets/Splice # the batch size batch_size = 8 @@ -19,7 +19,7 @@ num_workers = 12 sample_size = 32768 # Number of epochs between demos -demo_every = 50 +demo_every = 10 # Number of denoising steps for the demos demo_steps = 250 @@ -53,3 +53,9 @@ num_quantizers = 8 # If true training data is kept in RAM cache_training_data = False + +# number of sub-bands for the PQMF filter +pqmf_bands = 1 + +# number of heads for memcodes +num_heads = 8 diff --git a/run_zmc.py b/run_zmc.py new file mode 100755 index 0000000..3ae80b8 --- /dev/null +++ b/run_zmc.py @@ -0,0 +1,475 @@ +#!/usr/bin/env python3 + +import argparse +from contextlib import contextmanager +from copy import deepcopy +import math +from pathlib import Path +import numpy as np + + +import sys +import torch +from torch import optim, nn +from torch.nn import functional as F +from torch.utils import data +from tqdm import trange +import pytorch_lightning as pl +from pytorch_lightning.utilities.distributed import rank_zero_only +from einops import rearrange + +import torchaudio +import wandb +from prefigure.prefigure import get_all_args, push_wandb_config +import auraloss + +from dataset.dataset import SampleDataset +from diffusion.model import SkipBlock, FourierFeatures, expand_to_planes, ema_update +from diffusion.pqmf import CachedPQMF as PQMF +from encoders.encoders import RAVEEncoder, ResConvBlock, SoundStreamXLEncoder +from decoders.decoders import RAVEGenerator, SimpleDecoder + +from decoders.decoders import multiscale_stft, Loudness, mod_sigmoid + + +from nwt_pytorch import Memcodes +from dvae.residual_memcodes import ResidualMemcodes + +# from RAVE core: +def get_beta_kl(step, warmup, min_beta, max_beta): + if step > warmup: return max_beta + t = step / warmup + min_beta_log = np.log(min_beta) + max_beta_log = np.log(max_beta) + beta_log = t * (max_beta_log - min_beta_log) + min_beta_log + return np.exp(beta_log) + + +def get_beta_kl_cyclic(step, cycle_size, min_beta, max_beta): + return get_beta_kl(step % cycle_size, cycle_size // 2, min_beta, max_beta) + + +def get_beta_kl_cyclic_annealed(step, cycle_size, warmup, min_beta, max_beta): + min_beta = get_beta_kl(step, warmup, min_beta, max_beta) + return get_beta_kl_cyclic(step, cycle_size, min_beta, max_beta) + + + +# Define the noise schedule and sampling loop +def get_alphas_sigmas(t): + """Returns the scaling factors for the clean image (alpha) and for the + noise (sigma), given a timestep.""" + return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2) + +def get_crash_schedule(t): + sigma = torch.sin(t * math.pi / 2) ** 2 + alpha = (1 - sigma ** 2) ** 0.5 + return alpha_sigma_to_t(alpha, sigma) + +def alpha_sigma_to_t(alpha, sigma): + """Returns a timestep, given the scaling factors for the clean image and for + the noise.""" + return torch.atan2(sigma, alpha) / math.pi * 2 + + +@torch.no_grad() +def sample(model, inputs): + """just runs the model in inference mode""" + with torch.cuda.amp.autocast(): + v = model(inputs).float() + return v + + +''' +@torch.no_grad() +def sample_old(model, x, steps, eta, logits): + """Draws samples from a model given starting noise.""" + ts = x.new_ones([x.shape[0]]) + + # Create the noise schedule + t = torch.linspace(1, 0, steps + 1)[:-1] + alphas, sigmas = get_alphas_sigmas(get_crash_schedule(t)) + + # The sampling loop + for i in trange(steps): + + # Get the model output (v, the predicted velocity) + with torch.cuda.amp.autocast(): + v = model(x, ts * t[i], logits).float() + + # Predict the noise and the denoised image + pred = x * alphas[i] - v * sigmas[i] + eps = x * sigmas[i] + v * alphas[i] + + # If we are not on the last timestep, compute the noisy image for the + # next timestep. + if i < steps - 1: + # If eta > 0, adjust the scaling factor for the predicted noise + # downward according to the amount of additional noise to add + ddim_sigma = eta * (sigmas[i + 1]**2 / sigmas[i]**2).sqrt() * \ + (1 - alphas[i]**2 / alphas[i + 1]**2).sqrt() + adjusted_sigma = (sigmas[i + 1]**2 - ddim_sigma**2).sqrt() + + # Recombine the predicted noise and predicted denoised image in the + # correct proportions for the next step + x = pred * alphas[i + 1] + eps * adjusted_sigma + + # Add the correct amount of fresh noise + if eta: + x += torch.randn_like(x) * ddim_sigma + + # If we are on the last timestep, output the denoised image + return pred +''' + + +class ToMode: + def __init__(self, mode): + self.mode = mode + + def __call__(self, image): + return image.convert(self.mode) + + +def ramp(x1, x2, y1, y2): + def wrapped(x): + if x <= x1: + return y1 + if x >= x2: + return y2 + fac = (x - x1) / (x2 - x1) + return y1 * (1 - fac) + y2 * fac + return wrapped + + +class ZQVAE(pl.LightningModule): + def __init__(self, global_args, + min_kl=1e-4, + max_kl=5e-1): + super().__init__() + + #self.encoder = Encoder(global_args.codebook_size, 2) + #self.encoder = SoundStreamXLEncoder(32, global_args.latent_dim, n_io_channels=2, strides=[2, 2, 4, 5, 8], c_mults=[2, 4, 4, 8, 16]) + self.loudness = Loudness(global_args.sample_rate, 512) + + self.pqmf_bands = global_args.pqmf_bands + + if self.pqmf_bands > 1: + self.pqmf = PQMF(2, 70, global_args.pqmf_bands) + + self.min_kl = min_kl + self.max_kl = max_kl + self.warmup = 1000000 + + #Model: + # Encoder part + self.encoder = RAVEEncoder(2 * global_args.pqmf_bands, 64, global_args.latent_dim, ratios=[2, 2, 2, 2, 4, 4]) + self.encoder_ema = deepcopy(self.encoder) + + # Latent middle rep: Memcodes (see below) + + # Decoder part + #self.diffusion = DiffusionDecoder(global_args.latent_dim, 2) + #self.decoder = RAVEDecoder(global_args.latent_dim, 2) + # default RAVE settings pulled from https://github.com/acids-ircam/RAVE/blob/master/train_rave.py + DATA_SIZE = 2 + CAPACITY = 64 + LATENT_SIZE = 128 + BIAS = True + NO_LATENCY = False + RATIOS = [4, 4, 2, 2, 2, 2] #[4, 4, 4, 2] + + MIN_KL = 1e-1 + MAX_KL = 1e-1 + CROPPED_LATENT_SIZE = 0 + FEATURE_MATCH = True + + LOUD_STRIDE = 1 + + USE_NOISE = False + NOISE_RATIOS = [4, 4, 4] + NOISE_BANDS = 5 + + D_CAPACITY = 16 + D_MULTIPLIER = 4 + D_N_LAYERS = 4 + + MODE = "hinge" + CKPT = None + + #no_latency=False + #PADDING_MODE = "causal" if no_latency else "centered" + PADDING_MODE = "centered" + '''self.decoder = RAVEGenerator(global_args.latent_dim, + capacity=CAPACITY, + data_size=DATA_SIZE, + ratios=RATIOS, + loud_stride=LOUD_STRIDE, + use_noise=USE_NOISE, + noise_ratios=NOISE_RATIOS, + noise_bands=NOISE_BANDS, + padding_mode=PADDING_MODE, + bias=True)''' + self.decoder = SimpleDecoder(global_args.latent_dim, io_channels=2) + self.decoder_ema = deepcopy(self.decoder) + + + self.rng = torch.quasirandom.SobolEngine(1, scramble=True) + self.ema_decay = global_args.ema_decay + + self.num_quantizers = global_args.num_quantizers + self.quantized = (self.num_quantizers > 0) + if self.quantized: + if self.quantized: print(f"Making a quantizer.") + quantizer_class = ResidualMemcodes if global_args.num_quantizers > 1 else Memcodes + + quantizer_kwargs = {} + if global_args.num_quantizers > 1: + quantizer_kwargs["num_quantizers"] = global_args.num_quantizers + + self.quantizer = quantizer_class( + dim=global_args.latent_dim, + heads=global_args.num_heads, + num_codes=global_args.codebook_size, + temperature=1., + **quantizer_kwargs + ) + + self.quantizer_ema = deepcopy(self.quantizer) + #self.melstft_loss = auraloss.freq.MelSTFTLoss(global_args.sample_rate, device="cuda") + self.mrstft = auraloss.freq.MultiResolutionSTFTLoss() + + def lin_distance(self, x, y): + return torch.norm(x - y) / torch.norm(x) + + def log_distance(self, x, y): + return abs(torch.log(x + 1e-7) - torch.log(y + 1e-7)).mean() + + def distance(self, x, y): + scales = [2048, 1024, 512, 256, 128] + x = multiscale_stft(x, scales, .75) + y = multiscale_stft(y, scales, .75) + + lin = sum(list(map(self.lin_distance, x, y))) + log = sum(list(map(self.log_distance, x, y))) + + return lin + log + + def reparametrize(self, mean, scale): + std = nn.functional.softplus(scale) + 1e-4 + var = std * std + logvar = torch.log(var) + + z = torch.randn_like(mean) * std + mean + + kl = (mean * mean + var - logvar - 1).sum(1).mean() + + if self.cropped_latent_size: + noise = torch.randn( + z.shape[0], + self.latent_size - self.cropped_latent_size, + z.shape[-1], + ).to(z.device) + z = torch.cat([z, noise], 1) + return z, kl + + + def encode(self, *args, **kwargs): + if self.training: + return self.encoder(*args, **kwargs) + return self.encoder_ema(*args, **kwargs) + + def decode(self, *args, **kwargs): + if self.training: + return self.decoder(*args, **kwargs) + return self.decoder_ema(*args, **kwargs) + + def configure_optimizers(self): + return optim.Adam([*self.encoder.parameters(), *self.decoder.parameters()], lr=2e-5) + + + def training_step(self, batch, batch_idx): + reals = batch[0] + encoder_input = reals + + #if self.pqmf_bands > 1: + # encoder_input = self.pqmf(reals) + + targets = deepcopy(reals) + + # Compute the model output and the loss. + with torch.cuda.amp.autocast(): + tokens = self.encoder(encoder_input).float() + + if self.num_quantizers > 0: + #Rearrange for Memcodes + tokens = rearrange(tokens, 'b d n -> b n d') + + #Quantize into memcodes + tokens, _ = self.quantizer(tokens) + + tokens = rearrange(tokens, 'b n d -> b d n') + + # p = torch.rand([reals.shape[0], 1], device=reals.device) + # quantized = torch.where(p > 0.2, quantized, torch.zeros_like(quantized)) + z = tokens # ? Zach? + + with torch.cuda.amp.autocast(): + out_wave = self.decoder(z) + mse_loss = 2 * F.mse_loss(out_wave, targets) + mstft_loss = 0.1 * self.mrstft(out_wave, targets) + loss = mse_loss + mstft_loss + + + log_dict = { + 'train/loss': loss.detach(), + 'train/mse_loss': mse_loss.detach(), + 'train/mstft_loss': mstft_loss.detach(), + } + + self.log_dict(log_dict, prog_bar=True, on_step=True) + return loss + + def on_before_zero_grad(self, *args, **kwargs): + decay = 0.95 if self.current_epoch < 25 else self.ema_decay + ema_update(self.decoder, self.decoder_ema, decay) + ema_update(self.encoder, self.encoder_ema, decay) + + if self.num_quantizers > 0: + ema_update(self.quantizer, self.quantizer_ema, decay) + +class ExceptionCallback(pl.Callback): + def on_exception(self, trainer, module, err): + print(f'{type(err).__name__}: {err}', file=sys.stderr) + + + + + + +class DemoCallback(pl.Callback): + def __init__(self, demo_dl, global_args): + super().__init__() + self.demo_every = global_args.demo_every + self.demo_samples = global_args.sample_size + self.demo_steps = global_args.demo_steps + self.demo_dl = iter(demo_dl) + self.sample_rate = global_args.sample_rate + self.pqmf_bands = global_args.pqmf_bands + self.quantized = global_args.num_quantizers > 0 + + if self.pqmf_bands > 1: + self.pqmf = PQMF(2, 70, global_args.pqmf_bands) + + @rank_zero_only + @torch.no_grad() + def on_train_epoch_end(self, trainer, module): + #last_demo_step = -1 + #if (trainer.global_step - 1) % self.demo_every != 0 or last_demo_step == trainer.global_step: + if trainer.current_epoch % self.demo_every != 0: + return + + #last_demo_step = trainer.global_step + + demo_reals, _ = next(self.demo_dl) + + encoder_input = demo_reals + + if self.pqmf_bands > 1: + encoder_input = self.pqmf(demo_reals) + + encoder_input = encoder_input.to(module.device) + + demo_reals = demo_reals.to(module.device) + + #noise = torch.randn([demo_reals.shape[0], 2, self.demo_samples]).to(module.device) + + tokens = module.encoder_ema(encoder_input) + + if self.quantized: + + #Rearrange for Memcodes + tokens = rearrange(tokens, 'b d n -> b n d') + + tokens, _= module.quantizer_ema(tokens) + tokens = rearrange(tokens, 'b n d -> b d n') + + + #fakes = sample(module.decoder_ema, encoder_input) # , noise, self.demo_steps, 1, tokens) + fakes = module.decoder_ema(tokens) + + # Put the demos together + fakes = rearrange(fakes, 'b d n -> d (b n)') + demo_reals = rearrange(demo_reals, 'b d n -> d (b n)') + + #demo_audio = torch.cat([demo_reals, fakes], -1) + + try: + log_dict = {} + + filename = f'recon_{trainer.global_step:08}.wav' + fakes = fakes.clamp(-1, 1).mul(32767).to(torch.int16).cpu() + torchaudio.save(filename, fakes, self.sample_rate) + + reals_filename = f'reals_{trainer.global_step:08}.wav' + demo_reals = demo_reals.clamp(-1, 1).mul(32767).to(torch.int16).cpu() + torchaudio.save(reals_filename, demo_reals, self.sample_rate) + + + log_dict[f'recon'] = wandb.Audio(filename, + sample_rate=self.sample_rate, + caption=f'Reconstructed') + log_dict[f'real'] = wandb.Audio(reals_filename, + sample_rate=self.sample_rate, + caption=f'Real') + trainer.logger.experiment.log(log_dict, step=trainer.global_step) + except Exception as e: + print(f'{type(e).__name__}: {e}', file=sys.stderr) + +def main(): + args = get_all_args() + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + print('Using device:', device) + torch.manual_seed(args.seed) + + train_set = SampleDataset([args.training_dir], args) + train_dl = data.DataLoader(train_set, args.batch_size, shuffle=True, + num_workers=args.num_workers, persistent_workers=True, pin_memory=True) + wandb_logger = pl.loggers.WandbLogger(project=args.name) + push_wandb_config(wandb_logger, args) + demo_dl = data.DataLoader(train_set, args.num_demos, shuffle=True) + + exc_callback = ExceptionCallback() + ckpt_callback = pl.callbacks.ModelCheckpoint(every_n_train_steps=args.checkpoint_every, save_top_k=-1) + demo_callback = DemoCallback(demo_dl, args) + model = ZQVAE(args) + wandb_logger.watch(model) + + #torch.autograd.set_detect_anomaly(True) + + trainer = pl.Trainer( + gpus=torch.cuda.device_count(), + strategy='ddp', + precision=16, + accumulate_grad_batches={ + 0:1, + 1: args.accum_batches #Start without accumulation + # 5:2, + # 10:3, + # 12:4, + # 14:5, + # 16:6, + # 18:7, + # 20:8 + }, + callbacks=[ckpt_callback, demo_callback, exc_callback], + logger=wandb_logger, + log_every_n_steps=1, + max_epochs=10000000, + ) + + trainer.fit(model, train_dl) + +if __name__ == '__main__': + main() \ No newline at end of file From f51415836fc6c67add50932a189290575faccc5c Mon Sep 17 00:00:00 2001 From: Scott Hawley Date: Sat, 4 Jun 2022 18:01:35 +0000 Subject: [PATCH 3/3] code set for scaling runs, next switching to zach decoders branch --- dataset/dataset.py | 24 +++++++---------- decoders/decoders.py | 32 +++++++++++++---------- defaults.ini | 11 +++++++- diffusion/utils.py | 19 ++++++++++++-- dvae/dvae.py | 11 +++++--- run_zmc.py | 61 +++++++------------------------------------- train_fastdiff.py | 38 +++++++++++++-------------- 7 files changed, 91 insertions(+), 105 deletions(-) mode change 100644 => 100755 train_fastdiff.py diff --git a/dataset/dataset.py b/dataset/dataset.py index ae059b8..602460c 100644 --- a/dataset/dataset.py +++ b/dataset/dataset.py @@ -3,7 +3,7 @@ from torchaudio import transforms as T import random from glob import glob -from diffusion.utils import Stereo, PadCrop +from diffusion.utils import Stereo, PadCrop, PhaseFlipper, NormInputs class SampleDataset(torch.utils.data.Dataset): def __init__(self, paths, global_args): @@ -12,7 +12,9 @@ def __init__(self, paths, global_args): self.augs = torch.nn.Sequential( #RandomGain(0.9, 1.0), - PadCrop(global_args.sample_size), + PadCrop(global_args.sample_size, randomize=global_args.random_crop), + PhaseFlipper(), + NormInputs(do_norm=global_args.norm_inputs) ) self.encoding = torch.nn.Sequential( @@ -20,12 +22,8 @@ def __init__(self, paths, global_args): ) for path in paths: - self.filenames += glob(f'{path}/**/*.wav', recursive=True) - self.filenames += glob(f'{path}/**/*.flac', recursive=True) - self.filenames += glob(f'{path}/**/*.ogg', recursive=True) - self.filenames += glob(f'{path}/**/*.aiff', recursive=True) - self.filenames += glob(f'{path}/**/*.aif', recursive=True) - self.filenames += glob(f'{path}/**/*.mp3', recursive=True) + for ext in ['wav','flac','ogg','aiff','aif','mp3']: + self.filenames += glob(f'{path}/**/*.{ext}', recursive=True) self.sr = global_args.sample_rate @@ -51,7 +49,7 @@ def __getitem__(self, idx): audio = self.audio_files[idx] else: audio = self.load_file(audio_filename) - + audio = audio.clamp(-1, 1) #Run file-level augmentations @@ -83,12 +81,8 @@ def __init__(self, paths, global_args): for path in paths: - self.filenames += glob(f'{path}/**/*.wav', recursive=True) - self.filenames += glob(f'{path}/**/*.flac', recursive=True) - self.filenames += glob(f'{path}/**/*.ogg', recursive=True) - self.filenames += glob(f'{path}/**/*.aiff', recursive=True) - self.filenames += glob(f'{path}/**/*.aif', recursive=True) - self.filenames += glob(f'{path}/**/*.mp3', recursive=True) + for ext in ['wav','flac','ogg','aiff','aif','mp3']: + self.filenames += glob(f'{path}/**/*.{ext}', recursive=True) self.sr = global_args.sample_rate diff --git a/decoders/decoders.py b/decoders/decoders.py index 0abcd7c..dae9a87 100644 --- a/decoders/decoders.py +++ b/decoders/decoders.py @@ -448,8 +448,8 @@ def forward(self, x, add_noise: bool = True): waveform = torch.tanh(waveform.clone()) * mod_sigmoid(loudness) print("2 waveform.size() = ",waveform.size()) - #if add_noise: - # waveform = waveform + noise + if add_noise: + waveform = waveform + noise waveform = waveform.clone()[:,:,0:32768] #truncate @@ -459,31 +459,38 @@ def forward(self, x, add_noise: bool = True): def GenBlock(input_channels, output_channels, kernel_size=4, stride=2, padding=1, final_block=False): if not final_block: return nn.Sequential( - nn.ConvTranspose1d(input_channels, output_channels, kernel_size, stride=stride, padding=padding), + #nn.ConvTranspose1d(input_channels, output_channels, kernel_size, stride=stride, padding=padding), + nn.Upsample(scale_factor=stride), + nn.Conv1d(input_channels, output_channels, kernel_size, stride=1, padding=padding), nn.BatchNorm1d(output_channels), - nn.ReLU() + nn.LeakyReLU(0.2) ) else: # Final block return nn.Sequential( - nn.ConvTranspose1d(input_channels, output_channels, kernel_size, stride=stride, padding=padding), + #nn.ConvTranspose1d(input_channels, output_channels, kernel_size, stride=stride, padding=padding), + nn.Upsample(scale_factor=stride), + nn.Conv1d(input_channels, output_channels, kernel_size, stride=1, padding=padding), #nn.Tanh() # save tanh for end of loop, to rescale it ) class Upscale_new(nn.Module): - def __init__(self, inc, outc, ksize, scale=2, final_block=False, add_noise=True): + def __init__(self, inc, outc, ksize, scale=2, final_block=False, add_noise=False): super().__init__() self.gb = GenBlock(inc, outc, final_block=final_block, stride=scale) self.conv_same_size1 = nn.Conv1d(outc,outc,ksize,stride=1,padding=1) self.conv_same_size2 = nn.Conv1d(outc,outc,ksize,stride=1,padding=1) self.add_noise = add_noise - self.act = nn.ReLU() + self.act = nn.Tanh() + #self.bn = nn.BatchNorm1d(outc) def forward(self, x): x = self.gb(x) if self.add_noise: # some way of letting the network better match Zach's crazy Splice dataset noise = torch.rand_like(x) * 2 - 1 + # somehow we want the noise to be switched on or off based on what input signal is, but we only have latents x morph = self.act(self.conv_same_size1(x * noise)) # let x serve as the switch to allow more or less noise + #morph = self.bn(morph) # output looked like it had a positive bias, so let's bn that x = self.conv_same_size2(x + morph) return x @@ -510,7 +517,7 @@ class SimpleDecoder(nn.Module): """ Scott trying Just making a basic expanding thingy """ - def __init__(self, latent_dim, io_channels, out_length=32768, depth=16, add_noise=True): + def __init__(self, latent_dim, io_channels, out_length=32768, depth=16, add_noise=False): super().__init__() channels = [latent_dim,32,16,8,4,4,2] scales = [2,2,4,4,2,2] @@ -518,19 +525,18 @@ def __init__(self, latent_dim, io_channels, out_length=32768, depth=16, add_nois assert len(scales) == (len(channels)-1) self.out_length = out_length - - self.uplayers = nn.ModuleList( + self.up_layers = nn.ModuleList( [Upscale_new(channels[i],channels[i+1], ksize, scale=scales[i], final_block=(i==len(scales)-1), add_noise=add_noise) for i in range(len(scales))] ) #self.final_conv = nn.Conv1d(channels[-1], channels[-1], ksize, stride=1, padding=1) - self.final_act = torch.tanh # so output waveform is on [-1,1] + self.final_act = nn.Tanh() # so output waveform is on [-1,1] def forward(self, x): # initially, x = z = latents. then we upscale it - for i in range(len(self.uplayers)): + for i in range(len(self.up_layers)): #print(f"{i} 1 x.size() = ",x.size()) - x = self.uplayers[i](x) + x = self.up_layers[i](x) x = 1.1*self.final_act(x) return x[:,:,0:self.out_length] # crop to desired length, throw away the rest diff --git a/defaults.ini b/defaults.ini index 5b00172..719509a 100644 --- a/defaults.ini +++ b/defaults.ini @@ -19,7 +19,7 @@ num_workers = 12 sample_size = 32768 # Number of epochs between demos -demo_every = 10 +demo_every = 20 # Number of denoising steps for the demos demo_steps = 250 @@ -59,3 +59,12 @@ pqmf_bands = 1 # number of heads for memcodes num_heads = 8 + +# grab samples from random parts of input files +random_crop = True + +# normalize inputs? (for quiet sounds like IDMT guitar dataset) +norm_inputs = False + +# number of nodes / pods to run on +num_nodes = 2 diff --git a/diffusion/utils.py b/diffusion/utils.py index 768e190..a109187 100644 --- a/diffusion/utils.py +++ b/diffusion/utils.py @@ -8,18 +8,33 @@ def get_alphas_sigmas(t): return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2) class PadCrop(nn.Module): - def __init__(self, n_samples): + def __init__(self, n_samples, randomize=False): super().__init__() self.n_samples = n_samples + self.randomize = randomize def __call__(self, signal): n, s = signal.shape - start = 0 #torch.randint(0, max(0, s - self.n_samples) + 1, []).item() + start = 0 if (not self.randomize) else torch.randint(0, max(0, s - self.n_samples) + 1, []).item() end = start + self.n_samples output = signal.new_zeros([n, self.n_samples]) output[:, :min(s, self.n_samples)] = signal[:, start:end] return output +class PhaseFlipper(nn.Module): + "she was PHAAAAAAA-AAAASE FLIPPER, a random invert yeah" + def __call__(self, signal): + return -signal if (random.random() < 0.5) else signal + +class NormInputs(nn.Module): + "useful for quiet inputs. intended to be part of augmentation chain; not activated by default" + def __init__(self, do_norm=False): + super().__init__() + self.do_norm = do_norm + self.eps = 1e-2 + def __call__(self, signal): + return signal if (not self.do_norm) else signal/(torch.amax(signal,-1)[0] + self.eps) + class Mono(nn.Module): def __call__(self, signal): return torch.mean(signal, dim=0) if len(signal.shape) > 1 else signal diff --git a/dvae/dvae.py b/dvae/dvae.py index a3b527e..d7de12e 100644 --- a/dvae/dvae.py +++ b/dvae/dvae.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 -import argparse +#import argparse +from prefigure.prefigure import get_all_args, push_wandb_config from contextlib import contextmanager from copy import deepcopy import math @@ -25,7 +26,7 @@ from encoders.encoders import RAVEEncoder, ResConvBlock, SoundStreamXLEncoder from nwt_pytorch import Memcodes -from residual_memcodes import ResidualMemcodes +from dvae.residual_memcodes import ResidualMemcodes class DiffusionDecoder(nn.Module): def __init__(self, latent_dim, io_channels, depth=16): @@ -167,7 +168,7 @@ def __init__(self, global_args): self.num_quantizers = global_args.num_quantizers if self.num_quantizers > 0: - print(f"Making a quantizer. quantized: {global_args.quantized}") + print(f"Making a quantizer. quantized: {global_args.num_quantizers}") quantizer_class = ResidualMemcodes if global_args.num_quantizers > 1 else Memcodes quantizer_kwargs = {} @@ -341,6 +342,8 @@ def on_train_epoch_end(self, trainer, module): print(f'{type(e).__name__}: {e}', file=sys.stderr) def main(): + args = get_all_args() + ''' p = argparse.ArgumentParser() p.add_argument('--training-dir', type=Path, required=True, help='the training data directory') @@ -390,6 +393,7 @@ def main(): help='number of sub-bands for the PQMF filter') args = p.parse_args() + ''' device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print('Using device:', device) @@ -399,6 +403,7 @@ def main(): train_dl = data.DataLoader(train_set, args.batch_size, shuffle=True, num_workers=args.num_workers, persistent_workers=True, pin_memory=True) wandb_logger = pl.loggers.WandbLogger(project=args.name) + push_wandb_config(wandb_logger, args) demo_dl = data.DataLoader(train_set, args.num_demos, shuffle=True) exc_callback = ExceptionCallback() diff --git a/run_zmc.py b/run_zmc.py index 3ae80b8..289502b 100755 --- a/run_zmc.py +++ b/run_zmc.py @@ -80,49 +80,6 @@ def sample(model, inputs): return v -''' -@torch.no_grad() -def sample_old(model, x, steps, eta, logits): - """Draws samples from a model given starting noise.""" - ts = x.new_ones([x.shape[0]]) - - # Create the noise schedule - t = torch.linspace(1, 0, steps + 1)[:-1] - alphas, sigmas = get_alphas_sigmas(get_crash_schedule(t)) - - # The sampling loop - for i in trange(steps): - - # Get the model output (v, the predicted velocity) - with torch.cuda.amp.autocast(): - v = model(x, ts * t[i], logits).float() - - # Predict the noise and the denoised image - pred = x * alphas[i] - v * sigmas[i] - eps = x * sigmas[i] + v * alphas[i] - - # If we are not on the last timestep, compute the noisy image for the - # next timestep. - if i < steps - 1: - # If eta > 0, adjust the scaling factor for the predicted noise - # downward according to the amount of additional noise to add - ddim_sigma = eta * (sigmas[i + 1]**2 / sigmas[i]**2).sqrt() * \ - (1 - alphas[i]**2 / alphas[i + 1]**2).sqrt() - adjusted_sigma = (sigmas[i + 1]**2 - ddim_sigma**2).sqrt() - - # Recombine the predicted noise and predicted denoised image in the - # correct proportions for the next step - x = pred * alphas[i + 1] + eps * adjusted_sigma - - # Add the correct amount of fresh noise - if eta: - x += torch.randn_like(x) * ddim_sigma - - # If we are on the last timestep, output the denoised image - return pred -''' - - class ToMode: def __init__(self, mode): self.mode = mode @@ -150,7 +107,7 @@ def __init__(self, global_args, #self.encoder = Encoder(global_args.codebook_size, 2) #self.encoder = SoundStreamXLEncoder(32, global_args.latent_dim, n_io_channels=2, strides=[2, 2, 4, 5, 8], c_mults=[2, 4, 4, 8, 16]) - self.loudness = Loudness(global_args.sample_rate, 512) + #self.loudness = Loudness(global_args.sample_rate, 512) self.pqmf_bands = global_args.pqmf_bands @@ -172,7 +129,7 @@ def __init__(self, global_args, #self.diffusion = DiffusionDecoder(global_args.latent_dim, 2) #self.decoder = RAVEDecoder(global_args.latent_dim, 2) # default RAVE settings pulled from https://github.com/acids-ircam/RAVE/blob/master/train_rave.py - DATA_SIZE = 2 + '''DATA_SIZE = 2 CAPACITY = 64 LATENT_SIZE = 128 BIAS = True @@ -200,7 +157,7 @@ def __init__(self, global_args, #no_latency=False #PADDING_MODE = "causal" if no_latency else "centered" PADDING_MODE = "centered" - '''self.decoder = RAVEGenerator(global_args.latent_dim, + self.decoder = RAVEGenerator(global_args.latent_dim, capacity=CAPACITY, data_size=DATA_SIZE, ratios=RATIOS, @@ -236,7 +193,6 @@ def __init__(self, global_args, ) self.quantizer_ema = deepcopy(self.quantizer) - #self.melstft_loss = auraloss.freq.MelSTFTLoss(global_args.sample_rate, device="cuda") self.mrstft = auraloss.freq.MultiResolutionSTFTLoss() def lin_distance(self, x, y): @@ -316,8 +272,8 @@ def training_step(self, batch, batch_idx): with torch.cuda.amp.autocast(): out_wave = self.decoder(z) - mse_loss = 2 * F.mse_loss(out_wave, targets) - mstft_loss = 0.1 * self.mrstft(out_wave, targets) + mse_loss = 2 * F.mse_loss(out_wave, targets) # 2 is just based on experience, to balance the losses + mstft_loss = 0.1 * self.mrstft(out_wave, targets) # 0.2 is just based on experience, to balance the losses. loss = mse_loss + mstft_loss @@ -327,7 +283,7 @@ def training_step(self, batch, batch_idx): 'train/mstft_loss': mstft_loss.detach(), } - self.log_dict(log_dict, prog_bar=True, on_step=True) + self.log_dict(log_dict, prog_bar=True, on_step=True, sync_dist=True) return loss def on_before_zero_grad(self, *args, **kwargs): @@ -450,7 +406,8 @@ def main(): trainer = pl.Trainer( gpus=torch.cuda.device_count(), - strategy='ddp', + num_nodes=args.num_nodes, + strategy='fsdp', precision=16, accumulate_grad_batches={ 0:1, @@ -472,4 +429,4 @@ def main(): trainer.fit(model, train_dl) if __name__ == '__main__': - main() \ No newline at end of file + main() diff --git a/train_fastdiff.py b/train_fastdiff.py old mode 100644 new mode 100755 index af263a9..3fed18d --- a/train_fastdiff.py +++ b/train_fastdiff.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 -import argparse +#import argparse from prefigure.prefigure import get_all_args, push_wandb_config from contextlib import contextmanager from copy import deepcopy @@ -43,24 +43,27 @@ def alpha_sigma_to_t(alpha, sigma): return torch.atan2(sigma, alpha) / math.pi * 2 @torch.no_grad() -def sample(model, x, steps, eta, logits): +def sample(model, reals, specs, steps, eta): """Draws samples from a model given starting noise.""" - ts = x.new_ones([x.shape[0]]) + ts = reals.new_ones([reals.shape[0]]) # Create the noise schedule - t = torch.linspace(1, 0, steps + 1)[:-1] + t = torch.linspace(1, 0, steps + 1)[:-1].to(reals.device) + specs = specs.to(reals.device) alphas, sigmas = get_alphas_sigmas(get_crash_schedule(t)) # The sampling loop for i in trange(steps): + t_in = (ts * t[i]).unsqueeze(1).to(reals.device) + # Get the model output (v, the predicted velocity) with torch.cuda.amp.autocast(): - v = model(x, ts * t[i], logits).float() + v = model((reals, specs, t_in)).float() # Predict the noise and the denoised image - pred = x * alphas[i] - v * sigmas[i] - eps = x * sigmas[i] + v * alphas[i] + pred = reals * alphas[i] - v * sigmas[i] + eps = reals * sigmas[i] + v * alphas[i] # If we are not on the last timestep, compute the noisy image for the # next timestep. @@ -73,11 +76,11 @@ def sample(model, x, steps, eta, logits): # Recombine the predicted noise and predicted denoised image in the # correct proportions for the next step - x = pred * alphas[i + 1] + eps * adjusted_sigma + reals = pred * alphas[i + 1] + eps * adjusted_sigma # Add the correct amount of fresh noise if eta: - x += torch.randn_like(x) * ddim_sigma + reals += torch.randn_like(reals) * ddim_sigma # If we are on the last timestep, output the denoised image return pred @@ -90,13 +93,14 @@ def __init__(self, global_args): self.diffusion = FastDiff( audio_channels=2, cond_channels=80, + upsample_ratios=[8, 8, 4] ) self.rng = torch.quasirandom.SobolEngine(1, scramble=True) self.ema_decay = global_args.ema_decay def configure_optimizers(self): - return optim.Adam([*self.diffusion.parameters()], lr=2e-4) + return optim.Adam([*self.diffusion.parameters()], lr=2e-5) def training_step(self, batch, batch_idx): @@ -118,7 +122,6 @@ def training_step(self, batch, batch_idx): with torch.cuda.amp.autocast(): v = self.diffusion((noised_reals, specs, t.unsqueeze(1))) loss = F.mse_loss(v, targets) - log_dict = { 'train/loss': loss.detach() @@ -152,19 +155,19 @@ def on_train_epoch_end(self, trainer, module): if trainer.current_epoch % self.demo_every != 0: return - demo_reals, _ = next(self.demo_dl) + demo_specs, demo_reals, _ = next(self.demo_dl) demo_reals = demo_reals.to(module.device) noise = torch.randn([demo_reals.shape[0], 2, self.demo_samples]).to(module.device) - fakes = sample(module.diffusion_ema, noise, self.demo_steps, 1) + fakes = sample(module.diffusion, noise, demo_specs, self.demo_steps, 1) # Put the demos together fakes = rearrange(fakes, 'b d n -> d (b n)') demo_reals = rearrange(demo_reals, 'b d n -> d (b n)') - demo_audio = torch.stack([demo_reals, fakes], dim=0) + #demo_audio = torch.stack([demo_reals, fakes], dim=0) try: log_dict = {} @@ -226,9 +229,6 @@ def main(): p.add_argument('--cache-training-data', type=bool, default=False, help='If true, training data is kept in RAM') - # p.add_argument('--val-set', type=str, required=True, - # help='the validation set') - args = p.parse_args() """ @@ -253,7 +253,7 @@ def main(): diffusion_trainer = pl.Trainer( gpus=args.num_gpus, - strategy='ddp', + strategy="ddp_find_unused_parameters_false", precision=16, accumulate_grad_batches={ 0:1, @@ -275,4 +275,4 @@ def main(): diffusion_trainer.fit(diffusion_model, train_dl) if __name__ == '__main__': - main() \ No newline at end of file + main()