diff --git a/FrEIA/framework/graph_inn/nodes.py b/FrEIA/framework/graph_inn/nodes.py index e3c52e9..5ff11a8 100644 --- a/FrEIA/framework/graph_inn/nodes.py +++ b/FrEIA/framework/graph_inn/nodes.py @@ -206,7 +206,7 @@ def forward(self, x_or_z: Iterable[Tensor], f"variables, but should return " f"{len(self.inputs if rev else self.outputs)}.") - if not torch.is_tensor(mod_jac) or mod_jac.shape[0] != out[0].shape[0]: + if not torch.is_tensor(mod_jac) or mod_jac.shape == torch.Size([]) or mod_jac.shape[0] != out[0].shape[0]: if isinstance(mod_jac, (float, int, torch.Tensor)): mod_jac = torch.zeros(out[0].shape[0]).to(out[0].device) \ + mod_jac