From 810506b1a142da49b3cc7eddcc4bb32856d5e51c Mon Sep 17 00:00:00 2001 From: Alec Graves Date: Fri, 17 May 2019 14:10:38 -0400 Subject: [PATCH] refactor SampleLayer into its own module --- bvae/ae.py | 15 ++++--- bvae/model_utils.py | 101 ++--------------------------------------- bvae/models.py | 105 ++++++++++++++++++++----------------------- bvae/sample_layer.py | 94 ++++++++++++++++++++++++++++++++++++++ 4 files changed, 154 insertions(+), 161 deletions(-) create mode 100644 bvae/sample_layer.py diff --git a/bvae/ae.py b/bvae/ae.py index 7f04821..6f8c388 100644 --- a/bvae/ae.py +++ b/bvae/ae.py @@ -31,17 +31,20 @@ def test(): batchSize = 8 latentSize = 100 - img = load_img(os.path.join('..','images', 'img.jpg'), target_size=inputShape[:-1]) + img = load_img(os.path.join(os.path.dirname(__file__), '..','images', 'img.jpg'), target_size=inputShape[:-1]) img.show() - img = np.array(img, dtype=np.float32) / 255 - 0.5 + img = np.array(img, dtype=np.float32) * (2/255) - 1 +# print(np.min(img)) +# print(np.max(img)) +# print(np.mean(img)) + img = np.array([img]*batchSize) # make fake batches to improve GPU utilization # This is how you build the autoencoder - encoder = Darknet19Encoder(inputShape, latentSize=latentSize, latentConstraints='bvae', beta=69, capacity=15, randomSample=True) + encoder = Darknet19Encoder(inputShape, latentSize=latentSize, latentConstraints='bvae', beta=69) decoder = Darknet19Decoder(inputShape, latentSize=latentSize) bvae = AutoEncoder(encoder, decoder) - bvae.ae.compile(optimizer='adam', loss='mean_absolute_error') while True: bvae.ae.fit(img, img, @@ -53,9 +56,7 @@ def test(): print(latentVec) pred = bvae.ae.predict(img) # get the reconstructed image - pred[pred > 0.5] = 0.5 # clean it up a bit - pred[pred < -0.5] = -0.5 - pred = np.uint8((pred + 0.5)* 255) # convert to regular image values + pred = np.uint8((pred + 1)* 255/2) # convert to regular image values pred = Image.fromarray(pred[0]) pred.show() # display popup diff --git a/bvae/model_utils.py b/bvae/model_utils.py index bec7fa4..c33a230 100644 --- a/bvae/model_utils.py +++ b/bvae/model_utils.py @@ -1,6 +1,6 @@ ''' model_utils.py -contains custom layers, etc. for building mdoels. +contains custom blocks, etc. for building mdoels. created by shadySource @@ -10,7 +10,6 @@ from tensorflow.python.keras.layers import (InputLayer, Conv2D, Conv2DTranspose, BatchNormalization, LeakyReLU, MaxPool2D, UpSampling2D, Reshape, GlobalAveragePooling2D, Layer) -from tensorflow.python.keras import backend as K class ConvBnLRelu(object): def __init__(self, filters, kernelSize, strides=1): @@ -18,103 +17,9 @@ def __init__(self, filters, kernelSize, strides=1): self.kernelSize = kernelSize self.strides = strides # return conv + bn + leaky_relu model - def __call__(self, net): + def __call__(self, net, training=None): net = Conv2D(self.filters, self.kernelSize, strides=self.strides, padding='same')(net) - net = BatchNormalization()(net) + net = BatchNormalization()(net, training=training) net = LeakyReLU()(net) return net -class SampleLayer(Layer): - ''' - Keras Layer to grab a random sample from a distribution (by multiplication) - Computes "(normal)*stddev + mean" for the vae sampling operation - (written for tf backend) - - Additionally, - Applies regularization to the latent space representation. - Can perform standard regularization or B-VAE regularization. - - call: - pass in mean then stddev layers to sample from the distribution - ex. - sample = SampleLayer('bvae', 16)([mean, stddev]) - ''' - def __init__(self, latent_regularizer='bvae', beta=100., capacity=0., randomSample=True, **kwargs): - ''' - args: - ------ - latent_regularizer : str - Either 'bvae', 'vae', or 'no' - Determines whether regularization is applied - to the latent space representation. - beta : float - beta > 1, used for 'bvae' latent_regularizer, - (Unused if 'bvae' not selected) - capacity : float - used for 'bvae' to try to break input down to a set number - of basis. (e.g. at 25, the network will try to use - 25 dimensions of the latent space) - (unused if 'bvae' not selected) - randomSample : bool - whether or not to use random sampling when selecting from distribution. - if false, the latent vector equals the mean, essentially turning this into a - standard autoencoder. - ------ - ex. - sample = SampleLayer('bvae', 16)([mean, stddev]) - ''' - self.reg = latent_regularizer - self.beta = beta - self.capacity = capacity - self.random = randomSample - if K.image_data_format() == "channels_last": - self.sum_axis = -1 - else: - self.sum_axis = 0 - - super(SampleLayer, self).__init__(**kwargs) - - def build(self, input_shape): - # save the shape for distribution sampling - - super(SampleLayer, self).build(input_shape) # needed for layers - - def call(self, x): - if len(x) != 2: - raise Exception('input layers must be a list: mean and stddev') - if len(x[0].shape) != 2 or len(x[1].shape) != 2: - raise Exception('input shape is not a vector [batchSize, latentSize]') - - mean = x[0] - stddev = K.abs(x[1]) - - # trick to allow setting batch at train/eval time - if mean.shape[0].value == None or stddev.shape[0].value == None: - return mean + 0*stddev - - if self.reg == 'bvae': - # kl divergence: - latent_loss = -0.5 * K.mean(K.sum(1 + stddev - - K.square(mean) - - K.exp(stddev), axis=self.sum_axis)) - # use beta to force less usage of vector space: - # also try to use dimensions of the space: - latent_loss = self.beta * K.abs(latent_loss) - self.add_loss(latent_loss, x) - elif self.reg == 'vae': - # kl divergence: - latent_loss = -0.5 * K.mean(K.sum(1 + stddev - - K.square(mean) - - K.exp(stddev), axis=self.sum_axis)) - self.add_loss(latent_loss, x) - - epsilon = K.random_normal(shape=stddev.shape, - mean=0., stddev=1.) - if self.random: - # 'reparameterization trick': - return mean + K.exp(stddev/2) * epsilon - else: # do not perform random sampling, simply grab the impulse value - return mean + 0*stddev # Keras needs the *0 so the gradinent is not None - - def compute_output_shape(self, input_shape): - return input_shape[0] \ No newline at end of file diff --git a/bvae/models.py b/bvae/models.py index 2d0addc..7689c6b 100644 --- a/bvae/models.py +++ b/bvae/models.py @@ -13,7 +13,8 @@ Reshape, GlobalAveragePooling2D) from tensorflow.python.keras.models import Model -from model_utils import ConvBnLRelu, SampleLayer +from model_utils import ConvBnLRelu +from sample_layer import SampleLayer class Architecture(object): ''' @@ -30,6 +31,8 @@ def __init__(self, inputShape=None, batchSize=None, latentSize=None): latentSize : int the number of dimensions in the two output distribution vectors - mean and std-deviation + latentSize : Bool or None + True forces resampling, False forces no resampling, None chooses based on K.learning_phase() ''' self.inputShape = inputShape self.batchSize = batchSize @@ -51,8 +54,7 @@ class Darknet19Encoder(Architecture): https://github.com/pjreddie/darknet/blob/master/cfg/darknet19.cfg ''' def __init__(self, inputShape=(256, 256, 3), batchSize=None, - latentSize=1000, latentConstraints='bvae', beta=100., capacity=0., - randomSample=True): + latentSize=1000, latentConstraints='bvae', beta=100., training=None): ''' params ------- @@ -63,53 +65,43 @@ def __init__(self, inputShape=(256, 256, 3), batchSize=None, beta : float beta > 1, used for 'bvae' latent_regularizer (Unused if 'bvae' not selected, default 100) - capacity : float - used for 'bvae' to try to break input down to a set number - of basis. (e.g. at 25, the network will try to use - 25 dimensions of the latent space) - (unused if 'bvae' not selected) - randomSample : bool - whether or not to use random sampling when selecting from distribution. - if false, the latent vector equals the mean, essentially turning this into a - standard autoencoder. ''' self.latentConstraints = latentConstraints self.beta = beta - self.latentCapacity = capacity - self.randomSample = randomSample + self.training=training super().__init__(inputShape, batchSize, latentSize) def Build(self): # create the input layer for feeding the netowrk inLayer = Input(self.inputShape, self.batchSize) - net = ConvBnLRelu(32, kernelSize=3)(inLayer) # 1 + net = ConvBnLRelu(32, kernelSize=3)(inLayer, training=self.training) # 1 net = MaxPool2D((2, 2), strides=(2, 2))(net) - net = ConvBnLRelu(64, kernelSize=3)(net) # 2 + net = ConvBnLRelu(64, kernelSize=3)(net, training=self.training) # 2 net = MaxPool2D((2, 2), strides=(2, 2))(net) - net = ConvBnLRelu(128, kernelSize=3)(net) # 3 - net = ConvBnLRelu(64, kernelSize=1)(net) # 4 - net = ConvBnLRelu(128, kernelSize=3)(net) # 5 + net = ConvBnLRelu(128, kernelSize=3)(net, training=self.training) # 3 + net = ConvBnLRelu(64, kernelSize=1)(net, training=self.training) # 4 + net = ConvBnLRelu(128, kernelSize=3)(net, training=self.training) # 5 net = MaxPool2D((2, 2), strides=(2, 2))(net) - net = ConvBnLRelu(256, kernelSize=3)(net) # 6 - net = ConvBnLRelu(128, kernelSize=1)(net) # 7 - net = ConvBnLRelu(256, kernelSize=3)(net) # 8 + net = ConvBnLRelu(256, kernelSize=3)(net, training=self.training) # 6 + net = ConvBnLRelu(128, kernelSize=1)(net, training=self.training) # 7 + net = ConvBnLRelu(256, kernelSize=3)(net, training=self.training) # 8 net = MaxPool2D((2, 2), strides=(2, 2))(net) - net = ConvBnLRelu(512, kernelSize=3)(net) # 9 - net = ConvBnLRelu(256, kernelSize=1)(net) # 10 - net = ConvBnLRelu(512, kernelSize=3)(net) # 11 - net = ConvBnLRelu(256, kernelSize=1)(net) # 12 - net = ConvBnLRelu(512, kernelSize=3)(net) # 13 + net = ConvBnLRelu(512, kernelSize=3)(net, training=self.training) # 9 + net = ConvBnLRelu(256, kernelSize=1)(net, training=self.training) # 10 + net = ConvBnLRelu(512, kernelSize=3)(net, training=self.training) # 11 + net = ConvBnLRelu(256, kernelSize=1)(net, training=self.training) # 12 + net = ConvBnLRelu(512, kernelSize=3)(net, training=self.training) # 13 net = MaxPool2D((2, 2), strides=(2, 2))(net) - net = ConvBnLRelu(1024, kernelSize=3)(net) # 14 - net = ConvBnLRelu(512, kernelSize=1)(net) # 15 - net = ConvBnLRelu(1024, kernelSize=3)(net) # 16 - net = ConvBnLRelu(512, kernelSize=1)(net) # 17 - net = ConvBnLRelu(1024, kernelSize=3)(net) # 18 + net = ConvBnLRelu(1024, kernelSize=3)(net, training=self.training) # 14 + net = ConvBnLRelu(512, kernelSize=1)(net, training=self.training) # 15 + net = ConvBnLRelu(1024, kernelSize=3)(net, training=self.training) # 16 + net = ConvBnLRelu(512, kernelSize=1)(net, training=self.training) # 17 + net = ConvBnLRelu(1024, kernelSize=3)(net, training=self.training) # 18 # variational encoder output (distributions) mean = Conv2D(filters=self.latentSize, kernel_size=(1, 1), @@ -119,13 +111,13 @@ def Build(self): padding='same')(net) stddev = GlobalAveragePooling2D()(stddev) - sample = SampleLayer(self.latentConstraints, self.beta, - self.latentCapacity, self.randomSample)([mean, stddev]) + sample = SampleLayer(self.latentConstraints, self.beta)([mean, stddev], training=self.training) return Model(inputs=inLayer, outputs=sample) class Darknet19Decoder(Architecture): - def __init__(self, inputShape=(256, 256, 3), batchSize=None, latentSize=1000): + def __init__(self, inputShape=(256, 256, 3), batchSize=None, latentSize=1000, training=None): + self.training=training super().__init__(inputShape, batchSize, latentSize) def Build(self): @@ -138,38 +130,39 @@ def Build(self): # TODO try inverting num filter arangement (e.g. 512, 1204, 512, 1024, 512) # and also try (1, 3, 1, 3, 1) for the filter shape - net = ConvBnLRelu(1024, kernelSize=3)(net) - net = ConvBnLRelu(512, kernelSize=1)(net) - net = ConvBnLRelu(1024, kernelSize=3)(net) - net = ConvBnLRelu(512, kernelSize=1)(net) - net = ConvBnLRelu(1024, kernelSize=3)(net) + net = ConvBnLRelu(1024, kernelSize=3)(net, training=self.training) + net = ConvBnLRelu(512, kernelSize=1)(net, training=self.training) + net = ConvBnLRelu(1024, kernelSize=3)(net, training=self.training) + net = ConvBnLRelu(512, kernelSize=1)(net, training=self.training) + net = ConvBnLRelu(1024, kernelSize=3)(net, training=self.training) net = UpSampling2D((2, 2))(net) - net = ConvBnLRelu(512, kernelSize=3)(net) - net = ConvBnLRelu(256, kernelSize=1)(net) - net = ConvBnLRelu(512, kernelSize=3)(net) - net = ConvBnLRelu(256, kernelSize=1)(net) - net = ConvBnLRelu(512, kernelSize=3)(net) + net = ConvBnLRelu(512, kernelSize=3)(net, training=self.training) + net = ConvBnLRelu(256, kernelSize=1)(net, training=self.training) + net = ConvBnLRelu(512, kernelSize=3)(net, training=self.training) + net = ConvBnLRelu(256, kernelSize=1)(net, training=self.training) + net = ConvBnLRelu(512, kernelSize=3)(net, training=self.training) net = UpSampling2D((2, 2))(net) - net = ConvBnLRelu(256, kernelSize=3)(net) - net = ConvBnLRelu(128, kernelSize=1)(net) - net = ConvBnLRelu(256, kernelSize=3)(net) + net = ConvBnLRelu(256, kernelSize=3)(net, training=self.training) + net = ConvBnLRelu(128, kernelSize=1)(net, training=self.training) + net = ConvBnLRelu(256, kernelSize=3)(net, training=self.training) net = UpSampling2D((2, 2))(net) - net = ConvBnLRelu(128, kernelSize=3)(net) - net = ConvBnLRelu(64, kernelSize=1)(net) - net = ConvBnLRelu(128, kernelSize=3)(net) + net = ConvBnLRelu(128, kernelSize=3)(net, training=self.training) + net = ConvBnLRelu(64, kernelSize=1)(net, training=self.training) + net = ConvBnLRelu(128, kernelSize=3)(net, training=self.training) net = UpSampling2D((2, 2))(net) - net = ConvBnLRelu(64, kernelSize=3)(net) + net = ConvBnLRelu(64, kernelSize=3)(net, training=self.training) net = UpSampling2D((2, 2))(net) - net = ConvBnLRelu(64, kernelSize=1)(net) - net = ConvBnLRelu(32, kernelSize=3)(net) - # net = ConvBnLRelu(3, kernelSize=1)(net) + net = ConvBnLRelu(32, kernelSize=3)(net, training=self.training) + net = ConvBnLRelu(64, kernelSize=1)(net, training=self.training) + + # net = ConvBnLRelu(3, kernelSize=1)(net, training=self.training) net = Conv2D(filters=self.inputShape[-1], kernel_size=(1, 1), - padding='same')(net) + padding='same', activation="tanh")(net) return Model(inLayer, net) diff --git a/bvae/sample_layer.py b/bvae/sample_layer.py new file mode 100644 index 0000000..42eb6c3 --- /dev/null +++ b/bvae/sample_layer.py @@ -0,0 +1,94 @@ +''' +sample_layer.py +contains keras SampleLayer for bvae + +created by shadySource + +THE UNLICENSE +''' + +from tensorflow.python.keras.layers import Layer +from tensorflow.python.keras import backend as K + + +class SampleLayer(Layer): + ''' + Keras Layer to grab a random sample from a distribution (by multiplication) + Computes "(normal)*logvar + mean" for the vae sampling operation + (written for tf backend) + + Additionally, + Applies regularization to the latent space representation. + Can perform standard regularization or B-VAE regularization. + + call: + pass in mean then logvar layers to sample from the distribution + ex. + sample = SampleLayer('bvae', 16)([mean, logvar]) + ''' + def __init__(self, latent_regularizer='bvae', beta=100., **kwargs): + ''' + args: + ------ + latent_regularizer : str + Either 'bvae', 'vae', or 'no' + Determines whether regularization is applied + to the latent space representation. + beta : float + beta > 1, used for 'bvae' latent_regularizer, + (Unused if 'bvae' not selected) + ------ + ex. + sample = SampleLayer('bvae', 16)([mean, logvar]) + ''' + if latent_regularizer.lower() in ['bvae', 'vae']: + self.reg = latent_regularizer + else: + self.reg = None + + if self.reg == 'bvae': + self.beta = beta + elif self.reg == 'vae': + self.beta = 1. + + super(SampleLayer, self).__init__(**kwargs) + + def build(self, input_shape): + # save the shape for distribution sampling + super(SampleLayer, self).build(input_shape) # needed for layers + + def call(self, x, training=None): + if len(x) != 2: + raise Exception('input layers must be a list: mean and logvar') + if len(x[0].shape) != 2 or len(x[1].shape) != 2: + raise Exception('input shape is not a vector [batchSize, latentSize]') + + mean = x[0] + logvar = x[1] + + # trick to allow setting batch at train/eval time + if mean.shape[0].value == None or logvar.shape[0].value == None: + return mean + 0*logvar # Keras needs the *0 so the gradinent is not None + + if self.reg is not None: + # kl divergence: + latent_loss = -0.5 * (1 + logvar + - K.square(mean) + - K.exp(logvar)) + latent_loss = K.sum(latent_loss, axis=-1) # sum over latent dimension + latent_loss = K.mean(latent_loss, axis=0) # avg over batch + + # use beta to force less usage of vector space: + latent_loss = self.beta * latent_loss + self.add_loss(latent_loss, x) + + def reparameterization_trick(): + epsilon = K.random_normal(shape=logvar.shape, + mean=0., logvar=1.) + stddev = K.exp(logvar*0.5) + return mean + stddev * epsilon * inf + + return K.in_train_phase(reparameterization_trick, mean + 0*logvar, training=training) + + def compute_output_shape(self, input_shape): + return input_shape[0] \ No newline at end of file