From 472979cb49b2e1a70f88de1f7fe02fd6dbeebc97 Mon Sep 17 00:00:00 2001 From: Marton Havasi Date: Wed, 18 Dec 2024 14:42:28 +0000 Subject: [PATCH] internal sync --- flow_matching/solver/ode_solver.py | 29 ++++---- flow_matching/solver/riemannian_ode_solver.py | 53 +++++++------ tests/solver/test_ode_solver.py | 74 ++++++++++++++++++- tests/solver/test_riemannian_ode_solver.py | 58 +++++++++++++++ 4 files changed, 175 insertions(+), 39 deletions(-) diff --git a/flow_matching/solver/ode_solver.py b/flow_matching/solver/ode_solver.py index d2c1040..8997506 100644 --- a/flow_matching/solver/ode_solver.py +++ b/flow_matching/solver/ode_solver.py @@ -27,7 +27,6 @@ def __init__(self, velocity_model: Union[ModelWrapper, Callable]): super().__init__() self.velocity_model = velocity_model - @torch.no_grad() def sample( self, x_init: Tensor, @@ -37,6 +36,7 @@ def sample( rtol: float = 1e-5, time_grid: Tensor = torch.tensor([0.0, 1.0]), return_intermediates: bool = False, + enable_grad: bool = False, **model_extras, ) -> Union[Tensor, Sequence[Tensor]]: r"""Solve the ODE with the velocity field. @@ -72,6 +72,7 @@ def forward(self, x: torch.Tensor, t: torch.Tensor, **extras) -> torch.Tensor: rtol (float): Relative tolerance, used for adaptive step solvers. time_grid (Tensor): The process is solved in the interval [min(time_grid, max(time_grid)] and if step_size is None then time discretization is set by the time grid. May specify a descending time_grid to solve in the reverse direction. Defaults to torch.tensor([0.0, 1.0]). return_intermediates (bool, optional): If True then return intermediate time steps according to time_grid. Defaults to False. + enable_grad (bool, optional): Whether to compute gradients during sampling. Defaults to False. **model_extras: Additional input for the model. Returns: @@ -85,23 +86,23 @@ def ode_func(t, x): ode_opts = {"step_size": step_size} if step_size is not None else {} - # Approximate ODE solution with numerical ODE solver - sol = odeint( - ode_func, - x_init, - time_grid, - method=method, - options=ode_opts, - atol=atol, - rtol=rtol, - ) + with torch.set_grad_enabled(enable_grad): + # Approximate ODE solution with numerical ODE solver + sol = odeint( + ode_func, + x_init, + time_grid, + method=method, + options=ode_opts, + atol=atol, + rtol=rtol, + ) if return_intermediates: return sol else: return sol[-1] - @torch.no_grad() def compute_likelihood( self, x_1: Tensor, @@ -113,6 +114,7 @@ def compute_likelihood( time_grid: Tensor = torch.tensor([1.0, 0.0]), return_intermediates: bool = False, exact_divergence: bool = False, + enable_grad: bool = False, **model_extras, ) -> Union[Tuple[Tensor, Tensor], Tuple[Sequence[Tensor], Tensor]]: r"""Solve for log likelihood given a target sample at :math:`t=0`. @@ -130,6 +132,7 @@ def compute_likelihood( time_grid (Tensor): If step_size is None then time discretization is set by the time grid. Must start at 1.0 and end at 0.0, otherwise the likelihood computation is not valid. Defaults to torch.tensor([1.0, 0.0]). return_intermediates (bool, optional): If True then return intermediate time steps according to time_grid. Otherwise only return the final sample. Defaults to False. exact_divergence (bool): Whether to compute the exact divergence or use the Hutchinson estimator. + enable_grad (bool, optional): Whether to compute gradients during sampling. Defaults to False. **model_extras: Additional input for the model. Returns: @@ -174,7 +177,7 @@ def dynamics_func(t, states): y_init = (x_1, torch.zeros(x_1.shape[0], device=x_1.device)) ode_opts = {"step_size": step_size} if step_size is not None else {} - with torch.no_grad(): + with torch.set_grad_enabled(enable_grad): sol, log_det = odeint( dynamics_func, y_init, diff --git a/flow_matching/solver/riemannian_ode_solver.py b/flow_matching/solver/riemannian_ode_solver.py index 6eb3e5e..d851e8f 100644 --- a/flow_matching/solver/riemannian_ode_solver.py +++ b/flow_matching/solver/riemannian_ode_solver.py @@ -31,7 +31,6 @@ def __init__(self, manifold: Manifold, velocity_model: ModelWrapper): self.manifold = manifold self.velocity_model = velocity_model - @torch.no_grad() def sample( self, x_init: Tensor, @@ -42,6 +41,7 @@ def sample( time_grid: Tensor = torch.tensor([0.0, 1.0]), return_intermediates: bool = False, verbose: bool = False, + enable_grad: bool = False, **model_extras, ) -> Tensor: r"""Solve the ODE with the `velocity_field` on the manifold. @@ -55,6 +55,7 @@ def sample( time_grid (Tensor, optional): The process is solved in the interval [min(time_grid, max(time_grid)] and if step_size is None then time discretization is set by the time grid. Defaults to torch.tensor([0.0,1.0]). return_intermediates (bool, optional): If True then return intermediate time steps according to time_grid. Defaults to False. verbose (bool, optional): Whether to print progress bars. Defaults to False. + enable_grad (bool, optional): Whether to compute gradients during sampling. Defaults to False. **model_extras: Additional input for the model. Returns: @@ -68,6 +69,9 @@ def sample( assert method in step_fns.keys(), f"Unknown method {method}" step_fn = step_fns[method] + def velocity_func(x, t): + return self.velocity_model(x=x, t=t, **model_extras) + # --- Factor this out. time_grid = torch.sort(time_grid.to(device=x_init.device)).values @@ -98,29 +102,30 @@ def sample( xts = [] i_ret = 0 - xt = x_init - for t0, t1 in zip(t0s, t_discretization[1:]): - dt = t1 - t0 - xt_next = step_fn( - self.velocity_model, - xt, - t0, - dt, - manifold=self.manifold, - projx=projx, - proju=proju, - ) - if return_intermediates: - while ( - i_ret < len(time_grid) - and t0 <= time_grid[i_ret] - and time_grid[i_ret] <= t1 - ): - xts.append( - interp(self.manifold, xt, xt_next, t0, t1, time_grid[i_ret]) - ) - i_ret += 1 - xt = xt_next + with torch.set_grad_enabled(enable_grad): + xt = x_init + for t0, t1 in zip(t0s, t_discretization[1:]): + dt = t1 - t0 + xt_next = step_fn( + velocity_func, + xt, + t0, + dt, + manifold=self.manifold, + projx=projx, + proju=proju, + ) + if return_intermediates: + while ( + i_ret < len(time_grid) + and t0 <= time_grid[i_ret] + and time_grid[i_ret] <= t1 + ): + xts.append( + interp(self.manifold, xt, xt_next, t0, t1, time_grid[i_ret]) + ) + i_ret += 1 + xt = xt_next if return_intermediates: return torch.stack(xts, dim=0) diff --git a/tests/solver/test_ode_solver.py b/tests/solver/test_ode_solver.py index fbbefd2..85259fd 100644 --- a/tests/solver/test_ode_solver.py +++ b/tests/solver/test_ode_solver.py @@ -23,9 +23,10 @@ def forward(self, x: torch.Tensor, t: torch.Tensor, **extras) -> torch.Tensor: class ConstantVelocityModel(ModelWrapper): def __init__(self): super().__init__(None) + self.a = torch.nn.Parameter(torch.tensor(1.0)) def forward(self, x: torch.Tensor, t: torch.Tensor, **extras) -> torch.Tensor: - return x * 0.0 + 1.0 + return x * 0.0 + self.a class TestODESolver(unittest.TestCase): @@ -75,6 +76,45 @@ def test_sample_with_different_methods(self): "The solution to the velocity field 3t^3 from 0 to 1 is incorrect.", ) + def test_gradients(self): + x_init = torch.tensor([1.0, 0.0]) + step_size = 0.001 + time_grid = torch.tensor([0.0, 1.0]) + + for method in ["euler", "dopri5", "midpoint", "heun3"]: + with self.subTest(method=method): + self.constant_velocity_model.zero_grad() + result = self.constant_velocity_solver.sample( + x_init=x_init, + step_size=step_size if method != "dopri5" else None, + time_grid=time_grid, + method=method, + enable_grad=True, + ) + loss = result.sum() + loss.backward() + self.assertAlmostEqual( + self.constant_velocity_model.a.grad, 2.0, delta=1e-4 + ) + + def test_no_gradients(self): + x_init = torch.tensor([1.0, 0.0], requires_grad=True) + step_size = 0.001 + time_grid = torch.tensor([0.0, 1.0]) + + method = "euler" + self.constant_velocity_model.zero_grad() + result = self.constant_velocity_solver.sample( + x_init=x_init, + step_size=step_size, + time_grid=time_grid, + method=method, + ) + loss = result.sum() + + with self.assertRaises(RuntimeError): + loss.backward() + def test_compute_likelihood(self): x_1 = torch.tensor([[0.0, 0.0]]) step_size = 0.1 @@ -92,8 +132,33 @@ def dummy_log_p(x: Tensor) -> Tensor: self.assertIsInstance(log_likelihood, Tensor) self.assertEqual(x_1.shape[0], log_likelihood.shape[0]) + with self.assertRaises(RuntimeError): + log_likelihood.backward() + + def test_compute_likelihood_gradients_non_zero(self): + x_1 = torch.tensor([[0.0, 0.0]], requires_grad=True) + step_size = 0.1 + + # Define a dummy log probability function + def dummy_log_p(x: Tensor) -> Tensor: + return -0.5 * torch.sum(x**2, dim=1) + + _, log_likelihood = self.dummy_solver.compute_likelihood( + x_1=x_1, + log_p0=dummy_log_p, + step_size=step_size, + exact_divergence=False, + enable_grad=True, + ) + log_likelihood.backward() + # The gradient is hard to compute analytically, but if the gradients of the flow were 0.0, + # then the gradients of x_1 would be 1.0, which would be incorrect. + self.assertFalse( + torch.allclose(x_1.grad, torch.tensor([1.0, 1.0]), atol=1e-2), + ) + def test_compute_likelihood_exact_divergence(self): - x_1 = torch.tensor([[0.0, 0.0]]) + x_1 = torch.tensor([[0.0, 0.0]], requires_grad=True) step_size = 0.1 # Define a dummy log probability function @@ -105,6 +170,7 @@ def dummy_log_p(x: Tensor) -> Tensor: log_p0=dummy_log_p, step_size=step_size, exact_divergence=True, + enable_grad=True, ) self.assertIsInstance(log_likelihood, Tensor) self.assertEqual(x_1.shape[0], log_likelihood.shape[0]) @@ -114,6 +180,10 @@ def dummy_log_p(x: Tensor) -> Tensor: self.assertTrue( torch.allclose(x_1 - 1.0, x_0, atol=1e-2), ) + log_likelihood.backward() + self.assertTrue( + torch.allclose(x_1.grad, torch.tensor([1.0, 1.0]), atol=1e-2), + ) if __name__ == "__main__": diff --git a/tests/solver/test_riemannian_ode_solver.py b/tests/solver/test_riemannian_ode_solver.py index 1acf838..0ee5a8b 100644 --- a/tests/solver/test_riemannian_ode_solver.py +++ b/tests/solver/test_riemannian_ode_solver.py @@ -26,11 +26,22 @@ def forward(self, x, t): return torch.zeros_like(x) +class ExtraModel(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, t, must_be_true=False): + assert must_be_true + return torch.zeros_like(x) + + class TestRiemannianODESolver(unittest.TestCase): def setUp(self): self.manifold = Sphere() self.velocity_model = HundredVelocityModel() self.solver = RiemannianODESolver(self.manifold, self.velocity_model) + self.extra_model = ExtraModel() + self.extra_solver = RiemannianODESolver(self.manifold, self.extra_model) def test_init(self): self.assertEqual(self.solver.manifold, self.manifold) @@ -152,6 +163,53 @@ def test_sample_return_intermediates_euler(self): ) self.assertEqual(result.shape, (3, 1, 3)) # Two intermediate points + def test_model_extras(self): + x_init = self.manifold.projx(torch.randn(1, 3)) + step_size = 0.01 + time_grid = torch.tensor([0.0, 0.5, 1.0]) + result = self.extra_solver.sample( + x_init, + step_size, + method="euler", + time_grid=time_grid, + return_intermediates=True, + must_be_true=True, + ) + self.assertEqual(result.shape, (3, 1, 3)) + + with self.assertRaises(AssertionError): + result = self.extra_solver.sample( + x_init, + step_size, + method="euler", + time_grid=time_grid, + return_intermediates=True, + ) + + def test_gradient(self): + x_init = torch.tensor( + self.manifold.projx(torch.randn(1, 3)), requires_grad=True + ) + step_size = 0.01 + time_grid = torch.tensor([0.0, 1.0]) + result = self.solver.sample( + x_init, step_size, method="euler", time_grid=time_grid, enable_grad=True + ) + result.sum().backward() + self.assertIsInstance(x_init.grad, torch.Tensor) + + def test_no_gradient(self): + x_init = torch.tensor( + self.manifold.projx(torch.randn(1, 3)), requires_grad=True + ) + step_size = 0.01 + time_grid = torch.tensor([0.0, 1.0]) + result = self.solver.sample( + x_init, step_size, method="euler", time_grid=time_grid, enable_grad=False + ) + with self.assertRaises(RuntimeError): + result.sum().backward() + if __name__ == "__main__": unittest.main()