diff --git a/omni_anomaly/model.py b/omni_anomaly/model.py index ac57747..3831121 100644 --- a/omni_anomaly/model.py +++ b/omni_anomaly/model.py @@ -144,7 +144,7 @@ def get_training_loss(self, x, n_z=None): latent_log_probs=chain.vi.latent_log_probs, axis=chain.vi.axis ) - loss = tf.reduce_mean(vi.training.sgvb()) + loss = tf.reduce_mean(chain.vi.training.sgvb()) return loss def get_score(self, x, n_z=None,