diff --git a/bvae/models.py b/bvae/models.py index 7689c6b..80bb77a 100644 --- a/bvae/models.py +++ b/bvae/models.py @@ -107,11 +107,11 @@ def Build(self): mean = Conv2D(filters=self.latentSize, kernel_size=(1, 1), padding='same')(net) mean = GlobalAveragePooling2D()(mean) - stddev = Conv2D(filters=self.latentSize, kernel_size=(1, 1), + logvar = Conv2D(filters=self.latentSize, kernel_size=(1, 1), padding='same')(net) - stddev = GlobalAveragePooling2D()(stddev) + logvar = GlobalAveragePooling2D()(logvar) - sample = SampleLayer(self.latentConstraints, self.beta)([mean, stddev], training=self.training) + sample = SampleLayer(self.latentConstraints, self.beta)([mean, logvar], training=self.training) return Model(inputs=inLayer, outputs=sample)