diff --git a/models/dataset.py b/models/dataset.py new file mode 100644 index 0000000..ed1f3b6 --- /dev/null +++ b/models/dataset.py @@ -0,0 +1,179 @@ +"""Open datasets and process them to be used by a neural network.""" + + +import json +import os +import numpy as np +import h5py +import torch +from torch.utils.data import DataLoader, random_split +import functools +from PIL import Image + +CUDA = torch.cuda.is_available() + + +KWARGS = {'num_workers': 1, 'pin_memory': True} if CUDA else {} + + +def open_dataset(path, new_size, is_3d): + """Open datasets and processes data in ordor to make tensors. + + Parameters + ---------- + path : string, + Path. + new_size : int. + is_3d : boolean + If 2d or 3d. + + Returns + ------- + tensor + Images in black and white. + """ + if not os.path.exists(path): + raise OSError + if path.lower().endswith(".h5"): + data_dict = h5py.File(path, 'r') + all_datasets = data_dict['particles'][:] + else: + all_datasets = np.load(path) + dataset = np.asarray(all_datasets) + img_shape = dataset.shape + n_imgs = img_shape[0] + new_dataset = [] + if is_3d: + dataset = torch.Tensor(dataset) + dataset = normalization_linear(dataset) + dataset = dataset.reshape((len(dataset), 1)+dataset.shape[1:]) + else: + if img_shape.ndim == 3: + for i in range(n_imgs): + new_dataset.append(np.asarray(Image.fromarray( + dataset[i]).resize([new_size, new_size]))) + elif img_shape.ndim == 4: + for i in range(n_imgs): + new_dataset.append(np.asarray(Image.fromarray( + dataset[i][0]).resize([new_size, new_size]))) + dataset = torch.Tensor(new_dataset) + dataset = normalization_linear(dataset) + if len(img_shape) != 4: + dataset = dataset.reshape( + (img_shape[0], 1, img_shape[1], img_shape[1])) + return dataset + + +def normalization_linear(dataset): + """Normalize a tensor. + + Parameters + ---------- + dataset : tensor, + Images. + + Returns + ------- + dataset : tensor, + Normalized images. + """ + for i, data in enumerate(dataset): + min_data = torch.min(data) + max_data = torch.max(data) + if max_data == min_data: + raise ZeroDivisionError + dataset[i] = (data - min_data) / (max_data - min_data) + return dataset + + +def split_dataset(dataset, batch_size, frac_val): + """Separate data in train and validation sets. + + Parameters + ---------- + dataset : tensor, + Images. + batch_size : int, + Batch_size. + frac_val : float, + Ratio between validation and training datasets. + + Returns + ------- + trainset : tensor + Training images. + testset : tensor + Test images. + trainloader : tensor + Ready to be used by the NN for training images. + testloader : tensor + Ready to be used by the NN for test images. + """ + n_imgs = len(dataset) + n_val = int(n_imgs*frac_val) + trainset, testset = random_split(dataset, [n_imgs-n_val, n_val]) + trainloader = DataLoader( + trainset, batch_size=batch_size, shuffle=True, **KWARGS) + testloader = DataLoader( + testset, batch_size=batch_size, shuffle=False, **KWARGS) + return trainset, testset, trainloader, testloader + + +def hinted_tuple_hook(obj): + """Transform a list into tuple. + + Parameters + ---------- + obj : *, + Value of a dic. + + Returns + ------- + tuple, + Transform the value of a dic into dic. + obj: *, + Value of a dic. + """ + if '__tuple__' in obj: + return tuple(obj['items']) + return obj + + +def load_parameters(path): + """Load metadata for the VAE. + + Parameters + ---------- + path : string, + Path to the file. + + Returns + ------- + paths : dic, + Path to the data. + shapes: dic, + Shape of every dataset. + constants: dic, + Meta information for the vae. + search_space: dic, + Meta information for the vae. + meta_param_names: dic, + Names of meta parameters. + """ + with open(path) as json_file: + parameters = json.load(json_file, object_hook=hinted_tuple_hook) + paths = parameters["paths"] + shapes = parameters["shape"] + constants = parameters["constants"] + search_space = parameters["search_space"] + meta_param_names = parameters["meta_param_names"] + constants["conv_dim"] = len(constants["img_shape"][1:]) + constants["dataset_name"] = paths["simulated_2d"] + constants["dim_data"] = functools.reduce( + (lambda x, y: x * y), constants["img_shape"]) + return paths, shapes, constants, search_space, meta_param_names + + +if __name__ == "__main__": + PATHS, SHAPES, CONSTANTS, SEARCH_SPACE, META_PARAM_NAMES = load_parameters( + "vae_parameters.json") diff --git a/models/neural_network.py b/models/neural_network.py new file mode 100644 index 0000000..c9636a3 --- /dev/null +++ b/models/neural_network.py @@ -0,0 +1,276 @@ +import operator +from functools import reduce +import functools +from geomstats.geometry.special_orthogonal import SpecialOrthogonal +import torch +import os +import numpy as np +import torch.nn as nn +import latent_space_computation +from reparameterize import SO3reparameterize, N0reparameterize, \ + AlgebraMean, QuaternionMean, S2S1Mean, S2S2Mean, EulerYzyMean, ThetaMean,\ + VectorMean, SO2reparameterize, Nreparameterize + +import cnn_initialization + +import lie_tools + + +os.environ["GEOMSTATS_BACKEND"] = "pytorch" + +CUDA = torch.cuda.is_available() + +DEVICE = torch.device('cuda' if CUDA else 'cpu') + + +os.environ["GEOMSTATS_BACKEND"] = "pytorch" + +CUDA = torch.cuda.is_available() + +OUT_PAD = 0 + + +class Encoder(nn.Module): + """This class compute the Encoder""" + + def __init__(self, config): + """ + Initialization of the encoder. + + Parameters + ---------- + config : dic, principal constants to build a encoder. + + Returns + ------- + None. + + """ + super(Encoder, self).__init__() + self.config = config + self.wigner_dim = config["wigner_dim"] + self.latent_mode = config["latent_mode"] + self.latent_space = config["latent_space"] + self.mean_mode = config["mean_mode"] + self.fixed_sigma = None + self.transpose = True + self.item_rep = config["item_rep"] + self.rep_copies = config["rep_copies"] + self.conv_dim = config["dimension"] + self.n_enc_lay = config["n_enc_lay"] + + self.compression_conv = cnn_initialization.CompressionConv(config) + self.fcs_infeatures = int( + (config["img_shape"][-1]**self.conv_dim)/(2**(self.n_enc_lay*(self.conv_dim-1)+1))) + + self.init_layer_latent_space() + self.reparameterize = nn.ModuleList([self.rep_group]) + self.init_action_net() + + def init_action_net(self): + if self.latent_space == "so3": + self.action_net = latent_space_computation.ActionNetSo3( + self.wigner_dim) + elif self.latent_space == "so2": + self.action_net = latent_space_computation.ActionNetSo2() + else: + self.action_net = latent_space_computation.ActionNetRL() + + def init_layer_latent_space(self): + print(self.mean_mode) + if self.latent_space == 'so3': + normal = N0reparameterize(self.fcs_infeatures, z_dim=3, + fixed_sigma=self.fixed_sigma) + if self.mean_mode == 'alg': + mean_module = AlgebraMean(self.fcs_infeatures) + elif self.mean_mode == 'q': + mean_module = QuaternionMean(self.fcs_infeatures) + elif self.mean_mode == 's2s1': + mean_module = S2S1Mean(self.fcs_infeatures) + elif self.mean_mode == 's2s2': + mean_module = S2S2Mean(self.fcs_infeatures) + elif self.mean_mode == 'eulyzy': + mean_module = EulerYzyMean(self.fcs_infeatures) + self.rep_group = SO3reparameterize(normal, mean_module, k=10) + self.group_dims = 9 + elif self.latent_space == "so2": + normal = N0reparameterize(self.fcs_infeatures, z_dim=1, + fixed_sigma=self.fixed_sigma) + if self.mean_mode == "theta": + mean_module = ThetaMean(self.fcs_infeatures) + elif self.mean_mode == "v": + mean_module = VectorMean(self.fcs_infeatures) + self.rep_group = SO2reparameterize(normal, mean_module) + else: + normal = N0reparameterize(self.fcs_infeatures, + z_dim=self.config["latent_dim"], + fixed_sigma=self.fixed_sigma) + self.rep_group = Nreparameterize( + self.fcs_infeatures, z_dim=self.config["latent_dim"], fixed_sigma=self.fixed_sigma) + + def forward(self, h, n=1): + """ + Compute the passage through the neural network + + Parameters + ---------- + h : tensor, image or voxel. + + Returns + ------- + mu : tensor, latent space mu. + logvar : tensor, latent space sigma. + + """ + h = self.compression_conv(h) + rot_mat_enc = [r(h, n) for r in self.reparameterize][0] + self.rot_mat_enc = rot_mat_enc[0] + if self.latent_space == "so3": + self.eayzy = lie_tools.group_matrix_to_eazyz( + rot_mat_enc[0]) + batch_size = self.eayzy.shape[0] + items = self.action_net(self.eayzy) + elif self.latent_space == "so2": + items = self.action_net(self.rot_mat_enc) + items = nn.functional.normalize(items, eps=1e-30) + if len(items) != 0: + assert len(items[items.T[0]**2 + items.T[1]**2 < 0.999] + ) == 0, print(items, items.shape) + else: + items = self.action_net(self.rot_mat_enc) + return items, self.rot_mat_enc + + def kl(self): + kl = [r.kl() for r in self.reparameterize] + return kl + + +SO3 = SpecialOrthogonal(3, point_type="vector") + + +class VaeConv(nn.Module): + """This class compute the VAE""" + + def __init__(self, config): + """ + Initialization of the VAE. + + Parameters + ---------- + config : dic, principal constants to build a encoder. + + Returns + ------- + None. + + """ + super(VaeConv, self).__init__() + self.config = config + self.img_shape = config["img_shape"] + self.conv_dim = config["dimension"] + self.with_sigmoid = config["with_sigmoid"] + self.n_encoder_blocks = config["n_enc_lay"] + self.n_decoder_blocks = config["n_dec_lay"] + + self.encoder = Encoder( + self.config) + + self.decoder = cnn_initialization.DecoderConv( + self.config) + + def forward(self, x): + """ + Compute the passage through the neural network + + Parameters + ---------- + x : tensor, image or voxel. + + Returns + ------- + res : tensor, image or voxel. + scale_b: tensor, image or voxel. + mu : tensor, latent space mu. + logvar : tensor, latent space sigma. + + """ + z, matrix = self.encoder(x) + res, scale_b = self.decoder(z) + return res, scale_b, z + + +def reparametrize(mu, logvar, n_samples=1): + """ + Transform the probabilistic latent space into a deterministic latent space + + Parameters + ---------- + mu : tensor, latent space mu. + logvar : tensor, latent space sigma. + n_samples : int, number of samples. + + Returns + ------- + z_flat : tensor, deterministic latent space + + """ + n_batch_data, latent_dim = mu.shape + + std = logvar.mul(0.5).exp_() + std_expanded = std.expand( + n_samples, n_batch_data, latent_dim) + mu_expanded = mu.expand( + n_samples, n_batch_data, latent_dim) + + if CUDA: + eps = torch.cuda.FloatTensor( + n_samples, n_batch_data, latent_dim).normal_() + else: + eps = torch.FloatTensor(n_samples, n_batch_data, latent_dim).normal_() + eps = torch.autograd.Variable(eps) + + z = eps * std_expanded + mu_expanded + z_flat = z.reshape(n_samples * n_batch_data, latent_dim) + z_flat = z_flat.squeeze(dim=1) + return z_flat + + +def sample_from_q(mu, logvar, n_samples=1): + """ + Transform a probabilistic latent space into a deterministic latent space + + Parameters + ---------- + mu : tensor, latent space mu. + logvar : tensor, latent space sigma. + n_samples : int, number of samples. + + Returns + ------- + tensor, deterministic latent space. + + """ + return reparametrize(mu, logvar, n_samples) + + +def sample_from_prior(latent_dim, n_samples=1): + """ + Transform a probabilistic latent space into a deterministic latent. + + Parameters + ---------- + latent_dim : int, latent dimension. + n_samples : int, optional, number of sample. + + Returns + ------- + tensor, deterministic latent space. + + """ + if CUDA: + mu = torch.cuda.FloatTensor(n_samples, latent_dim).fill_(0) + logvar = torch.cuda.FloatTensor(n_samples, latent_dim).fill_(0) + else: + mu = torch.zeros(n_samples, latent_dim) + logvar = torch.zeros(n_samples, latent_dim) + return reparametrize(mu, logvar) diff --git a/models/nn.py b/models/nn.py new file mode 100644 index 0000000..48b9757 --- /dev/null +++ b/models/nn.py @@ -0,0 +1,890 @@ +"""This file is creating Convolutional Neural Networks.""" + +import math +from geomstats.geometry.special_orthogonal import SpecialOrthogonal +import functools +from functools import reduce +import operator +import numpy as np +import os +import torch +import torch.nn as nn +import cryo_dataset as ds +from pinchon_hoggan_dense import rot_mat, Jd +os.environ["GEOMSTATS_BACKEND"] = "pytorch" + +CUDA = torch.cuda.is_available() + + +def test(): + if CUDA: + path_vae = "Cryo/VAE_Cryo_V3/vae_parameters.json" + else: + path_vae = "vae_parameters.json" + PATHS, SHAPES, CONSTANTS, SEARCH_SPACE, _ = ds.load_parameters( + path_vae) + CONSTANTS.update(SEARCH_SPACE) + CONSTANTS["latent_space_definition"] = 0 + CONSTANTS["latent_dim"] = 10 + CONSTANTS["skip_z"] = True + CONSTANTS["n_gan_lay"] = 1 + enc = EncoderConv(CONSTANTS) + A = torch.zeros(20, 1, 64, 64) + B = enc.forward(A) + dec = DecoderConv(CONSTANTS) + C = dec.forward(B[2]) + dis = Discriminator(CONSTANTS) + D = dis(A) + return D + + +DEVICE = torch.device('cuda' if CUDA else 'cpu') + +NN_CONV = { + 2: nn.Conv2d, + 3: nn.Conv3d} +NN_CONV_TRANSPOSE = { + 2: nn.ConvTranspose2d, + 3: nn.ConvTranspose3d} + +NN_BATCH_NORM = { + 2: nn.BatchNorm2d, + 3: nn.BatchNorm3d} + +LATENT_SPACE_DEFINITION = { + 0: "(R^L)", + 1: "(SO(3)xT)", + 2: "(SO(3)xT)z0->R^L", + 3: "Wig((SO(3)xT))z0->R^L", + 4: "Wig((S2xS2xT))z0->R^L"} + +OUT_PAD = 0 + + +def conv_parameters(conv_dim, kernel_size, stride, padding, dilation): + """ + Construction of arrays of constants for 2d or 3d problems. + + Parameters + ---------- + conv_dim : int, 2 or 3 for 2d or 3d problems. + kernel_size : int, kernel size. + stride : int, stride. + padding : int, padding. + dilation : int, dilation. + + + Returns + ------- + kernel_size : array, kernel size for a 2d or 3d problem + stride : array, stride for a 2d or 3d problem + padding : array, padding for a 2d or 3d problem + """ + if type(kernel_size) is int: + kernel_size = np.repeat(kernel_size, conv_dim) + if type(stride) is int: + stride = np.repeat(stride, conv_dim) + if type(padding) is int: + padding = np.repeat(padding, conv_dim) + if type(dilation) is int: + dilation = np.repeat(dilation, conv_dim) + if len(kernel_size) != conv_dim: + raise ValueError + + if len(stride) != conv_dim: + raise ValueError + if len(padding) != conv_dim: + raise ValueError + if len(dilation) != conv_dim: + raise ValueError + return kernel_size, stride, padding, dilation + + +def conv_transpose_input_size(out_shape, in_channels, kernel_size, stride, + padding, dilation, output_padding=OUT_PAD): + """ + Compute the in_shape of a layer by knowing the output shape. + + Parameters + ---------- + out_shape : tuple, out shape of the layer. + in_channels : int, number of in channel. + kernel_size : int, kernel size. + stride : int, stride. + padding : int,padding. + dilation : int, dilation. + output_padding : int optional, out pad, the default is OUT_PAD. + + Returns + ------- + tuple, shape of the information before passing the layer. + """ + conv_dim = len(out_shape[1:]) + kernel_size, stride, padding, dilation = conv_parameters( + conv_dim, kernel_size, stride, padding, dilation) + if type(output_padding) is int: + output_padding = np.repeat(output_padding, conv_dim) + + def one_dim(i_dim): + """Inverts the formula giving the output shape.""" + shape_i_dim = ( + ((out_shape[i_dim+1] + + 2 * padding[i_dim] + - dilation[i_dim] * (kernel_size[i_dim] - 1) + - output_padding[i_dim] - 1) + // stride[i_dim]) + + 1) + + if shape_i_dim % 1 != 0: + raise ValueError + return int(shape_i_dim) + + in_shape = [one_dim(i_dim) for i_dim in range(conv_dim)] + in_shape = tuple(in_shape) + + return (in_channels,) + in_shape + + +def conv_output_size(in_shape, out_channels, kernel_size, stride, padding, + dilation): + """ + Compute the output shape by knowing the input shape of a layer + + Parameters + ---------- + in_shape : tuple, shape of the input of the layer. + out_channels : int, number of output channels. + kernel_size : int, kernel size. + stride : int, stride. + padding : int,padding. + dilation : int, dilation. + Returns + ------- + out_shape : tuple, shape of the output of the layer. + """ + out_shape = conv_transpose_input_size( + out_shape=in_shape, + in_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + output_padding=0, + dilation=dilation) + # out_shape = (out_shape[0], out_shape[1], out_shape[2]) + return out_shape + + +class EncoderConv(nn.Module): + """This class compute the Encoder""" + + def __init__(self, config): + """ + Initialization of the encoder. + + Parameters + ---------- + config : dic, principal constants to build a encoder. + + Returns + ------- + None. + + """ + super(EncoderConv, self).__init__() + self.config = config + self.latent_dim = int(config["latent_dim"]/2) + self.img_shape = config["img_shape"] + self.conv_dim = config["conv_dim"] + self.n_blocks = config["n_enc_lay"] + self.enc_c = config["enc_c"] + self.enc_ks = config["enc_ks"] + self.enc_str = config["enc_str"] + self.enc_pad = config["enc_pad"] + self.enc_dil = config["enc_dil"] + self.nn_conv = NN_CONV[self.conv_dim] + self.nn_batch_norm = NN_BATCH_NORM[self.conv_dim] + self.latent_space_definition = config["latent_space_definition"] + self.z0 = torch.zeros(self.latent_dim).to(DEVICE) + self.z0[0] = 1 + self.z1 = torch.zeros(self.latent_dim).to(DEVICE) + self.z1[1] = 1 + + # activation functions + self.leakyrelu = nn.LeakyReLU(0.2) + self.relu = nn.ReLU() + self.sigmoid = nn.Sigmoid() + + # encoder + self.blocks = torch.nn.ModuleList() + + next_in_channels = self.img_shape[0] + next_in_shape = self.img_shape + for i in range(self.n_blocks): + enc_c_factor = 2 ** i + enc = self.nn_conv( + in_channels=next_in_channels, + out_channels=self.enc_c * enc_c_factor, + kernel_size=self.enc_ks, + stride=self.enc_str, + padding=self.enc_pad) + bn = self.nn_batch_norm(enc.out_channels) + + self.blocks.append(enc) + self.blocks.append(bn) + enc_out_shape = self.enc_conv_output_size( + in_shape=next_in_shape, + out_channels=enc.out_channels) + next_in_shape = enc_out_shape + next_in_channels = enc.out_channels + + self.last_out_shape = next_in_shape + + self.fcs_infeatures = functools.reduce( + (lambda x, y: x * y), self.last_out_shape) + + self.fc_mu, self.fc_logvar = self.init_layer_latent_space() + + def init_layer_latent_space(self): + if self.latent_space_definition == 0: + mu = nn.Linear( + in_features=self.fcs_infeatures, out_features=self.latent_dim*2) + logvar = nn.Linear( + in_features=self.fcs_infeatures, out_features=self.latent_dim*2) + elif self.latent_space_definition in [1, 2, 3]: + mu = nn.Linear( + in_features=self.fcs_infeatures, out_features=5) + logvar = nn.Linear( + in_features=self.fcs_infeatures, out_features=5) + elif self.latent_space_definition == 4: + mu = nn.Linear( + in_features=self.fcs_infeatures, out_features=8) + logvar = nn.Linear( + in_features=self.fcs_infeatures, out_features=5) + return mu, logvar + + def enc_conv_output_size(self, in_shape, out_channels): + """ + Compute the output shape of a layer + + Parameters + ---------- + in_shape : tuple, input shape + out_channels : int, number of channels. + + Returns + ------- + out_shape : tuple, shape of the output of the layer + + """ + return conv_output_size( + in_shape, out_channels, + kernel_size=self.enc_ks, + stride=self.enc_str, + padding=self.enc_pad, + dilation=self.enc_dil) + + def forward(self, h): + """ + Compute the passage through the neural network + + Parameters + ---------- + h : tensor, image or voxel. + + Returns + ------- + mu : tensor, latent space mu. + logvar : tensor, latent space sigma. + + """ + for i in range(self.n_blocks): + h = self.blocks[2*i](h) + h = self.blocks[2*i+1](h) + h = self.leakyrelu(h) + h = h.view(-1, self.fcs_infeatures) + mu = self.fc_mu(h) + logvar = self.fc_logvar(h) + if self.latent_space_definition == 0: + z = reparametrize(mu, logvar) + return mu, logvar, z, torch.zeros(mu.shape) + elif self.latent_space_definition in [1, 2, 3]: + rot_mats_no_var, translation, logvar_rotmat = reparametrize_so3( + mu, logvar) + else: + rot_mats_no_var, translation, logvar_rotmat = reparametrize_s2s2( + mu, logvar) + rot_mats = add_logvar(rot_mats_no_var, logvar_rotmat) + if self.latent_space_definition == 1: + vec_rot = rot_mats.reshape(-1, 9) + z = torch.cat((vec_rot, translation), dim=1) + return mu, logvar, z, rot_mats + wig_matrix = convert_rot_mat_ten_to_wig_mat_ten( + rot_mats, self.latent_dim) + z0 = torch.matmul(wig_matrix, self.z0) + z1 = torch.matmul(wig_matrix, self.z1) + z = torch.cat((z0, z1, translation), dim=1) + return mu, logvar, z, rot_mats + + +def reparametrize_so3(mu, logvar): + mu_rotmat = mu.T[:-2].T + mu_transl = mu.T[-2:].T + logvar_rotmat = logvar.T[:-2].T + logvar_transl = logvar.T[-2:].T + rot_mats_no_var = rot_mat_tensor(mu_rotmat, 3) + translation = reparametrize(mu_transl, logvar_transl) + return rot_mats_no_var, translation, logvar_rotmat + + +def reparametrize_s2s2(mu, logvar): + mu_rotmat = mu.T[:-2].T + mu_transl = mu.T[-2:].T + logvar_rotmat = logvar.T[:-2].T + logvar_transl = logvar.T[-2:].T + translation = reparametrize(mu_transl, logvar_transl) + rot_mats_no_var = rotation_vectors(mu_rotmat) + return rot_mats_no_var, translation, logvar_rotmat + + +def rotation_vectors(z): + z = z.reshape((-1, 6)) + u = z.T[:3].T + v = z.T[3:].T + e1 = nn.functional.normalize(u) + v2 = v - (e1*v).sum(-1, keepdim=True) * e1 + e2 = nn.functional.normalize(v2) + e3 = torch.cross(e1, e2) + matrices = torch.cat([e1, e2, e3]). reshape((-1, 3, 3)) + return matrices + + +def map_to_lie_algebra(v): + """Map a point in R^N to the tangent space at the identity, i.e. + to the Lie Algebra + Arg: + v = vector in R^N, (..., 3) in our case + Return: + R = v converted to Lie Algebra element, (3,3) in our case""" + # make sure this is a sample from R^3 + assert v.shape[-1] == 3 + + R_x = v.new_tensor([[0., 0., 0.], + [0., 0., -1.], + [0., 1., 0.]]) + + R_y = v.new_tensor([[0., 0., 1.], + [0., 0., 0.], + [-1., 0., 0.]]) + + R_z = v.new_tensor([[0., -1., 0.], + [1., 0., 0.], + [0., 0., 0.]]) + + R = R_x * v[..., 0, None, None] + \ + R_y * v[..., 1, None, None] + \ + R_z * v[..., 2, None, None] + return R + + +def rodrigues(v): + theta = v.norm(p=2, dim=-1, keepdim=True) + # normalize K + K = map_to_lie_algebra(v / theta) + + I = torch.eye(3, device=v.device, dtype=v.dtype) + R = I + torch.sin(theta)[..., None]*K \ + + (1. - torch.cos(theta))[..., None]*(K@K) + return R + + +def add_logvar(rot_mats_no_var, logvar_rotmat): + logvar_mat = rodrigues(logvar_rotmat) + return rot_mats_no_var @ logvar_mat + + +def convert_rot_mat_to_eul_ang(matrix): + y1, y2, z = 0, 0, 0 + sy = math.sqrt(matrix[0][1]**2+matrix[2][1]**2) + if sy > 10**(-6): + y1 = math.atan2(matrix[2][1], -matrix[0][1]) + y2 = math.atan2(matrix[1][2], matrix[1][0]) + z = math.atan2(sy, matrix[1][1]) + else: + y1 = 0 + y2 = math.atan2(matrix[1][0], matrix[0][0]) + z = 0 + return torch.Tensor([y1, z, y2]) + + +def convert_rot_mat_ten_to_wig_mat_ten(matrices, latent_dim): + batch_size = matrices.shape[0] + euler_tens = torch.zeros((batch_size, 3)) + for i in range(batch_size): + euler_tens[i] = convert_rot_mat_to_eul_ang(matrices[i]) + wig_matrices = rot_mat_tensor(euler_tens, latent_dim) + return wig_matrices + + +SO3 = SpecialOrthogonal(3, point_type="vector") + + +def transform_into_so3(mu, logvar, latent_dim, n_samples=1): + n_batch_data, latent_shape = mu.shape + sigma = logvar.mul(0.5).exp_() + if CUDA: + eps = torch.cuda.FloatTensor( + n_samples, n_batch_data, latent_shape).normal_() + else: + eps = torch.FloatTensor(n_samples, n_batch_data, + latent_shape).normal_() + eps = eps.reshape(eps.shape[1:]) + tang_mu = eps*sigma + z = mu+tang_mu + matrices = rot_mat_tensor(z, latent_dim) + return matrices + + +def rot_mat_tensor(tens, latent_dim): + batch_size = tens.shape[0] + J = Jd[int((latent_dim-1)/2)].to(DEVICE) + matrices = torch.Tensor(batch_size, latent_dim, latent_dim).to(DEVICE) + for i in range(batch_size): + alpha = tens[i][0] + beta = tens[i][1] + gamma = tens[i][2] + matrices[i] = rot_mat(alpha, beta, gamma, int((latent_dim-1)/2), J) + return matrices + + +class DecoderConv(nn.Module): + """This class compute the decoder""" + + def dec_conv_transpose_input_size(self, out_shape, in_channels): + """ + Compute the in_shape of a layer by knowing the output shape. + + Parameters + ---------- + out_shape : tuple, out shape of the layer. + in_channels : int, number of in channel. + + Returns + ------- + tuple, shape of the information before passing the layer. + + """ + return conv_transpose_input_size( + out_shape=out_shape, + in_channels=in_channels, + kernel_size=self.dec_ks, + stride=self.dec_str, + padding=self.dec_pad, + dilation=self.dec_dil) + + def block(self, out_shape, dec_c_factor): + """ + Compute every layer + + Parameters + ---------- + out_shape : tuple, shape of the output of the layer. + dec_c_factor : int, decode factor. + + Returns + ------- + batch_norm : layer, layer of the NN. + conv_transpose : layer, layer of the NN. + in_shape : tuple, shape of the input of the layer + + """ + out_channels = out_shape[0] + in_channels = self.dec_c * dec_c_factor + + batch_norm = self.nn_batch_norm( + num_features=out_channels, + eps=1.e-3) + + conv_transpose = self.nn_conv_transpose( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=self.dec_ks, + stride=self.dec_str) + + in_shape = self.dec_conv_transpose_input_size( + out_shape=out_shape, + in_channels=in_channels) + return batch_norm, conv_transpose, in_shape + + def end_block(self, out_shape, dec_c_factor): + """ + Compute the last layer of the NN + + Parameters + ---------- + out_shape : tuple, out shape + dec_c_factor : int, decode factor + + Returns + ------- + conv_transpose : torch.nn.modules.conv.ConvTranspose, a layer of my NN + in_shape : tuple, input shape of the layer. + + """ + out_channels = out_shape[0] + in_channels = self.dec_c * dec_c_factor + + conv_transpose = self.nn_conv_transpose( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=self.dec_ks, + stride=self.dec_str) + + in_shape = self.dec_conv_transpose_input_size( + out_shape=out_shape, + in_channels=in_channels) + return conv_transpose, in_shape + + def __init__(self, config): + """ + Initialization of the encoder. + + Parameters + ---------- + config : dic, principal constants to build a encoder. + + Returns + ------- + None. + + """ + super(DecoderConv, self).__init__() + self.config = config + self.latent_dim = config["latent_dim"] + self.latent_space_definition = config["latent_space_definition"] + self.with_sigmoid = config["with_sigmoid"] + self.img_shape = config["img_shape"] + self.conv_dim = config["conv_dim"] + self.n_blocks = config["n_dec_lay"] + self.dec_c = config["dec_c"] + self.dec_ks = config["dec_ks"] + self.dec_str = config["dec_str"] + self.dec_pad = config["dec_pad"] + self.dec_dil = config["dec_dil"] + self.conv_dim = config["conv_dim"] + self.skip_z = config["skip_z"] + self.nn_conv_transpose = NN_CONV_TRANSPOSE[self.conv_dim] + self.nn_batch_norm = NN_BATCH_NORM[self.conv_dim] + + if self.latent_space_definition > 0: + self.latent_dim += 2 + if self.latent_space_definition == 1: + self.latent_dim = 11 + + # activation functions + self.leakyrelu = nn.LeakyReLU(0.2) + self.relu = nn.ReLU() + self.sigmoid = nn.Sigmoid() + + # decoder - layers in reverse order + conv_transpose_recon, required_in_shape_r = self.end_block( + out_shape=self.img_shape, dec_c_factor=2 ** (self.n_blocks-1)) + conv_transpose_scale, required_in_shape_s = self.end_block( + out_shape=self.img_shape, dec_c_factor=2 ** (self.n_blocks-1)) + + self.conv_transpose_recon = conv_transpose_recon + self.conv_transpose_scale = conv_transpose_scale + + if np.all(required_in_shape_r != required_in_shape_s): + raise ValueError + required_in_shape = required_in_shape_r + + blocks_reverse = torch.nn.ModuleList() + for i in reversed(range(self.n_blocks-1)): + dec_c_factor = 2 ** i + + batch_norm, conv_tranpose, in_shape = self.block( + out_shape=required_in_shape, + dec_c_factor=dec_c_factor) + shape_h = reduce(operator.mul, required_in_shape, 1) + w_z = nn.Linear(self.latent_dim, shape_h, bias=False) + + blocks_reverse.append(w_z) + blocks_reverse.append(batch_norm) + blocks_reverse.append(conv_tranpose) + + required_in_shape = in_shape + + self.blocks = blocks_reverse[::-1] + self.in_shape = required_in_shape + + self.fcs_infeatures = functools.reduce( + (lambda x, y: x * y), self.in_shape) + + self.l0 = nn.Linear( + in_features=self.latent_dim, out_features=self.fcs_infeatures) + + def forward(self, z): + """ + Compute the passage through the neural network + + Parameters + ---------- + z : tensor, latent space. + + Returns + ------- + recon : tensor, image or voxel. + scale_b: tensor, image or voxel. + + """ + h1 = self.relu(self.l0(z)) + h = h1.view((-1,) + self.in_shape) + + for i in range(self.n_blocks-1): + h = self.blocks[3*i](h) + h = self.blocks[3*i+1](h) + if self.skip_z: + z1 = self.blocks[3*i+2](z).reshape(h.shape) + h = self.leakyrelu(h+z1) + + recon = self.conv_transpose_recon(h) + scale_b = self.conv_transpose_scale(h) + + if self.with_sigmoid: + recon = self.sigmoid(recon) + return recon, scale_b + + +class VaeConv(nn.Module): + """This class compute the VAE""" + + def __init__(self, config): + """ + Initialization of the VAE. + + Parameters + ---------- + config : dic, principal constants to build a encoder. + + Returns + ------- + None. + + """ + super(VaeConv, self).__init__() + self.config = config + self.latent_dim = config["latent_dim"] + self.img_shape = config["img_shape"] + self.conv_dim = config["conv_dim"] + self.with_sigmoid = config["with_sigmoid"] + self.n_encoder_blocks = config["n_enc_lay"] + self.n_decoder_blocks = config["n_dec_lay"] + + self.encoder = EncoderConv( + config) + + self.decoder = DecoderConv( + config) + + def forward(self, x): + """ + Compute the passage through the neural network + + Parameters + ---------- + x : tensor, image or voxel. + + Returns + ------- + res : tensor, image or voxel. + scale_b: tensor, image or voxel. + mu : tensor, latent space mu. + logvar : tensor, latent space sigma. + + """ + mu, logvar, z, matrix = self.encoder(x) + res, scale_b = self.decoder(z) + return res, scale_b, mu, logvar, z + + +def reparametrize(mu, logvar, n_samples=1): + """ + Transform the probabilistic latent space into a deterministic latent space + + Parameters + ---------- + mu : tensor, latent space mu. + logvar : tensor, latent space sigma. + n_samples : int, number of samples. + + Returns + ------- + z_flat : tensor, deterministic latent space + + """ + n_batch_data, latent_dim = mu.shape + + std = logvar.mul(0.5).exp_() + std_expanded = std.expand( + n_samples, n_batch_data, latent_dim) + mu_expanded = mu.expand( + n_samples, n_batch_data, latent_dim) + + if CUDA: + eps = torch.cuda.FloatTensor( + n_samples, n_batch_data, latent_dim).normal_() + else: + eps = torch.FloatTensor(n_samples, n_batch_data, latent_dim).normal_() + eps = torch.autograd.Variable(eps) + + z = eps * std_expanded + mu_expanded + z_flat = z.reshape(n_samples * n_batch_data, latent_dim) + z_flat = z_flat.squeeze(dim=1) + return z_flat + + +def sample_from_q(mu, logvar, n_samples=1): + """ + Transform a probabilistic latent space into a deterministic latent space + + Parameters + ---------- + mu : tensor, latent space mu. + logvar : tensor, latent space sigma. + n_samples : int, number of samples. + + Returns + ------- + tensor, deterministic latent space. + + """ + return reparametrize(mu, logvar, n_samples) + + +def sample_from_prior(latent_dim, n_samples=1): + """ + Transform a probabilistic latent space into a deterministic latent. + + Parameters + ---------- + latent_dim : int, latent dimension. + n_samples : int, optional, number of sample. + + Returns + ------- + tensor, deterministic latent space. + + """ + if CUDA: + mu = torch.cuda.FloatTensor(n_samples, latent_dim).fill_(0) + logvar = torch.cuda.FloatTensor(n_samples, latent_dim).fill_(0) + else: + mu = torch.zeros(n_samples, latent_dim) + logvar = torch.zeros(n_samples, latent_dim) + return reparametrize(mu, logvar) + + +class Discriminator(nn.Module): + """This class compute the GAN""" + + def dis_conv_output_size(self, in_shape, out_channels): + """ + Compute the output shape by knowing the input shape of a layer. + + Parameters + ---------- + in_shape : tuple, shape of the input of the layer. + out_channels : int, number of output channels. + Returns + ------- + tuple, shape of the output of the layer. + + """ + return conv_output_size( + in_shape, out_channels, + kernel_size=self.dis_ks, + stride=self.dis_str, + padding=self.dis_pad, + dilation=self.dis_dil) + + def __init__(self, config): + """ + Initialization of the GAN. + + Parameters + ---------- + config : dic, principal constants to build a encoder. + + Returns + ------- + None. + + """ + super(Discriminator, self).__init__() + self.img_shape = config["img_shape"] + self.dis_c = config["dis_c"] + self.dis_ks = config["dis_ks"] + self.dis_str = config["dis_str"] + self.dis_pad = config["dis_pad"] + self.dis_dil = config["dis_dil"] + self.n_blocks = config["n_gan_lay"] + self.config = config + self.batch_size = config["batch_size"] + self.latent_dim = config["latent_dim"] + self.img_shape = config["img_shape"] + self.conv_dim = config["conv_dim"] + self.nn_conv = NN_CONV[self.conv_dim] + self.nn_batch_norm = NN_BATCH_NORM[self.conv_dim] + self.batch_size = config["batch_size"] + + # activation functions + self.leakyrelu = nn.LeakyReLU(0.2) + self.relu = nn.ReLU() + self.sigmoid = nn.Sigmoid() + + self.blocks = torch.nn.ModuleList() + self.dis_out_shape = self.img_shape + next_in_channels = self.img_shape[0] + for i in range(self.n_blocks): + dis_c_factor = 2 ** i + dis = self.nn_conv( + in_channels=next_in_channels, + out_channels=self.dis_c * dis_c_factor, + kernel_size=self.dis_ks, + stride=self.dis_str, + padding=self.dis_pad) + bn = self.nn_batch_norm(dis.out_channels) + self.blocks.append(dis) + self.blocks.append(bn) + next_in_channels = dis.out_channels + self.dis_out_shape = self.dis_conv_output_size( + in_shape=self.dis_out_shape, + out_channels=next_in_channels) + + self.fcs_infeatures = functools.reduce( + (lambda x, y: x * y), self.dis_out_shape) + self.fc1 = nn.Linear(in_features=self.fcs_infeatures, + out_features=1) + + def forward(self, x): + """ + Forward pass of the discriminator is to take an image + and output probability of the image being generated by the prior + versus the learned approximation of the posterior. + Parameters + ---------- + x : tensor, image or voxel. + + Returns + ------- + prob: float, between 0 and 1 the probability of being a true image. + """ + h = x + for i in range(self.n_blocks): + h = self.leakyrelu(self.blocks[2*i+1](self.blocks[2*i](h))) + h = h.view(-1, self.fcs_infeatures) + h_feature = self.fc1(h) + prob = self.sigmoid(h_feature) + prob = prob.view(-1, 1) + + return prob, 0, 0 diff --git a/models/reparameterize.py b/models/reparameterize.py new file mode 100644 index 0000000..ed5fab8 --- /dev/null +++ b/models/reparameterize.py @@ -0,0 +1,331 @@ +import math +import numpy as np +import os +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.distributions import Normal + +from lie_tools import logsumexp, rot_mat_wigner, rodrigues,\ + quaternions_to_group_matrix, s2s1rodrigues, s2s2_gram_schmidt + +os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" + + +class Nreparameterize(nn.Module): + """Reparametrize Gaussian variable.""" + + def __init__(self, input_dim, z_dim): + super().__init__() + + self.input_dim = input_dim + self.z_dim = z_dim + self.sigma_linear = nn.Linear(input_dim, z_dim) + self.mu_linear = nn.Linear(input_dim, z_dim) + self.return_means = False + + self.mu, self.sigma, self.z = None, None, None + + def forward(self, x, n=1): + self.mu = self.mu_linear(x) + self.sigma = F.softplus(self.sigma_linear(x)) + self.z = self.nsample(n=n) + return self.z + + def kl(self): + return -0.5 * torch.sum(1 + 2 * self.sigma.log() - self.mu.pow(2) - self.sigma ** 2, -1) + + def log_posterior(self): + return self._log_posterior(self.z) + + def _log_posterior(self, z): + return Normal(self.mu, self.sigma).log_prob(z).sum(-1) + + def log_prior(self): + return Normal(torch.zeros_like(self.mu), torch.ones_like(self.sigma)).log_prob(self.z).sum(-1) + + def nsample(self, n=1): + if self.return_means: + return self.mu.expand(n, -1, -1) + eps = Normal(torch.zeros_like(self.mu), + torch.ones_like(self.mu)).sample((n,)) + return self.mu + eps * self.sigma + + def deterministic(self): + """Set to return means.""" + self.return_means = True + + +class N0reparameterize(nn.Module): + """Reparametrize zero mean Gaussian Variable.""" + + def __init__(self, input_dim, z_dim, fixed_sigma=None): + super().__init__() + + self.input_dim = input_dim + self.z_dim = z_dim + self.sigma_linear = nn.Linear(input_dim, z_dim) + self.return_means = False + if fixed_sigma is not None: + self.register_buffer('fixed_sigma', torch.tensor(fixed_sigma)) + else: + self.fixed_sigma = None + + self.sigma = None + self.z = None + + def forward(self, x, n=1): + if self.fixed_sigma is not None: + self.sigma = x.new_full((x.shape[0], self.z_dim), self.fixed_sigma) + else: + self.sigma = F.softplus(self.sigma_linear(x)) + self.z = self.nsample(n=n) + return self.z + + def kl(self): + return -0.5 * torch.sum(1 + 2 * self.sigma.log() - self.sigma ** 2, -1) + + def log_posterior(self): + return self._log_posterior(self.z) + + def _log_posterior(self, z): + return Normal(torch.zeros_like(self.sigma), self.sigma).log_prob(z).sum(-1) + + def log_prior(self): + return Normal(torch.zeros_like(self.sigma), torch.ones_like(self.sigma)).log_prob(self.z).sum(-1) + + def nsample(self, n=1): + if self.return_means: + return torch.zeros_like(self.sigma).expand(n, -1, -1) + eps = Normal(torch.zeros_like(self.sigma), + torch.ones_like(self.sigma)).sample((n,)) + return eps * self.sigma + + def deterministic(self): + """Set to return means.""" + self.return_means = True + + +class AlgebraMean(nn.Module): + """Module to map R^3 -> SO(3) with Algebra method.""" + + def __init__(self, input_dims): + super().__init__() + self.map = nn.Linear(input_dims, 3) + + def forward(self, x): + return rodrigues(self.map(x)) + + +class QuaternionMean(nn.Module): + def __init__(self, input_dims): + super().__init__() + self.map = nn.Linear(input_dims, 4) + + def forward(self, x): + return quaternions_to_group_matrix(self.map(x)) + + +class S2S1Mean(nn.Module): + """Module to map R^5 -> SO(3) with S2S1 method.""" + + def __init__(self, input_dims): + super().__init__() + self.s2_map = nn.Linear(input_dims, 3) + self.s1_map = nn.Linear(input_dims, 2) + + def forward(self, x): + s2_el = self.s2_map(x) + s2_el = s2_el/s2_el.norm(p=2, dim=-1, keepdim=True) + + s1_el = self.s1_map(x) + s1_el = s1_el/s1_el.norm(p=2, dim=-1, keepdim=True) + + return s2s1rodrigues(s2_el, s1_el) + + +class S2S2Mean(nn.Module): + """Module to map R^6 -> SO(3) with S2S2 method.""" + + def __init__(self, input_dims): + super().__init__() + self.map = nn.Linear(input_dims, 6) + + # Start with big outputs + self.map.weight.data.uniform_(-10, 10) + self.map.bias.data.uniform_(-10, 10) + + def forward(self, x): + v = self.map(x).double().view(-1, 2, 3) + v1, v2 = v[:, 0], v[:, 1] + return s2s2_gram_schmidt(v1, v2).float() + + +class EulerYzyMean(nn.Module): + """Module to map R^3 -> SO(3) with Euler yzy method.""" + + def __init__(self, input_dims): + super().__init__() + self.map = nn.Linear(input_dims, 3) + + # Start with big outputs + self.map.weight.data.uniform_(-10, 10) + self.map.bias.data.uniform_(-10, 10) + + def forward(self, x): + mu_euler = self.map(x) + return rot_mat_tensor(mu_euler, 3) + + +class ThetaMean(nn.Module): + def __init__(self, input_dims): + super().__init__() + # Start with big outputs + self.map = nn.Linear(input_dims, 1) + self.map.weight.data.uniform_(-10, 10) + self.map.bias.data.uniform_(-10, 10) + + def forward(self, x): + mu_theta = self.map(x) + return mu_theta + + +class RLMean(nn.Module): + def __init__(self, input_dims, output_dims): + super().__init__() + # Start with big outputs + self.map = nn.Linear(input_dims, output_dims) + self.map.weight.data.uniform_(-10, 10) + self.map.bias.data.uniform_(-10, 10) + + def forward(self, x): + mu_rl = self.map(x) + return mu_rl + + +class VectorMean(nn.Module): + def __init__(self, input_dims): + super().__init__() + # Start with big outputs + self.map = nn.Linear(input_dims, 2) + self.map.weight.data.uniform_(-10, 10) + self.map.bias.data.uniform_(-10, 10) + + def forward(self, x): + vec = self.map(x) + norm_vec = torch.nn.functional.normalize(vec) + #mu_theta = torch.atan2(norm_vec[:, 1], norm_vec[:, 0]) + return norm_vec # mu_theta + + +class SO2reparameterize(nn.Module): + + def __init__(self, reparameterize, mean_module): + super().__init__() + + self.mean_module = mean_module + self.reparameterize = reparameterize + self.input_dim = self.reparameterize.input_dim + self.return_means = False + self.mu_lie, self.v, self.z = None, None, None + + def forward(self, x, n=1): + self.mu_lie = self.mean_module(x) + self.v = self.reparameterize(x, n) + self.z = self.nsample(n=n) + return self.z + + def nsample(self, n=1): + if self.return_means: + return self.mu_lie.expand(n, *[-1]*len(self.mu_lie.shape)) + return self.mu_lie + self.v + + def deterministic(self): + """Set to return means.""" + self.return_means = True + self.reparameterize.deterministic() + + def kl(self): + return -0.5 * torch.sum(1 + self.v - self.mu_lie.pow(2) - self.v.exp()) + + +class SO3reparameterize(nn.Module): + """Reparametrize SO(3) latent variable. + It uses an inner zero mean Gaussian reparametrization module, which it + exp-maps to a identity centered random SO(3) variable. The mean_module + deterministically outputs a mean. + """ + + def __init__(self, reparameterize, mean_module, k=10): + super().__init__() + + self.mean_module = mean_module + self.reparameterize = reparameterize + self.input_dim = self.reparameterize.input_dim + assert self.reparameterize.z_dim == 3 + self.k = k + self.return_means = False + + self.mu_lie, self.v, self.z = None, None, None + + def forward(self, x, n=1): + self.mu_lie = self.mean_module(x) + self.v = self.reparameterize(x, n) + self.z = self.nsample(n=n) + return self.z + + def kl(self): + log_q_z_x = self.log_posterior() + log_p_z = self.log_prior() + kl = log_q_z_x - log_p_z + return kl.mean(0) + + def log_posterior(self): + theta = self.v.norm(p=2, dim=-1, keepdim=True) # [n,B,1] + u = self.v / theta # [n,B,3] + + angles = 2 * math.pi * torch.arange( + -self.k, self.k+1, device=u.device, dtype=self.v.dtype) # [2k+1] + + theta_hat = theta[..., None, :] + angles[:, None] # [n,B,2k+1,1] + + clamp = 1e-3 + x = u[..., None, :] * theta_hat # [n,B,2k+1,3] + + # [n,(2k+1),B,3] or [n,(2k+1),B] + log_p = self.reparameterize._log_posterior( + x.permute([0, 2, 1, 3]).contiguous()) + + if len(log_p.size()) == 4: + log_p = log_p.sum(-1) # [n,(2k+1),B] + + log_p = log_p.permute([0, 2, 1]) # [n,B,(2k+1)] + + theta_hat_squared = torch.clamp(theta_hat ** 2, min=clamp) + + log_p.contiguous() + cos_theta_hat = torch.cos(theta_hat) + + # [n,B,(2k+1),1] + log_vol = torch.log(theta_hat_squared / + torch.clamp(2 - 2 * cos_theta_hat, min=clamp)) + log_p = log_p + log_vol.sum(-1) + log_p = logsumexp(log_p, -1) + + return log_p + + def log_prior(self): + prior = torch.tensor([- np.log(8 * (np.pi ** 2))], + device=self.z.device) + return prior.expand_as(self.z[..., 0, 0]) + + def nsample(self, n=1): + if self.return_means: + return self.mu_lie.expand(n, *[-1]*len(self.mu_lie.shape)) + v_lie = rodrigues(self.v) + return self.mu_lie @ v_lie + + def deterministic(self): + """Set to return means.""" + self.return_means = True + self.reparameterize.deterministic() diff --git a/models/train_utils.py b/models/train_utils.py new file mode 100644 index 0000000..55aa710 --- /dev/null +++ b/models/train_utils.py @@ -0,0 +1,440 @@ +"""Utils to factorize code for learning and visualization.""" + +import glob +import logging +import os +import torch +import torch.nn as tnn +from scipy.spatial.transform import Rotation as R + +import nn +import dataset as ds + +CUDA = torch.cuda.is_available() +DEVICE = torch.device("cuda" if CUDA else "cpu") + +CKPT_PERIOD = 1 + +W_INIT, B_INIT, NONLINEARITY_INIT = ( + {0: [[1.0], [0.0]], + 1: [[1.0, 0.0], [0.0, 1.0]]}, + {0: [0.0, 0.0], + 1: [0.01935, -0.02904]}, + 'softplus') + + +def init_xavier_normal(m): + """ + Initiate weigth of a Neural Network with xavier weigth. + + Parameters + ---------- + m : Neural Network. + + Returns + ------- + None. + + """ + if type(m) is tnn.Linear: + tnn.init.xavier_normal_(m.weight) + if type(m) is tnn.Conv2d: + tnn.init.xavier_normal_(m.weight) + + +def init_kaiming_normal(m): + """ + Initiate weigth of a Neural Network with kaiming weigth. + + Parameters + ---------- + m : Neural Network. + + Returns + ------- + None. + + """ + if type(m) is tnn.Linear: + tnn.init.kaiming_normal_(m.weight) + if type(m) is tnn.Conv2d: + tnn.init.kaiming_normal_(m.weight) + + +def init_custom(m): + """ + Initiate weigth of a Neural Network with own custom weigth. + + Parameters + ---------- + m : Neural Network. + + Returns + ------- + None. + + """ + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + m.weight.data.normal_(0.0, 0.02) + elif classname.find('BatchNorm') != -1: + m.weight.data.normal_(1.0, 0.02) + m.bias.data.fill_(0) + + +def init_function(weights_init='xavier'): + """ + Choose the function to initialize the weight of NN. + + Parameters + ---------- + weights_init : string, optional, initiate weights. + + Raises + ------ + NotImplementedError + DESCRIPTION. + + Returns + ------- + function depending on initialization goal. + + """ + if weights_init == 'xavier': + return init_xavier_normal + if weights_init == 'kaiming': + return init_kaiming_normal + if weights_init == 'custom': + return init_custom + raise NotImplementedError( + "This weight initialization is not implemented.") + + +def init_modules_and_optimizers(train_params, config): + """ + Initialization of the different modules and optimizer of the NN. + + Parameters + ---------- + train_params : dic, meta parameters for the NN. + config : dic, meta parameters for the NN. + + Returns + ------- + modules : dic, dic of modules encoder and decoder and gan of the NN. + optimizers : dic, dic of optimizer of the NN. + + """ + modules = {} + optimizers = {} + lr = train_params['lr'] + beta1 = train_params['beta1'] + beta2 = train_params['beta2'] + vae = nn.VaeConv(config).to(DEVICE) + + modules['encoder'] = vae.encoder + modules['decoder'] = vae.decoder + + if 'adversarial' in train_params['reconstructions']: + discriminator = nn.Discriminator(config).to(DEVICE) + modules['discriminator_reconstruction'] = discriminator + + if 'adversarial' in train_params['regularizations']: + discriminator = nn.Discriminator(config).to(DEVICE) + modules['discriminator_regularization'] = discriminator + + # Optimizers + optimizers['encoder'] = torch.optim.Adam( + modules['encoder'].parameters(), lr=lr, betas=(beta1, beta2)) + optimizers['decoder'] = torch.optim.Adam( + modules['decoder'].parameters(), lr=lr, betas=(beta1, beta2)) + + if 'adversarial' in train_params['reconstructions']: + optimizers['discriminator_reconstruction'] = torch.optim.Adam( + modules['discriminator_reconstruction'].parameters(), + lr=train_params['lr'], + betas=(train_params['beta1'], train_params['beta2'])) + + if 'adversarial' in train_params['regularizations']: + optimizers['discriminator_regularization'] = torch.optim.Adam( + modules['discriminator_regularization'].parameters(), + lr=train_params['lr'], + betas=(train_params['beta1'], train_params['beta2'])) + + return modules, optimizers + + +def init_training(train_dir, train_params, config): + """ + Initialization; Load ckpts or init. + + Parameters + ---------- + train_dir : string, dir where to save the modules. + train_params : dic, meta parameters for the NN. + config : dic, meta parameters for the NN. + + Returns + ------- + modules : dic, dic of modules encoder and decoder and gan of the NN. + optimizers : dic, dic of optimizer of the NN. + start_epoch : int, the number of epoch the NN has already done. + train_losses_all_epochs : list, value of the train_loss for every epoch. + val_losses_all_epochs : list, value of the val_loss for every epoch. + """ + start_epoch = 0 + train_losses_all_epochs = [] + val_losses_all_epochs = [] + + modules, optimizers = init_modules_and_optimizers( + train_params, config) + + path_base = os.path.join(train_dir, 'epoch_*_checkpoint.pth') + ckpts = glob.glob(path_base) + if len(ckpts) == 0: + weights_init = train_params['weights_init'] + logging.info( + "No checkpoints found. Initializing with %s.", weights_init) + + for module_name, module in modules.items(): + module.apply(init_function(weights_init)) + + else: + ckpts_ids_and_paths = [ + (int(f.split('_')[-2]), f) for f in ckpts] + _, ckpt_path = max( + ckpts_ids_and_paths, key=lambda item: item[0]) + logging.info("Found checkpoints. Initializing with %s.", ckpt_path) + if torch.cuda.is_available(): + def map_location(storage): return storage.cuda() + else: + map_location = 'cpu' + ckpt = torch.load(ckpt_path, map_location=map_location) + # ckpt = torch.load(ckpt_path, map_location=DEVICE) + for module_name in modules: + module = modules[module_name] + optimizer = optimizers[module_name] + module_ckpt = ckpt[module_name] + module.load_state_dict(module_ckpt['module_state_dict']) + optimizer.load_state_dict( + module_ckpt['optimizer_state_dict']) + start_epoch = ckpt['epoch'] + 1 + train_losses_all_epochs = ckpt['train_losses'] + val_losses_all_epochs = ckpt['val_losses'] + + return (modules, optimizers, start_epoch, + train_losses_all_epochs, val_losses_all_epochs) + + +def save_checkpoint(epoch, modules, optimizers, dir_path, + train_losses_all_epochs, val_losses_all_epochs, + nn_architecture, train_params): + """ + Save NN's weights at a precise epoch. + + Parameters + ---------- + epoch : int, current epoch. + modules : dic, dic of modules encoder and decoder and gan of the NN. + optimizers : dic, dic of optimizer of the NN. + dir_path : string, dir where to save modules + train_losses_all_epochs : list, value of the train_loss for every epoch. + val_losses_all_epochs : list, value of the val_loss for every epoch. + nn_architecture : dic, meta parameters for the NN. + train_params : dic, meta parameters for the NN. + + Returns + ------- + None. + + """ + checkpoint = {} + for module_name in modules.keys(): + module = modules[module_name] + optimizer = optimizers[module_name] + checkpoint[module_name] = { + 'module_state_dict': module.state_dict(), + 'optimizer_state_dict': optimizer.state_dict()} + checkpoint['epoch'] = epoch + checkpoint['train_losses'] = train_losses_all_epochs + checkpoint['val_losses'] = val_losses_all_epochs + checkpoint['nn_architecture'] = nn_architecture + checkpoint['train_params'] = train_params + checkpoint_path = os.path.join( + dir_path, 'epoch_%d_checkpoint.pth' % epoch) + torch.save(checkpoint, checkpoint_path) + + +def load_checkpoint(output, epoch_id=None): + """ + Loads a NN and all information about it at one expecting stage + of the learning + + Parameters + ---------- + output : string, dir where a NN has been saved. + epoch_id : int, optional. The default is None. + + Raises + ------ + ValueError + DESCRIPTION. + + Returns + ------- + ckpt : NN ,just loaded NN network. + + """ + if epoch_id is None: + ckpts = glob.glob( + '%s/checkpoint_*/epoch_*_checkpoint.pth' % output) + if len(ckpts) == 0: + raise ValueError('No checkpoints found.') + ckpts_ids_and_paths = [(int(f.split('_')[-2]), f) for f in ckpts] + _, ckpt_path = max( + ckpts_ids_and_paths, key=lambda item: item[0]) + else: + # Load module corresponding to epoch_id + ckpt_path = f"{output}/checkpoint_{epoch_id:0>6d}/" + \ + "epoch_{epoch_id}_checkpoint.pth" + + print(ckpt_path) + if not os.path.isfile(ckpt_path): + raise ValueError( + 'No checkpoints found for epoch %d in output %s.' % ( + epoch_id, output)) + + print('Found checkpoint. Getting: %s.' % ckpt_path) + ckpt = torch.load(ckpt_path, map_location=DEVICE) + return ckpt + + +def load_module_state(output, module, module_name, epoch_id=None): + """ + Affects weights of the considered epoch_id to NN's weights. + + Parameters + ---------- + output : string, dir where to find the NN. + module : NN, NN with initialized weight. + module_name : string, name of the considered module + epoch_id : int, optional. Epoch we are interested in. The default is None. + + Returns + ------- + module : NN, NN with the weight of the NN after the epoch_id. + + """ + ckpt = load_checkpoint( + output=output, epoch_id=epoch_id) + + module_ckpt = ckpt[module_name] + module.load_state_dict(module_ckpt['module_state_dict']) + + return module + + +def get_under_dic_cons(const, list_arg): + """ + Take a sub dictionnary of a dictionnary. + + Parameters + ---------- + const : dic. + list_arg : list, liste of keys you want to save. + + Returns + ------- + new_dic : dic, sub dic of const. + + """ + new_dic = {} + for key in list_arg: + if key in const: + new_dic[key] = const[key] + return new_dic + + +def quaternion_to_euler(labels): + """ + Transform the quaternion representation of rotation in zyx euler + representation. + + Parameters + ---------- + labels : dataframe, description of the orientation of each image. + + Returns + ------- + liste : list, liste of triples describibg the rotation with the zyx + euler angles. + + """ + n = len(labels) + liste = [] + for i in range(n): + A = labels['rotation_quaternion'].iloc[i].replace(' ]', ']') + A = A.replace(' ', ' ') + A = A.replace(' ', ' ') + A = A.replace(' ', ' ') + A = A.replace(' ]', ']') + A = A.replace(' ', ' ') + A = A[1:-1].split(' ') + B = list(map(float, A)) + r = R.from_quat(B) + liste.append(r.as_euler('zyx', degrees=True)) + return liste + + +def load_module(output, module_name='encoder', epoch_id=None): + """ + Affects weights of the considered epoch_id to NN's weights. + + Parameters + ---------- + output : string, dir where to find the NN. + module : NN, NN with initialized weight. + module_name : string, name of the considered module + epoch_id : int, optional. Epoch we are interested in. The default is None. + + Returns + ------- + module : NN, NN with the weight of the NN after the epoch_id. + + """ + ckpt = load_checkpoint( + output=output, epoch_id=epoch_id) + nn_architecture = ckpt['nn_architecture'] + nn_architecture['conv_dim'] = ds.CONSTANTS['conv_dim'] + nn_architecture.update(get_under_dic_cons( + ds.CONSTANTS, ds.META_PARAM_NAMES)) + nn_type = nn_architecture['nn_type'] + print('Loading %s from network of architecture: %s...' % ( + module_name, nn_type)) + vae = nn.VaeConv(nn_architecture) + modules = {} + modules['encoder'] = vae.encoder + modules['decoder'] = vae.decoder + module = modules[module_name].to(DEVICE) + module_ckpt = ckpt[module_name] + module.load_state_dict(module_ckpt['module_state_dict']) + + return module + + +def get_logging_shape(tensor): + """ + Convert shape of a tensor into a string. + + Parameters + ---------- + tensor : tensor. + + Returns + ------- + logging_shape : string, shape of the tensor. + + """ + shape = tensor.shape + logging_shape = '(' + ('%s, ' * len(shape) % tuple(shape))[:-2] + ')' + return logging_shape diff --git a/models/vae_parameters.json b/models/vae_parameters.json new file mode 100644 index 0000000..be1269e --- /dev/null +++ b/models/vae_parameters.json @@ -0,0 +1,77 @@ +{ + "paths": + { + "simulated_2d": "cryo_sim_128x128.npy", + "real_2d30": "class2D_30_sort.h5", + "real_2d39": "class2D_39_sort.h5", + "real_2d93": "class2D_93_sort.h5", + "refine_3D": "refine3D_180x180_sort.h5", + "class_3d9090": "class3D_90x90_sort.h5", + "simulated_3d": "concat_simulated.npy", + "simulated_3d_noise": "cryo_sim_128x128.npy", + "real_3d": "data.hdf5", + "4points": "4points.npy", + "4points1": "4points1.npy", + "4points_3d": "3d_images.npy" + }, + + "shape": + { + "simulated_3d": {"items": [1, 320, 320],"__tuple__":true}, + "simulated_2d": {"items": [1, 128, 128],"__tuple__":true}, + "4points": {"items": [1, 128, 128],"__tuple__":true}, + "4points1": {"items": [1, 64, 64],"__tuple__":true}, + "4points_3d": {"items": [64, 64, 64],"__tuple__":true} + }, + + "constants": + { + "img_shape": {"items": [1, 128, 128],"__tuple__":true}, + "with_sigmoid": "true", + "out_channels": "32, 64", + "is_3d": false, + "enc_ks": 4, + "enc_str": 2, + "enc_pad": 1, + "enc_dil": 1, + "enc_c": 1, + "dec_ks": 3, + "dec_str": 1, + "dec_pad": 0, + "dec_dil": 1, + "dec_c": 1, + "dis_ks": 4, + "dis_str": 2, + "dis_pad":1, + "dis_dil": 1, + "dis_c": 1, + "regularizations": ["kullbackleibler"], + "class_2d": 39, + "weights_init": "xavier", + "nn_type": "conv", + "beta1": 0.9, + "beta2": 0.999, + "frac_val": 0.2, + "bce": true, + "reconstructions": {"items": ["bce_on_intensities", "adversarial"], "__tuple__":true}, + "skip_z": false + }, + + "search_space": + { + "n_enc_lay": 2, + "n_dec_lay":2, + "latent_dim": 3, + "batch_size": 20, + "adversarial": false, + "n_gan_lay": 3, + "lr": 0.001, + "regu_factor": 0.003, + "lambda_regu": 0.2, + "lambda_adv": 0.2, + "reconstructions": {"items":["bce_on_intensities", "adversarial"],"__tuple__":true} + }, + + "meta_param_names":["enc_ks","enc_str", "enc_pad", "enc_dil", "enc_c", "dec_ks", "dec_str","dec_pad", "dec_dil", "dec_c", "dis_ks", "dis_str", "dis_pad","dis_dil", "dis_c"] + +}