diff --git a/test/test_algos.py b/test/test_algos.py index 0ea89f14..e55b3147 100644 --- a/test/test_algos.py +++ b/test/test_algos.py @@ -169,10 +169,10 @@ def test_trainable_recon(algorithm): psf = torch.rand(1, 32, 64, 3, dtype=torch_type) data = torch.rand(2, 1, 32, 64, 3, dtype=torch_type) - def pre_process(x, noise): + def pre_process(x, param): return x - def post_process(x, noise): + def post_process(x, param, residual): return x recon = algorithm(