diff --git a/FrEIA/modules/all_in_one_block.py b/FrEIA/modules/all_in_one_block.py index 07bb1e6..3ad846e 100644 --- a/FrEIA/modules/all_in_one_block.py +++ b/FrEIA/modules/all_in_one_block.py @@ -223,12 +223,12 @@ def _affine(self, x, a, rev=False): # the entire coupling coefficient tensor is scaled down by a # factor of ten for stability and easier initialization. - a *= 0.1 + a = a * 0.1 ch = x.shape[1] sub_jac = self.clamp * torch.tanh(a[:, :ch]/self.clamp) if self.GIN: - sub_jac -= torch.mean(sub_jac, dim=self.sum_dims, keepdim=True) + sub_jac = sub_jac - torch.mean(sub_jac, dim=self.sum_dims, keepdim=True) if not rev: return (x * torch.exp(sub_jac) + a[:, ch:], @@ -279,7 +279,7 @@ def forward(self, x, c=[], rev=False, jac=True): # trick to get the total number of non-channel dimensions: # number of elements of the first channel of the first batch member n_pixels = x_out[0, :1].numel() - log_jac_det += (-1)**rev * n_pixels * global_scaling_jac + log_jac_det = log_jac_det + (-1)**rev * n_pixels * global_scaling_jac return (x_out,), log_jac_det