Skip to content

Commit

Permalink
proper unwrap in AIO block, changed from torch.tensor(numpy) to torch…
Browse files Browse the repository at this point in the history
….from_numpy(numpy) and tensor.view() to tensor.view().contiguous()
  • Loading branch information
RussellALA committed Aug 18, 2023
1 parent e4b1253 commit b983024
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions FrEIA/modules/all_in_one_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,11 +155,11 @@ class or callable ``f``, called as ``f(channels_in, channels_out)`` and
self.vk_householder = nn.Parameter(0.2 * torch.randn(self.householder, channels), requires_grad=True)
self.w_perm = None
self.w_perm_inv = None
self.w_0 = nn.Parameter(torch.FloatTensor(w), requires_grad=False)
self.w_0 = nn.Parameter(torch.from_numpy(w), requires_grad=False)
elif permute_soft:
self.w_perm = nn.Parameter(torch.FloatTensor(w).view(channels, channels, *([1] * self.input_rank)),
self.w_perm = nn.Parameter(torch.from_numpy(w).view(channels, channels, *([1] * self.input_rank)).contiguous(),
requires_grad=False)
self.w_perm_inv = nn.Parameter(torch.FloatTensor(w.T).view(channels, channels, *([1] * self.input_rank)),
self.w_perm_inv = nn.Parameter(torch.from_numpy(w.T).view(channels, channels, *([1] * self.input_rank)).contiguous(),
requires_grad=False)
else:
self.w_perm = nn.Parameter(w_index, requires_grad=False)
Expand Down Expand Up @@ -239,7 +239,7 @@ def _affine(self, x, a, rev=False):

def forward(self, x, c=[], rev=False, jac=True):
'''See base class docstring'''
if x.shape[0][1:] != self.dims_in[0][1:]:
if x[0].shape[1:] != self.dims_in[0][1:]:
raise RuntimeError(f"Expected input of shape {self.dims_in[0]}, "
f"got {x.shape}.")
if self.householder:
Expand Down

0 comments on commit b983024

Please sign in to comment.