Skip to content

Commit

Permalink
implement mode-seeking regularizer
Browse files Browse the repository at this point in the history
  • Loading branch information
Byung-Hoon Kim committed Mar 18, 2019
1 parent 3542039 commit be39b9f
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 25 deletions.
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ Graphs
- WGAN-GP
- GEOGAN

Regularizers
- Mode-seek

Other model options are being updated.

## Implementation
Expand Down Expand Up @@ -77,9 +80,10 @@ Other model options are being updated.

## Reference
- GAN: [Generative Adversarial Nets](http://papers.nips.cc/paper/5423-generative-adversarial-nets)
- CGAN: [Generative Adversarial Nets](https://arxiv.org/abs/1411.1784)
- CGAN: [Conditional Generative Adversarial Nets](https://arxiv.org/abs/1411.1784)
- DCGAN: [Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks](https://arxiv.org/abs/1511.06434)
- LSGAN: [Least Squares Generative Adversarial Networks](https://arxiv.org/abs/1611.04076)
- WGAN: [Wasserstein GAN](https://arxiv.org/abs/1701.07875)
- WGAN-GP: [Improved Training of Wasserstein GANs](https://arxiv.org/abs/1704.00028)
- GEOGAN: [Geometric GAN](https://arxiv.org/abs/1705.02894)
- MODESEEK: [Mode Seeking Generative Adversarial Networks for Diverse Image Synthesis](https://arxiv.org/abs/1903.05628)
51 changes: 30 additions & 21 deletions graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,22 +29,21 @@ def build_model(self, model, type):
self.generated_image = self.model.generator_model_output


def build_graph(self, type):
def build_graph(self, type, regularizer):
self.type = type
self.regularizer = regularizer

# LOSSES
with tf.device(self.device):
with tf.name_scope(self.scope+"_loss"):
# VANILLA GAN
if type=='gan':
with tf.name_scope("generator_loss"):
self.generator_loss = tf.losses.sigmoid_cross_entropy(\
self.generator_fake_loss = tf.losses.sigmoid_cross_entropy(\
multi_class_labels=tf.ones_like(self.model.discriminator_fake_model_output), \
logits=self.model.discriminator_fake_model_logit, \
weights=1.0, loss_collection="GENERATOR_LOSS", scope='generator_loss')

self.generator_loss = tf.add_n(tf.get_collection("GENERATOR_LOSS"), name='generator_loss')

with tf.name_scope("discriminator_loss"):
self.discriminator_real_loss = tf.losses.sigmoid_cross_entropy(\
multi_class_labels=tf.ones_like(self.model.discriminator_real_model_output), \
Expand All @@ -56,18 +55,14 @@ def build_graph(self, type):
logits=self.model.discriminator_fake_model_logit, \
weights=1.0, loss_collection="DISCRIMINATOR_LOSS", scope='discriminator_fake_loss')

self.discriminator_loss = tf.add_n(tf.get_collection("DISCRIMINATOR_LOSS"), name='discriminator_loss')

# LEAST-SQUARES GAN
elif type=='lsgan':
with tf.name_scope("generator_loss"):
self.generator_loss = tf.losses.mean_squared_error(\
self.generator_fake_loss = tf.losses.mean_squared_error(\
labels=tf.ones_like(self.model.discriminator_fake_model_output), \
predictions=self.model.discriminator_fake_model_logit, \
weights=0.5, loss_collection="GENERATOR_LOSS", scope='generator_loss')

self.generator_loss = tf.add_n(tf.get_collection("GENERATOR_LOSS"), name='generator_loss')

with tf.name_scope("discriminator_loss"):
self.discriminator_real_loss = tf.losses.mean_squared_error(\
labels=tf.ones_like(self.model.discriminator_real_model_output), \
Expand All @@ -79,17 +74,13 @@ def build_graph(self, type):
predictions=self.model.discriminator_fake_model_logit, \
weights=0.5, loss_collection="DISCRIMINATOR_LOSS", scope='discriminator_fake_loss')

self.discriminator_loss = tf.add_n(tf.get_collection("DISCRIMINATOR_LOSS"), name='discriminator_loss')

# WASSERSTEIN GAN
elif type=='wgan' or type=='wgan-gp':
with tf.name_scope("generator_loss"):
self.generator_loss = tf.losses.compute_weighted_loss(\
self.generator_fake_loss = tf.losses.compute_weighted_loss(\
losses=self.model.discriminator_fake_model_logit, \
weights=-1.0, loss_collection="GENERATOR_LOSS", scope='generator_loss')

self.generator_loss = tf.add_n(tf.get_collection("GENERATOR_LOSS"), name='generator_loss')

with tf.name_scope("discriminator_loss"):
self.discriminator_real_loss = tf.losses.compute_weighted_loss(\
losses=self.model.discriminator_real_model_logit, \
Expand All @@ -110,17 +101,13 @@ def build_graph(self, type):
self.gradient_penalty = tf.square(gradient_norm - 1)
tf.add_to_collection("DISCRIMINATOR_LOSS", self.gradient_penalty)

self.discriminator_loss = tf.add_n(tf.get_collection("DISCRIMINATOR_LOSS"), name='discriminator_loss')

# GEOMETRIC GAN
elif type=='geogan':
with tf.name_scope("generator_loss"):
self.generator_loss = tf.losses.compute_weighted_loss(\
self.generator_fake_loss = tf.losses.compute_weighted_loss(\
losses=self.model.discriminator_fake_model_logit, \
weights=-1.0, loss_collection="GENERATOR_LOSS", scope='generator_loss')

self.generator_loss = tf.add_n(tf.get_collection("GENERATOR_LOSS"), name='generator_loss')

with tf.name_scope("discriminator_loss"):
self.discriminator_real_loss = tf.losses.hinge_loss(\
labels=tf.ones_like(self.model.discriminator_real_model_output), \
Expand All @@ -132,11 +119,30 @@ def build_graph(self, type):
logits=self.model.discriminator_fake_model_logit, \
weights=1.0, loss_collection="DISCRIMINATOR_LOSS", scope='discriminator_fake_loss')

self.discriminator_loss = tf.add_n(tf.get_collection("DISCRIMINATOR_LOSS"), name='discriminator_loss')

else:
raise ValueError('unknown gan graph type: {}'.format(type))

# REGULARIZERS
with tf.device(self.device):
with tf.name_scope(self.scope+'_regularizer'):
if regularizer=='modeseek':
with tf.name_scope(self.regularizer):
_img1, _img2 = tf.split(self.generated_image, 2, axis=0, name='image_split')
_noise1, _noise2 = tf.split(self.generated_image, 2, axis=0, name='noise_split')
_modeseek_loss = tf.reduce_mean(tf.abs(_img1-_img2)) / tf.reduce_mean(tf.abs(_noise1-_noise2))
self.regularizer_loss = 1 / (_modeseek_loss + 1e-8)
tf.add_to_collection("GENERATOR_LOSS", self.regularizer_loss)

elif regularizer=='spectralnorm':
raise NotImplementedError('{} is to be updated'.format(regularizer))

else:
pass

# GET FINAL LOSS
self.generator_loss = tf.add_n(tf.get_collection("GENERATOR_LOSS"), name='generator_loss')
self.discriminator_loss = tf.add_n(tf.get_collection("DISCRIMINATOR_LOSS"), name='discriminator_loss')

# IMAGES
with tf.device(self.device):
with tf.name_scope(self.scope+'_image'):
Expand All @@ -148,10 +154,12 @@ def build_graph(self, type):
with tf.device(self.device):
with tf.name_scope(self.scope+'_summary'+'_op'):
generator_loss_mean, generator_loss_mean_op = tf.metrics.mean(self.generator_loss, name='generator_loss', updates_collections=["GENERATOR_OPS"])
generator_fake_loss_mean, generator_fake_loss_mean_op = tf.metrics.mean(self.generator_fake_loss, name='generator_fake_loss', updates_collections=["GENERATOR_OPS"])
discriminator_loss_mean, discriminator_loss_mean_op = tf.metrics.mean(self.discriminator_loss, name='discriminator_loss', updates_collections=["DISCRIMINATOR_OPS"])
discriminator_real_loss_mean, discriminator_real_loss_mean_op = tf.metrics.mean(self.discriminator_real_loss, name='discriminator_real_loss', updates_collections=["DISCRIMINATOR_OPS"])
discriminator_fake_loss_mean, discriminator_fake_loss_mean_op = tf.metrics.mean(self.discriminator_fake_loss, name='discriminator_fake_loss', updates_collections=["DISCRIMINATOR_OPS"])
if 'gp' in type: gradient_penalty_mean, gradient_penalty_mean_op = tf.metrics.mean(self.gradient_penalty, name='gradient_penalty', updates_collections=["DISCRIMINATOR_OPS"])
if regularizer: regularizer_loss_mean, regularizer_loss_mean_op = tf.metrics.mean(self.regularizer_loss, name='regularizer_loss', updates_collections=["DISCRIMINATOR_OPS"])

with tf.name_scope(self.scope+'_summary'):
_ = tf.summary.scalar(name='generator_loss', tensor=generator_loss_mean, collections=["GENERATOR_SUMMARY"], family='01_loss_total')
Expand All @@ -161,6 +169,7 @@ def build_graph(self, type):
if 'gp' in type: _ = tf.summary.scalar(name='gradient_penalty', tensor=gradient_penalty_mean, collections=["DISCRIMINATOR_SUMMARY"], family='02_loss_discriminator')
_ = tf.summary.scalar(name='generator_learning_rate', tensor=self.generator_learning_rate, collections=["GENERATOR_SUMMARY"], family='03_hyperparameter')
_ = tf.summary.scalar(name='discriminator_learning_rate', tensor=self.discriminator_learning_rate, collections=["DISCRIMINATOR_SUMMARY"], family='03_hyperparameter')
if regularizer: _ = tf.summary.scalar(name='regularizer_loss', tensor=regularizer_loss_mean, collections=["DISCRIMINATOR_SUMMARY"], family='04_regularizer')

with tf.name_scope(self.scope+'_summary'+'_merge'):
self.generator_summary = tf.summary.merge(tf.get_collection("GENERATOR_SUMMARY"))
Expand Down
1 change: 1 addition & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def _build(self):
dataset_type=self.base_option['dataset_type'], \
model_type=self.base_option['model_type'], \
graph_type=self.base_option['graph_type'], \
regularizer_type=self.base_option['regularizer_type'], \
scope=self.base_option['scope'])


Expand Down
6 changes: 3 additions & 3 deletions network.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,18 @@ def __init__(self, device, config):
self.config = config


def build_network(self, batch_size, noise_shape, num_epoch, discriminator_learning_rate, generator_learning_rate, dataset_type, model_type, graph_type, scope):
def build_network(self, batch_size, noise_shape, num_epoch, discriminator_learning_rate, generator_learning_rate, dataset_type, model_type, graph_type, regularizer_type, scope):
# 1. BUILD DATASET OBJECT
self.dataset = DatasetGAN(batch_size=batch_size, noise_shape=noise_shape)
self.dataset.build_dataset(type=dataset_type)
self.image_shape, self.noise_shape, self.label_shape = self.dataset.get_shape()

# 2. BUILD MODEL AND GRAPH OBJECT
# 2. BUILD MODEL AND GRAPH OBJECTS
self.model = ModelGAN(device=self.device, scope=scope+"_model")
self.graph = GraphGAN(device=self.device, scope=scope+"_graph")
self.graph.define_nodes(image_shape=self.image_shape, noise_shape=self.noise_shape, label_shape=self.label_shape)
self.graph.build_model(model=self.model, type=model_type)
self.graph.build_graph(type=graph_type)
self.graph.build_graph(type=graph_type, regularizer=regularizer_type)

# 3. BUILD SESSION OBJECT
self.session = SessionGAN(config=self.config)
Expand Down
1 change: 1 addition & 0 deletions utils/option.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def parse():
parser.add_argument('-tD', '--dataset_type', type=str, default='mnist', help='type of the dataset [mnist/cifar10/cifar100]')
parser.add_argument('-tM', '--model_type', type=str, default='gan', help='type of the GAN model [gan/cgan/dcgan]')
parser.add_argument('-tG', '--graph_type', type=str, default='gan', help='type of the GAN graph [gan/lsgan/wgan/wgan-gp/geogan]')
parser.add_argument('-tR', '--regularizer_type', type=str, default=None, help='type of the GAN regularizer [modeseek]')
parser.add_argument('-lD', '--discriminator_learning_rate', type=float, default=1e-3, help='learning rate of the discriminator training')
parser.add_argument('-lG', '--generator_learning_rate', type=float, default=1e-4, help='learning rate of the generator training')
parser.add_argument('-dS', '--savedir', type=str, default='./GANs', help='directory path to save the trained generator model and/or the resulting image')
Expand Down

0 comments on commit be39b9f

Please sign in to comment.