diff --git a/FrEIA/modules/invertible_resnet.py b/FrEIA/modules/invertible_resnet.py index 8beaa15..26affa0 100644 --- a/FrEIA/modules/invertible_resnet.py +++ b/FrEIA/modules/invertible_resnet.py @@ -77,10 +77,10 @@ def forward(self, x, c=None, rev=False, jac=True): if not rev: out = (x - self.loc) / self.scale - log_jac_det = -utils.sum_except_batch(self.log_scale) + log_jac_det = -utils.sum_except_batch(self.log_scale) * torch.prod(torch.tensor(x.shape[2:]).float()) else: out = self.scale * x + self.loc - log_jac_det = utils.sum_except_batch(self.log_scale) + log_jac_det = utils.sum_except_batch(self.log_scale) * torch.prod(torch.tensor(x.shape[2:]).float()) return (out,), log_jac_det