Skip to content

Commit

Permalink
Merge pull request #173 from ju-w/fix#170
Browse files Browse the repository at this point in the history
Check for 0-dim tensors
  • Loading branch information
fdraxler authored Mar 19, 2024
2 parents e3618d8 + b3e367e commit 802d840
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion FrEIA/framework/graph_inn/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def forward(self, x_or_z: Iterable[Tensor],
f"variables, but should return "
f"{len(self.inputs if rev else self.outputs)}.")

if not torch.is_tensor(mod_jac) or mod_jac.shape[0] != out[0].shape[0]:
if not torch.is_tensor(mod_jac) or mod_jac.shape == torch.Size([]) or mod_jac.shape[0] != out[0].shape[0]:
if isinstance(mod_jac, (float, int, torch.Tensor)):
mod_jac = torch.zeros(out[0].shape[0]).to(out[0].device) \
+ mod_jac
Expand Down

0 comments on commit 802d840

Please sign in to comment.