From b3e367efa1fe659dcbc6cb00c64a3ad45553cd86 Mon Sep 17 00:00:00 2001 From: ju-w <22564375+ju-w@users.noreply.github.com> Date: Mon, 18 Mar 2024 15:23:27 +0100 Subject: [PATCH] Check for 0-dim tensors Fixes #170 --- FrEIA/framework/graph_inn/nodes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/FrEIA/framework/graph_inn/nodes.py b/FrEIA/framework/graph_inn/nodes.py index e3c52e9..5ff11a8 100644 --- a/FrEIA/framework/graph_inn/nodes.py +++ b/FrEIA/framework/graph_inn/nodes.py @@ -206,7 +206,7 @@ def forward(self, x_or_z: Iterable[Tensor], f"variables, but should return " f"{len(self.inputs if rev else self.outputs)}.") - if not torch.is_tensor(mod_jac) or mod_jac.shape[0] != out[0].shape[0]: + if not torch.is_tensor(mod_jac) or mod_jac.shape == torch.Size([]) or mod_jac.shape[0] != out[0].shape[0]: if isinstance(mod_jac, (float, int, torch.Tensor)): mod_jac = torch.zeros(out[0].shape[0]).to(out[0].device) \ + mod_jac