From 257718978a90b28d4354aad7d96d93539a86d4e9 Mon Sep 17 00:00:00 2001 From: Armand Date: Tue, 20 Jun 2023 13:50:17 +0200 Subject: [PATCH 01/12] more descriptive spline coupling error, when discriminant is negative --- FrEIA/modules/splines/rational_quadratic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/FrEIA/modules/splines/rational_quadratic.py b/FrEIA/modules/splines/rational_quadratic.py index 6656fb8..51b5fea 100644 --- a/FrEIA/modules/splines/rational_quadratic.py +++ b/FrEIA/modules/splines/rational_quadratic.py @@ -164,7 +164,7 @@ def rational_quadratic_spline(x: torch.Tensor, # Eq 29 in the appendix of the paper discriminant = b ** 2 - 4 * a * c - assert torch.all(discriminant >= 0) + assert torch.all(discriminant >= 0), f"Discriminant must be positive, but is violated by {torch.min(discriminant)}" xi = 2 * c / (-b - torch.sqrt(discriminant)) From 3ca46d42eddbc6391283199e99997c174759d32d Mon Sep 17 00:00:00 2001 From: Armand Date: Tue, 20 Jun 2023 13:51:57 +0200 Subject: [PATCH 02/12] permutation in AllInOneBlock now via index tensor (saves resolution**2 parameters in image case) --- FrEIA/modules/all_in_one_block.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/FrEIA/modules/all_in_one_block.py b/FrEIA/modules/all_in_one_block.py index 2f07cbe..48c7b03 100644 --- a/FrEIA/modules/all_in_one_block.py +++ b/FrEIA/modules/all_in_one_block.py @@ -102,10 +102,13 @@ class or callable ``f``, called as ``f(channels_in, channels_out)`` and self.splits = [split_len1, split_len2] try: - self.permute_function = {0: F.linear, - 1: F.conv1d, - 2: F.conv2d, - 3: F.conv3d}[self.input_rank] + if permute_soft or learned_householder_permutation: + self.permute_function = {0: F.linear, + 1: F.conv1d, + 2: F.conv2d, + 3: F.conv3d}[self.input_rank] + else: + self.permute_function = lambda x, p: x[:, p] except KeyError: raise ValueError(f"Data is {1 + self.input_rank}D. Must be 1D-4D.") @@ -143,6 +146,7 @@ class or callable ``f``, called as ``f(channels_in, channels_out)`` and if permute_soft: w = special_ortho_group.rvs(channels) else: + w_index = torch.randperm(channels, requires_grad=False) w = np.zeros((channels, channels)) for i, j in enumerate(np.random.permutation(channels)): w[i, j] = 1. @@ -155,11 +159,14 @@ class or callable ``f``, called as ``f(channels_in, channels_out)`` and self.w_perm = None self.w_perm_inv = None self.w_0 = nn.Parameter(torch.FloatTensor(w), requires_grad=False) - else: + elif permute_soft: self.w_perm = nn.Parameter(torch.FloatTensor(w).view(channels, channels, *([1] * self.input_rank)), requires_grad=False) self.w_perm_inv = nn.Parameter(torch.FloatTensor(w.T).view(channels, channels, *([1] * self.input_rank)), requires_grad=False) + else: + self.w_perm = nn.Parameter(w_index, requires_grad=False) + self.w_perm_inv = nn.Parameter(torch.argsort(w_index), requires_grad=False) if subnet_constructor is None: raise ValueError("Please supply a callable subnet_constructor " From ae74fd4cd2be0435ed9ce70fa540a81dfc84c737 Mon Sep 17 00:00:00 2001 From: Armand Date: Thu, 17 Aug 2023 10:21:20 +0200 Subject: [PATCH 03/12] fixed soft clamping in all in one block (higher values no longer scale model output) --- FrEIA/modules/all_in_one_block.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/FrEIA/modules/all_in_one_block.py b/FrEIA/modules/all_in_one_block.py index 48c7b03..81d6a2a 100644 --- a/FrEIA/modules/all_in_one_block.py +++ b/FrEIA/modules/all_in_one_block.py @@ -229,7 +229,7 @@ def _affine(self, x, a, rev=False): a *= 0.1 ch = x.shape[1] - sub_jac = self.clamp * torch.tanh(a[:, :ch]) + sub_jac = self.clamp * torch.tanh(a[:, :ch]/self.clamp) if self.GIN: sub_jac -= torch.mean(sub_jac, dim=self.sum_dims, keepdim=True) From d748f44412fb0bde2b29c3e0fc0dbe024bafb749 Mon Sep 17 00:00:00 2001 From: Armand Date: Fri, 18 Aug 2023 14:33:38 +0200 Subject: [PATCH 04/12] added a new parameter domain_clamping to BinnedSplineBase, which soft-clamps the width of the spline to (-domain_clampling, domain_clamping). Since the total width before clamping is always > 0, effectively the domain is clamped to (0, domain_clamping). --- FrEIA/modules/splines/binned.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/FrEIA/modules/splines/binned.py b/FrEIA/modules/splines/binned.py index 5c6607f..b36034c 100644 --- a/FrEIA/modules/splines/binned.py +++ b/FrEIA/modules/splines/binned.py @@ -64,7 +64,7 @@ class BinnedSplineBase(InvertibleModule): def __init__(self, dims_in, dims_c=None, bins: int = 10, parameter_counts: Dict[str, int] = None, min_bin_sizes: Tuple[float] = (0.1, 0.1), default_domain: Tuple[float] = (-3.0, 3.0, -3.0, 3.0), - identity_tails: bool = False) -> None: + identity_tails: bool = False, domain_clamping: float = None) -> None: """ Args: bins: number of bins to use @@ -75,6 +75,7 @@ def __init__(self, dims_in, dims_c=None, bins: int = 10, parameter_counts: Dict[ default_domain: tuple of (left, right, bottom, top) default spline domain values these values will be used as the starting domain (when the network outputs zero) identity_tails: whether to use identity tails for the spline + domain_clamping: clamping value for the domain """ if dims_c is None: dims_c = [] @@ -92,6 +93,13 @@ def __init__(self, dims_in, dims_c=None, bins: int = 10, parameter_counts: Dict[ assert default_domain[3] - default_domain[2] >= min_bin_sizes[1] * bins, \ "{bins} bins of size {min_bin_sizes[1]} are too large for domain {default_domain[2]} to {default_domain[3]}" + if domain_clamping is not None: + self.clamp_domain = lambda domain: domain_clamping * torch.tanh( + domain / domain_clamping + ) + else: + self.clamp_domain = lambda domain: domain + self.register_buffer("bins", torch.tensor(bins, dtype=torch.int32)) self.register_buffer("min_bin_sizes", torch.as_tensor(min_bin_sizes, dtype=torch.float32)) self.register_buffer("default_domain", torch.as_tensor(default_domain, dtype=torch.float32)) @@ -143,6 +151,7 @@ def constrain_parameters(self, parameters: Dict[str, torch.Tensor]) -> Dict[str, total_width = parameters["total_width"] shift = np.log(np.e - 1) total_width = self.default_width * F.softplus(total_width + shift) + total_width = self.clamp_domain(total_width) parameters["left"] = -total_width / 2 parameters["bottom"] = -total_width / 2 @@ -161,7 +170,16 @@ def constrain_parameters(self, parameters: Dict[str, torch.Tensor]) -> Dict[str, parameters["widths"] = self.min_bin_sizes[0] + F.softplus(parameters["widths"] + xshift) parameters["heights"] = self.min_bin_sizes[1] + F.softplus(parameters["heights"] + yshift) - + + domain_width = torch.sum(parameters["widths"], dim=-1, keepdim=True) + domain_height = torch.sum(parameters["heights"], dim=-1, keepdim=True) + width_resize = self.clamp_domain(domain_width) / domain_width + height_resize = self.clamp_domain(domain_height) / domain_height + + parameters["widths"] = parameters["widths"] * width_resize + parameters["heights"] = parameters["heights"] * height_resize + parameters["left"] = parameters["left"] * width_resize + parameters["bottom"] = parameters["bottom"] * height_resize return parameters From f5690e3c8364c02553d0aebeb3ca0f3a0f07ecd2 Mon Sep 17 00:00:00 2001 From: Armand Date: Fri, 18 Aug 2023 14:39:25 +0200 Subject: [PATCH 05/12] constructed actnorm with full input shape instead of just the channels to be in line with the original actnorm implementation of glow. This full shape was already used when actnorm is initialized, but the inconsistency in the constructor prevented loading of >1D actnorm models. --- FrEIA/modules/invertible_resnet.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/FrEIA/modules/invertible_resnet.py b/FrEIA/modules/invertible_resnet.py index efac11f..7e551cd 100644 --- a/FrEIA/modules/invertible_resnet.py +++ b/FrEIA/modules/invertible_resnet.py @@ -30,9 +30,8 @@ def __init__(self, dims_in, dims_c=None, init_data: torch.Tensor = None): self.register_buffer("is_initialized", torch.tensor(False)) - dim = next(iter(dims_in))[0] - self.log_scale = nn.Parameter(torch.empty(1, dim)) - self.loc = nn.Parameter(torch.empty(1, dim)) + self.log_scale = nn.Parameter(torch.empty(1, *dims_in)) + self.loc = nn.Parameter(torch.empty(1, *dims_in)) if init_data is not None: self.initialize(init_data) From 120c520cbe75b2d671aa620b119a62710882f5ee Mon Sep 17 00:00:00 2001 From: Armand Rousselot Date: Fri, 18 Aug 2023 14:57:03 +0200 Subject: [PATCH 06/12] fixed shape for actnorm constructor --- FrEIA/modules/invertible_resnet.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/FrEIA/modules/invertible_resnet.py b/FrEIA/modules/invertible_resnet.py index 7e551cd..1759120 100644 --- a/FrEIA/modules/invertible_resnet.py +++ b/FrEIA/modules/invertible_resnet.py @@ -30,8 +30,9 @@ def __init__(self, dims_in, dims_c=None, init_data: torch.Tensor = None): self.register_buffer("is_initialized", torch.tensor(False)) - self.log_scale = nn.Parameter(torch.empty(1, *dims_in)) - self.loc = nn.Parameter(torch.empty(1, *dims_in)) + dims = next(iter(dims_in)) + self.log_scale = nn.Parameter(torch.empty(1, *dims)) + self.loc = nn.Parameter(torch.empty(1, *dims)) if init_data is not None: self.initialize(init_data) From ff7221a9c6084afe48848d6ea586a1b926ff414b Mon Sep 17 00:00:00 2001 From: Armand Date: Fri, 18 Aug 2023 15:56:40 +0200 Subject: [PATCH 07/12] deleted unused hard permutation matrix in AllInOne block, more explicit docstring for BinnedSplineBase domain_clamping --- FrEIA/modules/all_in_one_block.py | 3 --- FrEIA/modules/splines/binned.py | 3 ++- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/FrEIA/modules/all_in_one_block.py b/FrEIA/modules/all_in_one_block.py index 81d6a2a..0a74570 100644 --- a/FrEIA/modules/all_in_one_block.py +++ b/FrEIA/modules/all_in_one_block.py @@ -147,9 +147,6 @@ class or callable ``f``, called as ``f(channels_in, channels_out)`` and w = special_ortho_group.rvs(channels) else: w_index = torch.randperm(channels, requires_grad=False) - w = np.zeros((channels, channels)) - for i, j in enumerate(np.random.permutation(channels)): - w[i, j] = 1. if self.householder: # instead of just the permutation matrix w, the learned housholder diff --git a/FrEIA/modules/splines/binned.py b/FrEIA/modules/splines/binned.py index b36034c..44d46fb 100644 --- a/FrEIA/modules/splines/binned.py +++ b/FrEIA/modules/splines/binned.py @@ -75,7 +75,8 @@ def __init__(self, dims_in, dims_c=None, bins: int = 10, parameter_counts: Dict[ default_domain: tuple of (left, right, bottom, top) default spline domain values these values will be used as the starting domain (when the network outputs zero) identity_tails: whether to use identity tails for the spline - domain_clamping: clamping value for the domain + domain_clamping: clamping value for the domain, if float, + clamp spline width and height to (-domain_clamping, domain_clamping) """ if dims_c is None: dims_c = [] From 5896c1bdd15925a7e5346c467a5f4adc9f33ca92 Mon Sep 17 00:00:00 2001 From: Armand Date: Fri, 18 Aug 2023 17:05:30 +0200 Subject: [PATCH 08/12] made clamp_domain a class method instead of lambda, changed discriminant assertion in rational quadratic splines runtime error, raise runtime error in allinone block if x does not match dims --- FrEIA/modules/all_in_one_block.py | 3 +++ FrEIA/modules/splines/binned.py | 18 +++++++++++------- FrEIA/modules/splines/rational_quadratic.py | 3 ++- 3 files changed, 16 insertions(+), 8 deletions(-) diff --git a/FrEIA/modules/all_in_one_block.py b/FrEIA/modules/all_in_one_block.py index 0a74570..50a6ee6 100644 --- a/FrEIA/modules/all_in_one_block.py +++ b/FrEIA/modules/all_in_one_block.py @@ -239,6 +239,9 @@ def _affine(self, x, a, rev=False): def forward(self, x, c=[], rev=False, jac=True): '''See base class docstring''' + if x.shape[1:] != self.dims_in[0][1:]: + raise RuntimeError(f"Expected input of shape {self.dims_in[0]}, " + f"got {x.shape}.") if self.householder: self.w_perm = self._construct_householder_permutation() if rev or self.reverse_pre_permute: diff --git a/FrEIA/modules/splines/binned.py b/FrEIA/modules/splines/binned.py index 44d46fb..ad20009 100644 --- a/FrEIA/modules/splines/binned.py +++ b/FrEIA/modules/splines/binned.py @@ -94,19 +94,14 @@ def __init__(self, dims_in, dims_c=None, bins: int = 10, parameter_counts: Dict[ assert default_domain[3] - default_domain[2] >= min_bin_sizes[1] * bins, \ "{bins} bins of size {min_bin_sizes[1]} are too large for domain {default_domain[2]} to {default_domain[3]}" - if domain_clamping is not None: - self.clamp_domain = lambda domain: domain_clamping * torch.tanh( - domain / domain_clamping - ) - else: - self.clamp_domain = lambda domain: domain - self.register_buffer("bins", torch.tensor(bins, dtype=torch.int32)) self.register_buffer("min_bin_sizes", torch.as_tensor(min_bin_sizes, dtype=torch.float32)) self.register_buffer("default_domain", torch.as_tensor(default_domain, dtype=torch.float32)) self.register_buffer("identity_tails", torch.tensor(identity_tails, dtype=torch.bool)) self.register_buffer("default_width", torch.as_tensor(default_domain[1] - default_domain[0], dtype=torch.float32)) + self.domain_clamping = domain_clamping + # The default parameters are # parameter constraints count # 1. the leftmost bin edge - 1 @@ -140,6 +135,15 @@ def split_parameters(self, parameters: torch.Tensor, split_len: int) -> Dict[str return dict(zip(keys, values)) + def clamp_domain(self, domain: torch.Tensor) -> torch.Tensor: + """ + Clamp domain to the a size between (-domain_clamping, domain_clamping) + """ + if self.domain_clamping is None: + return domain + else: + return self.domain_clamping * torch.tanh(domain / self.domain_clamping) + def constrain_parameters(self, parameters: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ Constrain Parameters to meet certain conditions (e.g. positivity) diff --git a/FrEIA/modules/splines/rational_quadratic.py b/FrEIA/modules/splines/rational_quadratic.py index 51b5fea..e8673d0 100644 --- a/FrEIA/modules/splines/rational_quadratic.py +++ b/FrEIA/modules/splines/rational_quadratic.py @@ -164,7 +164,8 @@ def rational_quadratic_spline(x: torch.Tensor, # Eq 29 in the appendix of the paper discriminant = b ** 2 - 4 * a * c - assert torch.all(discriminant >= 0), f"Discriminant must be positive, but is violated by {torch.min(discriminant)}" + if not torch.all(discriminant >= 0): + raise(RuntimeError(f"Discriminant must be positive, but is violated by {torch.min(discriminant)}")) xi = 2 * c / (-b - torch.sqrt(discriminant)) From e4b1253592476464e1ec03741f23f8abe4148067 Mon Sep 17 00:00:00 2001 From: Armand Date: Fri, 18 Aug 2023 17:11:40 +0200 Subject: [PATCH 09/12] unwrap input in dimension check AllInOne block --- FrEIA/modules/all_in_one_block.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/FrEIA/modules/all_in_one_block.py b/FrEIA/modules/all_in_one_block.py index 50a6ee6..3e6fa18 100644 --- a/FrEIA/modules/all_in_one_block.py +++ b/FrEIA/modules/all_in_one_block.py @@ -239,7 +239,7 @@ def _affine(self, x, a, rev=False): def forward(self, x, c=[], rev=False, jac=True): '''See base class docstring''' - if x.shape[1:] != self.dims_in[0][1:]: + if x.shape[0][1:] != self.dims_in[0][1:]: raise RuntimeError(f"Expected input of shape {self.dims_in[0]}, " f"got {x.shape}.") if self.householder: From b983024c5d369cad25d542d8cbfb27790c743712 Mon Sep 17 00:00:00 2001 From: Armand Date: Fri, 18 Aug 2023 17:17:27 +0200 Subject: [PATCH 10/12] proper unwrap in AIO block, changed from torch.tensor(numpy) to torch.from_numpy(numpy) and tensor.view() to tensor.view().contiguous() --- FrEIA/modules/all_in_one_block.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/FrEIA/modules/all_in_one_block.py b/FrEIA/modules/all_in_one_block.py index 3e6fa18..8ff3708 100644 --- a/FrEIA/modules/all_in_one_block.py +++ b/FrEIA/modules/all_in_one_block.py @@ -155,11 +155,11 @@ class or callable ``f``, called as ``f(channels_in, channels_out)`` and self.vk_householder = nn.Parameter(0.2 * torch.randn(self.householder, channels), requires_grad=True) self.w_perm = None self.w_perm_inv = None - self.w_0 = nn.Parameter(torch.FloatTensor(w), requires_grad=False) + self.w_0 = nn.Parameter(torch.from_numpy(w), requires_grad=False) elif permute_soft: - self.w_perm = nn.Parameter(torch.FloatTensor(w).view(channels, channels, *([1] * self.input_rank)), + self.w_perm = nn.Parameter(torch.from_numpy(w).view(channels, channels, *([1] * self.input_rank)).contiguous(), requires_grad=False) - self.w_perm_inv = nn.Parameter(torch.FloatTensor(w.T).view(channels, channels, *([1] * self.input_rank)), + self.w_perm_inv = nn.Parameter(torch.from_numpy(w.T).view(channels, channels, *([1] * self.input_rank)).contiguous(), requires_grad=False) else: self.w_perm = nn.Parameter(w_index, requires_grad=False) @@ -239,7 +239,7 @@ def _affine(self, x, a, rev=False): def forward(self, x, c=[], rev=False, jac=True): '''See base class docstring''' - if x.shape[0][1:] != self.dims_in[0][1:]: + if x[0].shape[1:] != self.dims_in[0][1:]: raise RuntimeError(f"Expected input of shape {self.dims_in[0]}, " f"got {x.shape}.") if self.householder: From 31ef2b8f75ad73548bb3463970ba334012813540 Mon Sep 17 00:00:00 2001 From: Armand Date: Fri, 18 Aug 2023 17:28:10 +0200 Subject: [PATCH 11/12] fixed typing in AIO input check --- FrEIA/modules/all_in_one_block.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/FrEIA/modules/all_in_one_block.py b/FrEIA/modules/all_in_one_block.py index 8ff3708..631ea8a 100644 --- a/FrEIA/modules/all_in_one_block.py +++ b/FrEIA/modules/all_in_one_block.py @@ -239,9 +239,9 @@ def _affine(self, x, a, rev=False): def forward(self, x, c=[], rev=False, jac=True): '''See base class docstring''' - if x[0].shape[1:] != self.dims_in[0][1:]: + if tuple(x[0].shape[1:]) != self.dims_in[0]: raise RuntimeError(f"Expected input of shape {self.dims_in[0]}, " - f"got {x.shape}.") + f"got {tuple(x[0].shape[1:])}.") if self.householder: self.w_perm = self._construct_householder_permutation() if rev or self.reverse_pre_permute: From 406837555da889c67c14f7e44db5f86a27ff3e8a Mon Sep 17 00:00:00 2001 From: Armand Date: Wed, 23 Aug 2023 18:33:51 +0200 Subject: [PATCH 12/12] soft permutes in AIO now initialize of type float --- FrEIA/modules/all_in_one_block.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/FrEIA/modules/all_in_one_block.py b/FrEIA/modules/all_in_one_block.py index 631ea8a..07bb1e6 100644 --- a/FrEIA/modules/all_in_one_block.py +++ b/FrEIA/modules/all_in_one_block.py @@ -155,11 +155,11 @@ class or callable ``f``, called as ``f(channels_in, channels_out)`` and self.vk_householder = nn.Parameter(0.2 * torch.randn(self.householder, channels), requires_grad=True) self.w_perm = None self.w_perm_inv = None - self.w_0 = nn.Parameter(torch.from_numpy(w), requires_grad=False) + self.w_0 = nn.Parameter(torch.from_numpy(w).float(), requires_grad=False) elif permute_soft: - self.w_perm = nn.Parameter(torch.from_numpy(w).view(channels, channels, *([1] * self.input_rank)).contiguous(), + self.w_perm = nn.Parameter(torch.from_numpy(w).float().view(channels, channels, *([1] * self.input_rank)).contiguous(), requires_grad=False) - self.w_perm_inv = nn.Parameter(torch.from_numpy(w.T).view(channels, channels, *([1] * self.input_rank)).contiguous(), + self.w_perm_inv = nn.Parameter(torch.from_numpy(w.T).float().view(channels, channels, *([1] * self.input_rank)).contiguous(), requires_grad=False) else: self.w_perm = nn.Parameter(w_index, requires_grad=False)