diff --git a/flow_matching/solver/discrete_solver.py b/flow_matching/solver/discrete_solver.py index 282c2a0..f3b79fd 100644 --- a/flow_matching/solver/discrete_solver.py +++ b/flow_matching/solver/discrete_solver.py @@ -17,6 +17,7 @@ from flow_matching.path import MixtureDiscreteProbPath from flow_matching.solver.solver import Solver +from flow_matching.solver.utils import toggle_grad from flow_matching.utils import categorical, ModelWrapper from .utils import get_nearest_times @@ -82,7 +83,7 @@ def __init__( self.source_distribution_p = source_distribution_p - @torch.no_grad() + @toggle_grad def sample( self, x_init: Tensor, diff --git a/flow_matching/solver/ode_solver.py b/flow_matching/solver/ode_solver.py index d2c1040..85685cf 100644 --- a/flow_matching/solver/ode_solver.py +++ b/flow_matching/solver/ode_solver.py @@ -11,6 +11,7 @@ from torchdiffeq import odeint from flow_matching.solver.solver import Solver +from flow_matching.solver.utils import toggle_grad from flow_matching.utils import gradient, ModelWrapper @@ -27,7 +28,7 @@ def __init__(self, velocity_model: Union[ModelWrapper, Callable]): super().__init__() self.velocity_model = velocity_model - @torch.no_grad() + @toggle_grad def sample( self, x_init: Tensor, @@ -101,7 +102,7 @@ def ode_func(t, x): else: return sol[-1] - @torch.no_grad() + @toggle_grad def compute_likelihood( self, x_1: Tensor, @@ -174,16 +175,15 @@ def dynamics_func(t, states): y_init = (x_1, torch.zeros(x_1.shape[0], device=x_1.device)) ode_opts = {"step_size": step_size} if step_size is not None else {} - with torch.no_grad(): - sol, log_det = odeint( - dynamics_func, - y_init, - time_grid, - method=method, - options=ode_opts, - atol=atol, - rtol=rtol, - ) + sol, log_det = odeint( + dynamics_func, + y_init, + time_grid, + method=method, + options=ode_opts, + atol=atol, + rtol=rtol, + ) x_source = sol[-1] source_log_p = log_p0(x_source) diff --git a/flow_matching/solver/riemannian_ode_solver.py b/flow_matching/solver/riemannian_ode_solver.py index 6eb3e5e..bd27042 100644 --- a/flow_matching/solver/riemannian_ode_solver.py +++ b/flow_matching/solver/riemannian_ode_solver.py @@ -12,6 +12,7 @@ from tqdm import tqdm from flow_matching.solver.solver import Solver +from flow_matching.solver.utils import toggle_grad from flow_matching.utils import ModelWrapper from flow_matching.utils.manifolds import geodesic, Manifold @@ -31,7 +32,7 @@ def __init__(self, manifold: Manifold, velocity_model: ModelWrapper): self.manifold = manifold self.velocity_model = velocity_model - @torch.no_grad() + @toggle_grad def sample( self, x_init: Tensor, diff --git a/flow_matching/solver/utils.py b/flow_matching/solver/utils.py index f3a34ee..51f80a5 100644 --- a/flow_matching/solver/utils.py +++ b/flow_matching/solver/utils.py @@ -4,6 +4,8 @@ # This source code is licensed under the CC-by-NC license found in the # LICENSE file in the root directory of this source tree. +from functools import wraps + import torch from torch import Tensor @@ -17,3 +19,12 @@ def get_nearest_times(time_grid: Tensor, t_discretization: Tensor) -> Tensor: nearest_indices = distances.argmin(dim=1) return t_discretization[nearest_indices] + + +def toggle_grad(func): + @wraps(func) + def wrapper(*args, enable_grad=False, **kwargs): + with torch.set_grad_enabled(enable_grad): + return func(*args, **kwargs) + + return wrapper diff --git a/tests/solver/test_ode_solver.py b/tests/solver/test_ode_solver.py index fbbefd2..bc87a24 100644 --- a/tests/solver/test_ode_solver.py +++ b/tests/solver/test_ode_solver.py @@ -23,9 +23,10 @@ def forward(self, x: torch.Tensor, t: torch.Tensor, **extras) -> torch.Tensor: class ConstantVelocityModel(ModelWrapper): def __init__(self): super().__init__(None) + self.a = torch.nn.Parameter(torch.tensor(1.0)) def forward(self, x: torch.Tensor, t: torch.Tensor, **extras) -> torch.Tensor: - return x * 0.0 + 1.0 + return x * 0.0 + self.a class TestODESolver(unittest.TestCase): @@ -75,6 +76,27 @@ def test_sample_with_different_methods(self): "The solution to the velocity field 3t^3 from 0 to 1 is incorrect.", ) + def test_gradients(self): + x_init = torch.tensor([1.0, 0.0]) + step_size = 0.001 + time_grid = torch.tensor([0.0, 1.0]) + + for method in ["euler", "dopri5", "midpoint", "heun3"]: + with self.subTest(method=method): + self.constant_velocity_model.zero_grad() + result = self.constant_velocity_solver.sample( + x_init=x_init, + step_size=step_size if method != "dopri5" else None, + time_grid=time_grid, + method=method, + enable_grad=True, + ) + loss = result.sum() + loss.backward() + self.assertAlmostEqual( + self.constant_velocity_model.a.grad, 2.0, delta=1e-4 + ) + def test_compute_likelihood(self): x_1 = torch.tensor([[0.0, 0.0]]) step_size = 0.1 @@ -93,7 +115,7 @@ def dummy_log_p(x: Tensor) -> Tensor: self.assertEqual(x_1.shape[0], log_likelihood.shape[0]) def test_compute_likelihood_exact_divergence(self): - x_1 = torch.tensor([[0.0, 0.0]]) + x_1 = torch.tensor([[0.0, 0.0]], requires_grad=True) step_size = 0.1 # Define a dummy log probability function @@ -105,6 +127,7 @@ def dummy_log_p(x: Tensor) -> Tensor: log_p0=dummy_log_p, step_size=step_size, exact_divergence=True, + enable_grad=True, ) self.assertIsInstance(log_likelihood, Tensor) self.assertEqual(x_1.shape[0], log_likelihood.shape[0]) @@ -114,6 +137,10 @@ def dummy_log_p(x: Tensor) -> Tensor: self.assertTrue( torch.allclose(x_1 - 1.0, x_0, atol=1e-2), ) + log_likelihood.backward() + self.assertTrue( + torch.allclose(x_1.grad, torch.tensor([1.0, 1.0]), atol=1e-2), + ) if __name__ == "__main__":