Skip to content

Commit

Permalink
refactor for convenience, make baseline_model10_10 a class
Browse files Browse the repository at this point in the history
  • Loading branch information
Anastasia Gracheva committed Feb 19, 2020
1 parent 997de85 commit 5ebaa5b
Show file tree
Hide file tree
Showing 7 changed files with 168 additions and 160 deletions.
8 changes: 6 additions & 2 deletions metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import matplotlib.pyplot as plt
import PIL


def _gaussian_fit(img):
assert img.ndim == 2
assert (img >= 0).all()
Expand All @@ -24,6 +25,7 @@ def _gaussian_fit(img):
).reshape(2, 2)
return mu, cov


def _get_val_metric_single(img):
"""Returns a vector of gaussian fit results to the image.
The components are: [mu0, mu1, sigma0^2, sigma1^2, covariance, integral]
Expand All @@ -36,10 +38,13 @@ def _get_val_metric_single(img):

return np.array((*mu, *cov.diagonal(), cov[0, 1], img.sum()))


_METRIC_NAMES = ['Mean0', 'Mean1', 'Sigma0^2', 'Sigma1^2', 'Cov01', 'Sum']


get_val_metric = np.vectorize(_get_val_metric_single, signature='(m,n)->(k)')


def get_val_metric_v(imgs):
"""Returns a vector of gaussian fit results to the image.
The components are: [mu0, mu1, sigma0^2, sigma1^2, covariance, integral]
Expand Down Expand Up @@ -84,6 +89,7 @@ def make_histograms(data_real, data_gen, title, figsize=(8, 8), n_bins=100, logy
img = PIL.Image.open(buf)
return np.array(img.getdata(), dtype=np.uint8).reshape(1, img.size[0], img.size[1], -1)


def make_metric_plots(images_real, images_gen):
plots = {}
try:
Expand All @@ -96,5 +102,3 @@ def make_metric_plots(images_real, images_gen):
pass

return plots


131 changes: 131 additions & 0 deletions models/baseline_10x10.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import tensorflow as tf


def get_generator(activation, kernel_init, latent_dim):
generator = tf.keras.Sequential([
tf.keras.layers.Dense(units=64, activation=activation, input_shape=(latent_dim,)),

tf.keras.layers.Reshape((4, 4, 4)),

tf.keras.layers.Conv2D(filters=32, kernel_size=3, padding='same', activation=activation, kernel_initializer=kernel_init),
tf.keras.layers.Conv2D(filters=32, kernel_size=3, padding='same', activation=activation, kernel_initializer=kernel_init),
tf.keras.layers.UpSampling2D(), # 8x8

tf.keras.layers.Conv2D(filters=16, kernel_size=3, padding='same' , activation=activation, kernel_initializer=kernel_init),
tf.keras.layers.Conv2D(filters=16, kernel_size=3, padding='valid', activation=activation, kernel_initializer=kernel_init), # 6x6
tf.keras.layers.UpSampling2D(), # 12x12

tf.keras.layers.Conv2D(filters=8, kernel_size=3, padding='valid', activation=activation, kernel_initializer=kernel_init), # 10x10
tf.keras.layers.Conv2D(filters=1, kernel_size=1, padding='valid', activation=tf.keras.activations.relu, kernel_initializer=kernel_init),

tf.keras.layers.Reshape((10, 10)),
], name='generator')
return generator


def get_discriminator(activation, kernel_init, dropout_rate):
discriminator = tf.keras.Sequential([
tf.keras.layers.Reshape((10, 10, 1), input_shape=(10, 10)),

tf.keras.layers.Conv2D(filters=16, kernel_size=3, padding='same', activation=activation, kernel_initializer=kernel_init),
tf.keras.layers.Dropout(dropout_rate),
tf.keras.layers.Conv2D(filters=16, kernel_size=3, padding='valid', activation=activation, kernel_initializer=kernel_init), # 8x8
tf.keras.layers.Dropout(dropout_rate),

tf.keras.layers.MaxPool2D(), # 4x4

tf.keras.layers.Conv2D(filters=32, kernel_size=3, padding='same', activation=activation, kernel_initializer=kernel_init),
tf.keras.layers.Dropout(dropout_rate),
tf.keras.layers.Conv2D(filters=32, kernel_size=3, padding='same', activation=activation, kernel_initializer=kernel_init),
tf.keras.layers.Dropout(dropout_rate),

tf.keras.layers.MaxPool2D(), # 2x2

tf.keras.layers.Conv2D(filters=64, kernel_size=2, padding='valid', activation=activation, kernel_initializer=kernel_init), # 1x1
tf.keras.layers.Dropout(dropout_rate),

tf.keras.layers.Reshape((64,)),

tf.keras.layers.Dense(units=128, activation=activation),
tf.keras.layers.Dropout(dropout_rate),

tf.keras.layers.Dense(units=1, activation=None),
], name='discriminator')
return discriminator


def disc_loss(d_real, d_fake):
return tf.reduce_mean(d_fake - d_real)


def gen_loss(d_real, d_fake):
return tf.reduce_mean(d_real - d_fake)


class BaselineModel10x10:
def __init__(self, activation=tf.keras.activations.relu, kernel_init='glorot_uniform',
dropout_rate=0.2, lr=1e-4, latent_dim=32, gp_lambda=10., num_disc_updates=3):
self.disc_opt = tf.keras.optimizers.RMSprop(lr)
self.gen_opt = tf.keras.optimizers.RMSprop(lr)
self.latent_dim = latent_dim
self.gp_lambda = gp_lambda
self.num_disc_updates = num_disc_updates

self.generator = get_generator(activation=activation, kernel_init=kernel_init, latent_dim=latent_dim)
self.discriminator = get_discriminator(activation=activation, kernel_init=kernel_init, dropout_rate=dropout_rate)

self.step_counter = tf.Variable(0, dtype='int32', trainable=False)

def make_fake(self, size):
return self.generator(
tf.random.normal(shape=(size, self.latent_dim), dtype='float32')
)

def gradient_penalty(self, real, fake):
alpha = tf.random.uniform(shape=[len(real), 1, 1])
interpolates = alpha * real + (1 - alpha) * fake
with tf.GradientTape() as t:
t.watch(interpolates)
d_int = self.discriminator(interpolates)
grads = tf.reshape(t.gradient(d_int, interpolates), [len(real), -1])
return tf.reduce_mean(tf.maximum(tf.norm(grads, axis=-1) - 1, 0)**2)

@tf.function
def calculate_losses(self, batch):
fake = self.make_fake(len(batch))
d_real = self.discriminator(batch)
d_fake = self.discriminator(fake)

d_loss = disc_loss(d_real, d_fake) + self.gp_lambda * self.gradient_penalty(batch, fake)
g_loss = gen_loss(d_real, d_fake)
return {'disc_loss': d_loss, 'gen_loss': g_loss}

def disc_step(self, batch):
batch = tf.convert_to_tensor(batch)

with tf.GradientTape() as t:
losses = self.calculate_losses(batch)

grads = t.gradient(losses['disc_loss'], self.discriminator.trainable_variables)
self.disc_opt.apply_gradients(zip(grads, self.discriminator.trainable_variables))
return losses

def gen_step(self, batch):
batch = tf.convert_to_tensor(batch)

with tf.GradientTape() as t:
losses = self.calculate_losses(batch)

grads = t.gradient(losses['gen_loss'], self.generator.trainable_variables)
self.gen_opt.apply_gradients(zip(grads, self.generator.trainable_variables))
return losses

@tf.function
def training_step(self, batch):
if self.step_counter == self.num_disc_updates:
result = self.gen_step(batch)
self.step_counter.assign(0)
else:
result = self.disc_step(batch)
self.step_counter.assign_add(1)
return result
138 changes: 0 additions & 138 deletions models/baseline_10x10/__init__.py

This file was deleted.

File renamed without changes.
7 changes: 4 additions & 3 deletions models/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
from tqdm import trange


def train(data_train, data_val, train_step_fn, loss_eval_fn, num_epochs, batch_size, train_writer=None, val_writer=None, callbacks=[]):
def train(data_train, data_val, train_step_fn, loss_eval_fn, num_epochs, batch_size,
train_writer=None, val_writer=None, callbacks=[]):
for i_epoch in range(num_epochs):
print("Working on epoch #{}".format(i_epoch), flush=True)

tf.keras.backend.set_learning_phase(1) # training
tf.keras.backend.set_learning_phase(1) # training

shuffle_ids = np.random.permutation(len(data_train))
losses_train = {}
Expand All @@ -20,7 +21,7 @@ def train(data_train, data_val, train_step_fn, loss_eval_fn, num_epochs, batch_s
losses_train[k] = losses_train.get(k, 0) + l.numpy() * len(batch)
losses_train = {k : l / len(data_train) for k, l in losses_train.items()}

tf.keras.backend.set_learning_phase(0) # testing
tf.keras.backend.set_learning_phase(0) # testing

losses_val = {k : l.numpy() for k, l in loss_eval_fn(data_val).items()}
for f in callbacks:
Expand Down
4 changes: 3 additions & 1 deletion plotting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ def _bootstrap_error(data, function, num_bs=100):
bs_data = np.random.choice(data, size=(num_bs, len(data)), replace=True)
return np.array([function(bs_sample) for bs_sample in bs_data]).std()


def _get_stats(arr):
class Obj:
pass
Expand All @@ -22,6 +23,7 @@ class Obj:

return result


def compare_two_dists(d_real, d_gen, label, tag=None, nbins=100):
ax = plt.gca()
bins = np.linspace(
Expand All @@ -39,7 +41,7 @@ def compare_two_dists(d_real, d_gen, label, tag=None, nbins=100):
leg_entry = 'gen'

plt.hist(d_real, bins=bins, density=True, label='real')
plt.hist(d_gen , bins=bins, density=True, label=leg_entry, histtype='step', linewidth=2.)
plt.hist(d_gen, bins=bins, density=True, label=leg_entry, histtype='step', linewidth=2.)

string = '\n'.join([
f"real: mean = {stats_real.mean :.4f} +/- {stats_real.mean_err :.4f}",
Expand Down
Loading

0 comments on commit 5ebaa5b

Please sign in to comment.