diff --git a/zuko/flows.py b/zuko/flows.py index ff78042..1566adf 100644 --- a/zuko/flows.py +++ b/zuko/flows.py @@ -869,7 +869,7 @@ def forward(self, y: Tensor = None) -> Transform: return FreeFormJacobianTransform( f=partial(self.f, y=y), time=self.time, - phi=(y, *self.ode.parameters()), + phi=self.ode.parameters() if y is None else (y, *self.ode.parameters()), exact=self.exact, )