Skip to content

Commit

Permalink
made clamp_domain a class method instead of lambda, changed discrimin…
Browse files Browse the repository at this point in the history
…ant assertion in rational quadratic splines runtime error, raise runtime error in allinone block if x does not match dims
  • Loading branch information
RussellALA committed Aug 18, 2023
1 parent 612d7c6 commit 5896c1b
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 8 deletions.
3 changes: 3 additions & 0 deletions FrEIA/modules/all_in_one_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,9 @@ def _affine(self, x, a, rev=False):

def forward(self, x, c=[], rev=False, jac=True):
'''See base class docstring'''
if x.shape[1:] != self.dims_in[0][1:]:
raise RuntimeError(f"Expected input of shape {self.dims_in[0]}, "
f"got {x.shape}.")
if self.householder:
self.w_perm = self._construct_householder_permutation()
if rev or self.reverse_pre_permute:
Expand Down
18 changes: 11 additions & 7 deletions FrEIA/modules/splines/binned.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,19 +94,14 @@ def __init__(self, dims_in, dims_c=None, bins: int = 10, parameter_counts: Dict[
assert default_domain[3] - default_domain[2] >= min_bin_sizes[1] * bins, \
"{bins} bins of size {min_bin_sizes[1]} are too large for domain {default_domain[2]} to {default_domain[3]}"

if domain_clamping is not None:
self.clamp_domain = lambda domain: domain_clamping * torch.tanh(
domain / domain_clamping
)
else:
self.clamp_domain = lambda domain: domain

self.register_buffer("bins", torch.tensor(bins, dtype=torch.int32))
self.register_buffer("min_bin_sizes", torch.as_tensor(min_bin_sizes, dtype=torch.float32))
self.register_buffer("default_domain", torch.as_tensor(default_domain, dtype=torch.float32))
self.register_buffer("identity_tails", torch.tensor(identity_tails, dtype=torch.bool))
self.register_buffer("default_width", torch.as_tensor(default_domain[1] - default_domain[0], dtype=torch.float32))

self.domain_clamping = domain_clamping

# The default parameters are
# parameter constraints count
# 1. the leftmost bin edge - 1
Expand Down Expand Up @@ -140,6 +135,15 @@ def split_parameters(self, parameters: torch.Tensor, split_len: int) -> Dict[str

return dict(zip(keys, values))

def clamp_domain(self, domain: torch.Tensor) -> torch.Tensor:
"""
Clamp domain to the a size between (-domain_clamping, domain_clamping)
"""
if self.domain_clamping is None:
return domain
else:
return self.domain_clamping * torch.tanh(domain / self.domain_clamping)

def constrain_parameters(self, parameters: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
Constrain Parameters to meet certain conditions (e.g. positivity)
Expand Down
3 changes: 2 additions & 1 deletion FrEIA/modules/splines/rational_quadratic.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,8 @@ def rational_quadratic_spline(x: torch.Tensor,

# Eq 29 in the appendix of the paper
discriminant = b ** 2 - 4 * a * c
assert torch.all(discriminant >= 0), f"Discriminant must be positive, but is violated by {torch.min(discriminant)}"
if not torch.all(discriminant >= 0):
raise(RuntimeError(f"Discriminant must be positive, but is violated by {torch.min(discriminant)}"))

xi = 2 * c / (-b - torch.sqrt(discriminant))

Expand Down

0 comments on commit 5896c1b

Please sign in to comment.