From 6912465ea3412d18e2a4f3c5f5c00e0495bad74a Mon Sep 17 00:00:00 2001 From: Lars <37488165+LarsKue@users.noreply.github.com> Date: Wed, 4 Oct 2023 14:38:07 +0200 Subject: [PATCH] Adjusted ActNorm to work as described in the paper (#167) * Adjusted ActNorm to work as described in the paper * Fix off-by-one * Fix log jacobian computation --- FrEIA/modules/invertible_resnet.py | 26 +++++++++++++++++++++----- tests/test_invertible_resnet.py | 8 ++++++-- 2 files changed, 27 insertions(+), 7 deletions(-) diff --git a/FrEIA/modules/invertible_resnet.py b/FrEIA/modules/invertible_resnet.py index 1759120..26affa0 100644 --- a/FrEIA/modules/invertible_resnet.py +++ b/FrEIA/modules/invertible_resnet.py @@ -30,7 +30,8 @@ def __init__(self, dims_in, dims_c=None, init_data: torch.Tensor = None): self.register_buffer("is_initialized", torch.tensor(False)) - dims = next(iter(dims_in)) + dims = list(next(iter(dims_in))) + dims[1:] = [1] * len(dims[1:]) self.log_scale = nn.Parameter(torch.empty(1, *dims)) self.loc = nn.Parameter(torch.empty(1, *dims)) @@ -42,9 +43,24 @@ def scale(self): return torch.exp(self.log_scale) def initialize(self, batch: torch.Tensor): + if batch.ndim != self.log_scale.ndim: + raise ValueError(f"Expected batch of dimension {self.log_scale.ndim}, but got {batch.ndim}.") + + # we draw the mean and std over all dimensions except the channel dimension + dims = [0] + list(range(2, batch.ndim)) + + loc = torch.mean(batch, dim=dims, keepdim=True) + scale = torch.std(batch, dim=dims, keepdim=True) + + # check for zero std + if torch.any(torch.isclose(scale, torch.tensor(0.0))): + raise ValueError("Failed to initialize ActNorm: One or more channels have zero standard deviation.") + + # slice here to avoid silent device move + self.log_scale.data[:] = torch.log(scale) + self.loc.data[:] = loc + self.is_initialized.data = torch.tensor(True) - self.log_scale.data = torch.log(torch.std(batch, dim=0, keepdim=True)) - self.loc.data = torch.mean(batch, dim=0, keepdim=True) def output_dims(self, input_dims): assert len(input_dims) == 1, "Can only use one input" @@ -61,10 +77,10 @@ def forward(self, x, c=None, rev=False, jac=True): if not rev: out = (x - self.loc) / self.scale - log_jac_det = -utils.sum_except_batch(self.log_scale) + log_jac_det = -utils.sum_except_batch(self.log_scale) * torch.prod(torch.tensor(x.shape[2:]).float()) else: out = self.scale * x + self.loc - log_jac_det = utils.sum_except_batch(self.log_scale) + log_jac_det = utils.sum_except_batch(self.log_scale) * torch.prod(torch.tensor(x.shape[2:]).float()) return (out,), log_jac_det diff --git a/tests/test_invertible_resnet.py b/tests/test_invertible_resnet.py index 0ea1759..f266d91 100644 --- a/tests/test_invertible_resnet.py +++ b/tests/test_invertible_resnet.py @@ -36,8 +36,12 @@ def test_conv(self): self.assertStandardMoments(y_) def assertStandardMoments(self, data): - self.assertTrue(torch.allclose(torch.mean(data, dim=0), torch.zeros(data.shape[-1]), atol=1e-7)) - self.assertTrue(torch.allclose(torch.std(data, dim=0), torch.ones(data.shape[-1]))) + dims = [0] + list(range(2, data.ndim)) + mean = torch.mean(data, dim=dims) + std = torch.std(data, dim=dims) + + self.assertTrue(torch.allclose(mean, torch.zeros_like(mean), atol=1e-7)) + self.assertTrue(torch.allclose(std, torch.ones_like(std))) class IResNetTest(unittest.TestCase):