From 9338b1167315f011ec11d48450a621c7c8f6dd49 Mon Sep 17 00:00:00 2001 From: Artem Maevskiy Date: Wed, 14 Oct 2020 23:46:24 +0300 Subject: [PATCH] dynamic stepping; logloss fix --- models/model_v4.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/models/model_v4.py b/models/model_v4.py index 884a9b9..929094f 100644 --- a/models/model_v4.py +++ b/models/model_v4.py @@ -41,7 +41,7 @@ def gen_loss_cramer(d_real, d_fake, d_fake_2): def logloss(x): - return tf.where(x < -30., -x, tf.math.log(1. + tf.math.exp(-x))) + return tf.nn.softplus(-x) def disc_loss_js(d_real, d_fake): @@ -69,6 +69,11 @@ def __init__(self, config): assert not (self.js and self.cramer) self.stochastic_stepping = config['stochastic_stepping'] + self.dynamic_stepping = config.get('dynamic_stepping', False) + if self.dynamic_stepping: + assert not self.stochastic_stepping + self.dynamic_stepping_threshold = config['dynamic_stepping_threshold'] + self.latent_dim = config['latent_dim'] architecture_descr = config['architecture'] @@ -225,5 +230,9 @@ def training_step(self, feature_batch, target_batch): self.step_counter.assign(0) else: result = self.disc_step(feature_batch, target_batch) - self.step_counter.assign_add(1) + if self.dynamic_stepping: + if result['disc_loss'] < self.dynamic_stepping_threshold: + self.step_counter.assign(self.num_disc_updates) + else: + self.step_counter.assign_add(1) return result