Skip to content

Commit e1ff1ec

Browse files
committed
🐶 Trick to prevent expoded loss in melgan_stft.
1 parent 43f3297 commit e1ff1ec

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

examples/melgan_stft/train_melgan_stft.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,12 @@ def compute_per_example_generator_losses(self, batch, outputs):
114114
sc_loss, mag_loss = calculate_2d_loss(
115115
audios, tf.squeeze(y_hat, -1), self.stft_loss
116116
)
117+
118+
# trick to prevent loss expoded here
119+
sc_loss = tf.where(sc_loss >= 15.0, 0.0, sc_loss)
120+
mag_loss = tf.where(mag_loss >= 15.0, 0.0, mag_loss)
121+
122+
# compute generator loss
117123
gen_loss = 0.5 * (sc_loss + mag_loss)
118124

119125
if self.steps >= self.config["discriminator_train_start_steps"]:

0 commit comments

Comments
 (0)