From e4b1253592476464e1ec03741f23f8abe4148067 Mon Sep 17 00:00:00 2001 From: Armand Date: Fri, 18 Aug 2023 17:11:40 +0200 Subject: [PATCH] unwrap input in dimension check AllInOne block --- FrEIA/modules/all_in_one_block.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/FrEIA/modules/all_in_one_block.py b/FrEIA/modules/all_in_one_block.py index 50a6ee6..3e6fa18 100644 --- a/FrEIA/modules/all_in_one_block.py +++ b/FrEIA/modules/all_in_one_block.py @@ -239,7 +239,7 @@ def _affine(self, x, a, rev=False): def forward(self, x, c=[], rev=False, jac=True): '''See base class docstring''' - if x.shape[1:] != self.dims_in[0][1:]: + if x.shape[0][1:] != self.dims_in[0][1:]: raise RuntimeError(f"Expected input of shape {self.dims_in[0]}, " f"got {x.shape}.") if self.householder: