Skip to content

Commit

Permalink
refactor SampleLayer into its own module
Browse files Browse the repository at this point in the history
  • Loading branch information
alecGraves committed May 17, 2019
1 parent d78d8e3 commit 810506b
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 161 deletions.
15 changes: 8 additions & 7 deletions bvae/ae.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,20 @@ def test():
batchSize = 8
latentSize = 100

img = load_img(os.path.join('..','images', 'img.jpg'), target_size=inputShape[:-1])
img = load_img(os.path.join(os.path.dirname(__file__), '..','images', 'img.jpg'), target_size=inputShape[:-1])
img.show()

img = np.array(img, dtype=np.float32) / 255 - 0.5
img = np.array(img, dtype=np.float32) * (2/255) - 1
# print(np.min(img))
# print(np.max(img))
# print(np.mean(img))

img = np.array([img]*batchSize) # make fake batches to improve GPU utilization

# This is how you build the autoencoder
encoder = Darknet19Encoder(inputShape, latentSize=latentSize, latentConstraints='bvae', beta=69, capacity=15, randomSample=True)
encoder = Darknet19Encoder(inputShape, latentSize=latentSize, latentConstraints='bvae', beta=69)
decoder = Darknet19Decoder(inputShape, latentSize=latentSize)
bvae = AutoEncoder(encoder, decoder)

bvae.ae.compile(optimizer='adam', loss='mean_absolute_error')
while True:
bvae.ae.fit(img, img,
Expand All @@ -53,9 +56,7 @@ def test():
print(latentVec)

pred = bvae.ae.predict(img) # get the reconstructed image
pred[pred > 0.5] = 0.5 # clean it up a bit
pred[pred < -0.5] = -0.5
pred = np.uint8((pred + 0.5)* 255) # convert to regular image values
pred = np.uint8((pred + 1)* 255/2) # convert to regular image values

pred = Image.fromarray(pred[0])
pred.show() # display popup
Expand Down
101 changes: 3 additions & 98 deletions bvae/model_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
'''
model_utils.py
contains custom layers, etc. for building mdoels.
contains custom blocks, etc. for building mdoels.
created by shadySource
Expand All @@ -10,111 +10,16 @@
from tensorflow.python.keras.layers import (InputLayer, Conv2D, Conv2DTranspose,
BatchNormalization, LeakyReLU, MaxPool2D, UpSampling2D,
Reshape, GlobalAveragePooling2D, Layer)
from tensorflow.python.keras import backend as K

class ConvBnLRelu(object):
def __init__(self, filters, kernelSize, strides=1):
self.filters = filters
self.kernelSize = kernelSize
self.strides = strides
# return conv + bn + leaky_relu model
def __call__(self, net):
def __call__(self, net, training=None):
net = Conv2D(self.filters, self.kernelSize, strides=self.strides, padding='same')(net)
net = BatchNormalization()(net)
net = BatchNormalization()(net, training=training)
net = LeakyReLU()(net)
return net

class SampleLayer(Layer):
'''
Keras Layer to grab a random sample from a distribution (by multiplication)
Computes "(normal)*stddev + mean" for the vae sampling operation
(written for tf backend)
Additionally,
Applies regularization to the latent space representation.
Can perform standard regularization or B-VAE regularization.
call:
pass in mean then stddev layers to sample from the distribution
ex.
sample = SampleLayer('bvae', 16)([mean, stddev])
'''
def __init__(self, latent_regularizer='bvae', beta=100., capacity=0., randomSample=True, **kwargs):
'''
args:
------
latent_regularizer : str
Either 'bvae', 'vae', or 'no'
Determines whether regularization is applied
to the latent space representation.
beta : float
beta > 1, used for 'bvae' latent_regularizer,
(Unused if 'bvae' not selected)
capacity : float
used for 'bvae' to try to break input down to a set number
of basis. (e.g. at 25, the network will try to use
25 dimensions of the latent space)
(unused if 'bvae' not selected)
randomSample : bool
whether or not to use random sampling when selecting from distribution.
if false, the latent vector equals the mean, essentially turning this into a
standard autoencoder.
------
ex.
sample = SampleLayer('bvae', 16)([mean, stddev])
'''
self.reg = latent_regularizer
self.beta = beta
self.capacity = capacity
self.random = randomSample
if K.image_data_format() == "channels_last":
self.sum_axis = -1
else:
self.sum_axis = 0

super(SampleLayer, self).__init__(**kwargs)

def build(self, input_shape):
# save the shape for distribution sampling

super(SampleLayer, self).build(input_shape) # needed for layers

def call(self, x):
if len(x) != 2:
raise Exception('input layers must be a list: mean and stddev')
if len(x[0].shape) != 2 or len(x[1].shape) != 2:
raise Exception('input shape is not a vector [batchSize, latentSize]')

mean = x[0]
stddev = K.abs(x[1])

# trick to allow setting batch at train/eval time
if mean.shape[0].value == None or stddev.shape[0].value == None:
return mean + 0*stddev

if self.reg == 'bvae':
# kl divergence:
latent_loss = -0.5 * K.mean(K.sum(1 + stddev
- K.square(mean)
- K.exp(stddev), axis=self.sum_axis))
# use beta to force less usage of vector space:
# also try to use <capacity> dimensions of the space:
latent_loss = self.beta * K.abs(latent_loss)
self.add_loss(latent_loss, x)
elif self.reg == 'vae':
# kl divergence:
latent_loss = -0.5 * K.mean(K.sum(1 + stddev
- K.square(mean)
- K.exp(stddev), axis=self.sum_axis))
self.add_loss(latent_loss, x)

epsilon = K.random_normal(shape=stddev.shape,
mean=0., stddev=1.)
if self.random:
# 'reparameterization trick':
return mean + K.exp(stddev/2) * epsilon
else: # do not perform random sampling, simply grab the impulse value
return mean + 0*stddev # Keras needs the *0 so the gradinent is not None

def compute_output_shape(self, input_shape):
return input_shape[0]
105 changes: 49 additions & 56 deletions bvae/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
Reshape, GlobalAveragePooling2D)
from tensorflow.python.keras.models import Model

from model_utils import ConvBnLRelu, SampleLayer
from model_utils import ConvBnLRelu
from sample_layer import SampleLayer

class Architecture(object):
'''
Expand All @@ -30,6 +31,8 @@ def __init__(self, inputShape=None, batchSize=None, latentSize=None):
latentSize : int
the number of dimensions in the two output distribution vectors -
mean and std-deviation
latentSize : Bool or None
True forces resampling, False forces no resampling, None chooses based on K.learning_phase()
'''
self.inputShape = inputShape
self.batchSize = batchSize
Expand All @@ -51,8 +54,7 @@ class Darknet19Encoder(Architecture):
https://github.com/pjreddie/darknet/blob/master/cfg/darknet19.cfg
'''
def __init__(self, inputShape=(256, 256, 3), batchSize=None,
latentSize=1000, latentConstraints='bvae', beta=100., capacity=0.,
randomSample=True):
latentSize=1000, latentConstraints='bvae', beta=100., training=None):
'''
params
-------
Expand All @@ -63,53 +65,43 @@ def __init__(self, inputShape=(256, 256, 3), batchSize=None,
beta : float
beta > 1, used for 'bvae' latent_regularizer
(Unused if 'bvae' not selected, default 100)
capacity : float
used for 'bvae' to try to break input down to a set number
of basis. (e.g. at 25, the network will try to use
25 dimensions of the latent space)
(unused if 'bvae' not selected)
randomSample : bool
whether or not to use random sampling when selecting from distribution.
if false, the latent vector equals the mean, essentially turning this into a
standard autoencoder.
'''
self.latentConstraints = latentConstraints
self.beta = beta
self.latentCapacity = capacity
self.randomSample = randomSample
self.training=training
super().__init__(inputShape, batchSize, latentSize)

def Build(self):
# create the input layer for feeding the netowrk
inLayer = Input(self.inputShape, self.batchSize)
net = ConvBnLRelu(32, kernelSize=3)(inLayer) # 1
net = ConvBnLRelu(32, kernelSize=3)(inLayer, training=self.training) # 1
net = MaxPool2D((2, 2), strides=(2, 2))(net)

net = ConvBnLRelu(64, kernelSize=3)(net) # 2
net = ConvBnLRelu(64, kernelSize=3)(net, training=self.training) # 2
net = MaxPool2D((2, 2), strides=(2, 2))(net)

net = ConvBnLRelu(128, kernelSize=3)(net) # 3
net = ConvBnLRelu(64, kernelSize=1)(net) # 4
net = ConvBnLRelu(128, kernelSize=3)(net) # 5
net = ConvBnLRelu(128, kernelSize=3)(net, training=self.training) # 3
net = ConvBnLRelu(64, kernelSize=1)(net, training=self.training) # 4
net = ConvBnLRelu(128, kernelSize=3)(net, training=self.training) # 5
net = MaxPool2D((2, 2), strides=(2, 2))(net)

net = ConvBnLRelu(256, kernelSize=3)(net) # 6
net = ConvBnLRelu(128, kernelSize=1)(net) # 7
net = ConvBnLRelu(256, kernelSize=3)(net) # 8
net = ConvBnLRelu(256, kernelSize=3)(net, training=self.training) # 6
net = ConvBnLRelu(128, kernelSize=1)(net, training=self.training) # 7
net = ConvBnLRelu(256, kernelSize=3)(net, training=self.training) # 8
net = MaxPool2D((2, 2), strides=(2, 2))(net)

net = ConvBnLRelu(512, kernelSize=3)(net) # 9
net = ConvBnLRelu(256, kernelSize=1)(net) # 10
net = ConvBnLRelu(512, kernelSize=3)(net) # 11
net = ConvBnLRelu(256, kernelSize=1)(net) # 12
net = ConvBnLRelu(512, kernelSize=3)(net) # 13
net = ConvBnLRelu(512, kernelSize=3)(net, training=self.training) # 9
net = ConvBnLRelu(256, kernelSize=1)(net, training=self.training) # 10
net = ConvBnLRelu(512, kernelSize=3)(net, training=self.training) # 11
net = ConvBnLRelu(256, kernelSize=1)(net, training=self.training) # 12
net = ConvBnLRelu(512, kernelSize=3)(net, training=self.training) # 13
net = MaxPool2D((2, 2), strides=(2, 2))(net)

net = ConvBnLRelu(1024, kernelSize=3)(net) # 14
net = ConvBnLRelu(512, kernelSize=1)(net) # 15
net = ConvBnLRelu(1024, kernelSize=3)(net) # 16
net = ConvBnLRelu(512, kernelSize=1)(net) # 17
net = ConvBnLRelu(1024, kernelSize=3)(net) # 18
net = ConvBnLRelu(1024, kernelSize=3)(net, training=self.training) # 14
net = ConvBnLRelu(512, kernelSize=1)(net, training=self.training) # 15
net = ConvBnLRelu(1024, kernelSize=3)(net, training=self.training) # 16
net = ConvBnLRelu(512, kernelSize=1)(net, training=self.training) # 17
net = ConvBnLRelu(1024, kernelSize=3)(net, training=self.training) # 18

# variational encoder output (distributions)
mean = Conv2D(filters=self.latentSize, kernel_size=(1, 1),
Expand All @@ -119,13 +111,13 @@ def Build(self):
padding='same')(net)
stddev = GlobalAveragePooling2D()(stddev)

sample = SampleLayer(self.latentConstraints, self.beta,
self.latentCapacity, self.randomSample)([mean, stddev])
sample = SampleLayer(self.latentConstraints, self.beta)([mean, stddev], training=self.training)

return Model(inputs=inLayer, outputs=sample)

class Darknet19Decoder(Architecture):
def __init__(self, inputShape=(256, 256, 3), batchSize=None, latentSize=1000):
def __init__(self, inputShape=(256, 256, 3), batchSize=None, latentSize=1000, training=None):
self.training=training
super().__init__(inputShape, batchSize, latentSize)

def Build(self):
Expand All @@ -138,38 +130,39 @@ def Build(self):

# TODO try inverting num filter arangement (e.g. 512, 1204, 512, 1024, 512)
# and also try (1, 3, 1, 3, 1) for the filter shape
net = ConvBnLRelu(1024, kernelSize=3)(net)
net = ConvBnLRelu(512, kernelSize=1)(net)
net = ConvBnLRelu(1024, kernelSize=3)(net)
net = ConvBnLRelu(512, kernelSize=1)(net)
net = ConvBnLRelu(1024, kernelSize=3)(net)
net = ConvBnLRelu(1024, kernelSize=3)(net, training=self.training)
net = ConvBnLRelu(512, kernelSize=1)(net, training=self.training)
net = ConvBnLRelu(1024, kernelSize=3)(net, training=self.training)
net = ConvBnLRelu(512, kernelSize=1)(net, training=self.training)
net = ConvBnLRelu(1024, kernelSize=3)(net, training=self.training)

net = UpSampling2D((2, 2))(net)
net = ConvBnLRelu(512, kernelSize=3)(net)
net = ConvBnLRelu(256, kernelSize=1)(net)
net = ConvBnLRelu(512, kernelSize=3)(net)
net = ConvBnLRelu(256, kernelSize=1)(net)
net = ConvBnLRelu(512, kernelSize=3)(net)
net = ConvBnLRelu(512, kernelSize=3)(net, training=self.training)
net = ConvBnLRelu(256, kernelSize=1)(net, training=self.training)
net = ConvBnLRelu(512, kernelSize=3)(net, training=self.training)
net = ConvBnLRelu(256, kernelSize=1)(net, training=self.training)
net = ConvBnLRelu(512, kernelSize=3)(net, training=self.training)

net = UpSampling2D((2, 2))(net)
net = ConvBnLRelu(256, kernelSize=3)(net)
net = ConvBnLRelu(128, kernelSize=1)(net)
net = ConvBnLRelu(256, kernelSize=3)(net)
net = ConvBnLRelu(256, kernelSize=3)(net, training=self.training)
net = ConvBnLRelu(128, kernelSize=1)(net, training=self.training)
net = ConvBnLRelu(256, kernelSize=3)(net, training=self.training)

net = UpSampling2D((2, 2))(net)
net = ConvBnLRelu(128, kernelSize=3)(net)
net = ConvBnLRelu(64, kernelSize=1)(net)
net = ConvBnLRelu(128, kernelSize=3)(net)
net = ConvBnLRelu(128, kernelSize=3)(net, training=self.training)
net = ConvBnLRelu(64, kernelSize=1)(net, training=self.training)
net = ConvBnLRelu(128, kernelSize=3)(net, training=self.training)

net = UpSampling2D((2, 2))(net)
net = ConvBnLRelu(64, kernelSize=3)(net)
net = ConvBnLRelu(64, kernelSize=3)(net, training=self.training)

net = UpSampling2D((2, 2))(net)
net = ConvBnLRelu(64, kernelSize=1)(net)
net = ConvBnLRelu(32, kernelSize=3)(net)
# net = ConvBnLRelu(3, kernelSize=1)(net)
net = ConvBnLRelu(32, kernelSize=3)(net, training=self.training)
net = ConvBnLRelu(64, kernelSize=1)(net, training=self.training)

# net = ConvBnLRelu(3, kernelSize=1)(net, training=self.training)
net = Conv2D(filters=self.inputShape[-1], kernel_size=(1, 1),
padding='same')(net)
padding='same', activation="tanh")(net)

return Model(inLayer, net)

Expand Down
Loading

0 comments on commit 810506b

Please sign in to comment.