Skip to content

Commit

Permalink
Merge pull request #165 from vislearn/VariousModuleImprovements
Browse files Browse the repository at this point in the history
Various module improvements
  • Loading branch information
fdraxler authored Aug 23, 2023
2 parents a4d3a7d + 4068375 commit 2eb14cd
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 18 deletions.
31 changes: 19 additions & 12 deletions FrEIA/modules/all_in_one_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,13 @@ class or callable ``f``, called as ``f(channels_in, channels_out)`` and
self.splits = [split_len1, split_len2]

try:
self.permute_function = {0: F.linear,
1: F.conv1d,
2: F.conv2d,
3: F.conv3d}[self.input_rank]
if permute_soft or learned_householder_permutation:
self.permute_function = {0: F.linear,
1: F.conv1d,
2: F.conv2d,
3: F.conv3d}[self.input_rank]
else:
self.permute_function = lambda x, p: x[:, p]
except KeyError:
raise ValueError(f"Data is {1 + self.input_rank}D. Must be 1D-4D.")

Expand Down Expand Up @@ -143,9 +146,7 @@ class or callable ``f``, called as ``f(channels_in, channels_out)`` and
if permute_soft:
w = special_ortho_group.rvs(channels)
else:
w = np.zeros((channels, channels))
for i, j in enumerate(np.random.permutation(channels)):
w[i, j] = 1.
w_index = torch.randperm(channels, requires_grad=False)

if self.householder:
# instead of just the permutation matrix w, the learned housholder
Expand All @@ -154,12 +155,15 @@ class or callable ``f``, called as ``f(channels_in, channels_out)`` and
self.vk_householder = nn.Parameter(0.2 * torch.randn(self.householder, channels), requires_grad=True)
self.w_perm = None
self.w_perm_inv = None
self.w_0 = nn.Parameter(torch.FloatTensor(w), requires_grad=False)
else:
self.w_perm = nn.Parameter(torch.FloatTensor(w).view(channels, channels, *([1] * self.input_rank)),
self.w_0 = nn.Parameter(torch.from_numpy(w).float(), requires_grad=False)
elif permute_soft:
self.w_perm = nn.Parameter(torch.from_numpy(w).float().view(channels, channels, *([1] * self.input_rank)).contiguous(),
requires_grad=False)
self.w_perm_inv = nn.Parameter(torch.FloatTensor(w.T).view(channels, channels, *([1] * self.input_rank)),
self.w_perm_inv = nn.Parameter(torch.from_numpy(w.T).float().view(channels, channels, *([1] * self.input_rank)).contiguous(),
requires_grad=False)
else:
self.w_perm = nn.Parameter(w_index, requires_grad=False)
self.w_perm_inv = nn.Parameter(torch.argsort(w_index), requires_grad=False)

if subnet_constructor is None:
raise ValueError("Please supply a callable subnet_constructor "
Expand Down Expand Up @@ -222,7 +226,7 @@ def _affine(self, x, a, rev=False):
a *= 0.1
ch = x.shape[1]

sub_jac = self.clamp * torch.tanh(a[:, :ch])
sub_jac = self.clamp * torch.tanh(a[:, :ch]/self.clamp)
if self.GIN:
sub_jac -= torch.mean(sub_jac, dim=self.sum_dims, keepdim=True)

Expand All @@ -235,6 +239,9 @@ def _affine(self, x, a, rev=False):

def forward(self, x, c=[], rev=False, jac=True):
'''See base class docstring'''
if tuple(x[0].shape[1:]) != self.dims_in[0]:
raise RuntimeError(f"Expected input of shape {self.dims_in[0]}, "
f"got {tuple(x[0].shape[1:])}.")
if self.householder:
self.w_perm = self._construct_householder_permutation()
if rev or self.reverse_pre_permute:
Expand Down
6 changes: 3 additions & 3 deletions FrEIA/modules/invertible_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ def __init__(self, dims_in, dims_c=None, init_data: torch.Tensor = None):

self.register_buffer("is_initialized", torch.tensor(False))

dim = next(iter(dims_in))[0]
self.log_scale = nn.Parameter(torch.empty(1, dim))
self.loc = nn.Parameter(torch.empty(1, dim))
dims = next(iter(dims_in))
self.log_scale = nn.Parameter(torch.empty(1, *dims))
self.loc = nn.Parameter(torch.empty(1, *dims))

if init_data is not None:
self.initialize(init_data)
Expand Down
27 changes: 25 additions & 2 deletions FrEIA/modules/splines/binned.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class BinnedSplineBase(InvertibleModule):

def __init__(self, dims_in, dims_c=None, bins: int = 10, parameter_counts: Dict[str, int] = None,
min_bin_sizes: Tuple[float] = (0.1, 0.1), default_domain: Tuple[float] = (-3.0, 3.0, -3.0, 3.0),
identity_tails: bool = False) -> None:
identity_tails: bool = False, domain_clamping: float = None) -> None:
"""
Args:
bins: number of bins to use
Expand All @@ -75,6 +75,8 @@ def __init__(self, dims_in, dims_c=None, bins: int = 10, parameter_counts: Dict[
default_domain: tuple of (left, right, bottom, top) default spline domain values
these values will be used as the starting domain (when the network outputs zero)
identity_tails: whether to use identity tails for the spline
domain_clamping: clamping value for the domain, if float,
clamp spline width and height to (-domain_clamping, domain_clamping)
"""
if dims_c is None:
dims_c = []
Expand All @@ -98,6 +100,8 @@ def __init__(self, dims_in, dims_c=None, bins: int = 10, parameter_counts: Dict[
self.register_buffer("identity_tails", torch.tensor(identity_tails, dtype=torch.bool))
self.register_buffer("default_width", torch.as_tensor(default_domain[1] - default_domain[0], dtype=torch.float32))

self.domain_clamping = domain_clamping

# The default parameters are
# parameter constraints count
# 1. the leftmost bin edge - 1
Expand Down Expand Up @@ -131,6 +135,15 @@ def split_parameters(self, parameters: torch.Tensor, split_len: int) -> Dict[str

return dict(zip(keys, values))

def clamp_domain(self, domain: torch.Tensor) -> torch.Tensor:
"""
Clamp domain to the a size between (-domain_clamping, domain_clamping)
"""
if self.domain_clamping is None:
return domain
else:
return self.domain_clamping * torch.tanh(domain / self.domain_clamping)

def constrain_parameters(self, parameters: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
Constrain Parameters to meet certain conditions (e.g. positivity)
Expand All @@ -143,6 +156,7 @@ def constrain_parameters(self, parameters: Dict[str, torch.Tensor]) -> Dict[str,
total_width = parameters["total_width"]
shift = np.log(np.e - 1)
total_width = self.default_width * F.softplus(total_width + shift)
total_width = self.clamp_domain(total_width)
parameters["left"] = -total_width / 2
parameters["bottom"] = -total_width / 2

Expand All @@ -161,7 +175,16 @@ def constrain_parameters(self, parameters: Dict[str, torch.Tensor]) -> Dict[str,

parameters["widths"] = self.min_bin_sizes[0] + F.softplus(parameters["widths"] + xshift)
parameters["heights"] = self.min_bin_sizes[1] + F.softplus(parameters["heights"] + yshift)


domain_width = torch.sum(parameters["widths"], dim=-1, keepdim=True)
domain_height = torch.sum(parameters["heights"], dim=-1, keepdim=True)
width_resize = self.clamp_domain(domain_width) / domain_width
height_resize = self.clamp_domain(domain_height) / domain_height

parameters["widths"] = parameters["widths"] * width_resize
parameters["heights"] = parameters["heights"] * height_resize
parameters["left"] = parameters["left"] * width_resize
parameters["bottom"] = parameters["bottom"] * height_resize

return parameters

Expand Down
3 changes: 2 additions & 1 deletion FrEIA/modules/splines/rational_quadratic.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,8 @@ def rational_quadratic_spline(x: torch.Tensor,

# Eq 29 in the appendix of the paper
discriminant = b ** 2 - 4 * a * c
assert torch.all(discriminant >= 0)
if not torch.all(discriminant >= 0):
raise(RuntimeError(f"Discriminant must be positive, but is violated by {torch.min(discriminant)}"))

xi = 2 * c / (-b - torch.sqrt(discriminant))

Expand Down

0 comments on commit 2eb14cd

Please sign in to comment.