Skip to content

Commit

Permalink
toggle_grad
Browse files Browse the repository at this point in the history
  • Loading branch information
Marton Havasi committed Dec 17, 2024
1 parent 14f99df commit 68f0f96
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 16 deletions.
3 changes: 2 additions & 1 deletion flow_matching/solver/discrete_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from flow_matching.path import MixtureDiscreteProbPath

from flow_matching.solver.solver import Solver
from flow_matching.solver.utils import toggle_grad
from flow_matching.utils import categorical, ModelWrapper
from .utils import get_nearest_times

Expand Down Expand Up @@ -82,7 +83,7 @@ def __init__(

self.source_distribution_p = source_distribution_p

@torch.no_grad()
@toggle_grad
def sample(
self,
x_init: Tensor,
Expand Down
24 changes: 12 additions & 12 deletions flow_matching/solver/ode_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from torchdiffeq import odeint

from flow_matching.solver.solver import Solver
from flow_matching.solver.utils import toggle_grad
from flow_matching.utils import gradient, ModelWrapper


Expand All @@ -27,7 +28,7 @@ def __init__(self, velocity_model: Union[ModelWrapper, Callable]):
super().__init__()
self.velocity_model = velocity_model

@torch.no_grad()
@toggle_grad
def sample(
self,
x_init: Tensor,
Expand Down Expand Up @@ -101,7 +102,7 @@ def ode_func(t, x):
else:
return sol[-1]

@torch.no_grad()
@toggle_grad
def compute_likelihood(
self,
x_1: Tensor,
Expand Down Expand Up @@ -174,16 +175,15 @@ def dynamics_func(t, states):
y_init = (x_1, torch.zeros(x_1.shape[0], device=x_1.device))
ode_opts = {"step_size": step_size} if step_size is not None else {}

with torch.no_grad():
sol, log_det = odeint(
dynamics_func,
y_init,
time_grid,
method=method,
options=ode_opts,
atol=atol,
rtol=rtol,
)
sol, log_det = odeint(
dynamics_func,
y_init,
time_grid,
method=method,
options=ode_opts,
atol=atol,
rtol=rtol,
)

x_source = sol[-1]
source_log_p = log_p0(x_source)
Expand Down
3 changes: 2 additions & 1 deletion flow_matching/solver/riemannian_ode_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from tqdm import tqdm

from flow_matching.solver.solver import Solver
from flow_matching.solver.utils import toggle_grad
from flow_matching.utils import ModelWrapper
from flow_matching.utils.manifolds import geodesic, Manifold

Expand All @@ -31,7 +32,7 @@ def __init__(self, manifold: Manifold, velocity_model: ModelWrapper):
self.manifold = manifold
self.velocity_model = velocity_model

@torch.no_grad()
@toggle_grad
def sample(
self,
x_init: Tensor,
Expand Down
11 changes: 11 additions & 0 deletions flow_matching/solver/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# This source code is licensed under the CC-by-NC license found in the
# LICENSE file in the root directory of this source tree.

from functools import wraps

import torch
from torch import Tensor

Expand All @@ -17,3 +19,12 @@ def get_nearest_times(time_grid: Tensor, t_discretization: Tensor) -> Tensor:
nearest_indices = distances.argmin(dim=1)

return t_discretization[nearest_indices]


def toggle_grad(func):
@wraps(func)
def wrapper(*args, enable_grad=False, **kwargs):
with torch.set_grad_enabled(enable_grad):
return func(*args, **kwargs)

return wrapper
31 changes: 29 additions & 2 deletions tests/solver/test_ode_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@ def forward(self, x: torch.Tensor, t: torch.Tensor, **extras) -> torch.Tensor:
class ConstantVelocityModel(ModelWrapper):
def __init__(self):
super().__init__(None)
self.a = torch.nn.Parameter(torch.tensor(1.0))

def forward(self, x: torch.Tensor, t: torch.Tensor, **extras) -> torch.Tensor:
return x * 0.0 + 1.0
return x * 0.0 + self.a


class TestODESolver(unittest.TestCase):
Expand Down Expand Up @@ -75,6 +76,27 @@ def test_sample_with_different_methods(self):
"The solution to the velocity field 3t^3 from 0 to 1 is incorrect.",
)

def test_gradients(self):
x_init = torch.tensor([1.0, 0.0])
step_size = 0.001
time_grid = torch.tensor([0.0, 1.0])

for method in ["euler", "dopri5", "midpoint", "heun3"]:
with self.subTest(method=method):
self.constant_velocity_model.zero_grad()
result = self.constant_velocity_solver.sample(
x_init=x_init,
step_size=step_size if method != "dopri5" else None,
time_grid=time_grid,
method=method,
enable_grad=True,
)
loss = result.sum()
loss.backward()
self.assertAlmostEqual(
self.constant_velocity_model.a.grad, 2.0, delta=1e-4
)

def test_compute_likelihood(self):
x_1 = torch.tensor([[0.0, 0.0]])
step_size = 0.1
Expand All @@ -93,7 +115,7 @@ def dummy_log_p(x: Tensor) -> Tensor:
self.assertEqual(x_1.shape[0], log_likelihood.shape[0])

def test_compute_likelihood_exact_divergence(self):
x_1 = torch.tensor([[0.0, 0.0]])
x_1 = torch.tensor([[0.0, 0.0]], requires_grad=True)
step_size = 0.1

# Define a dummy log probability function
Expand All @@ -105,6 +127,7 @@ def dummy_log_p(x: Tensor) -> Tensor:
log_p0=dummy_log_p,
step_size=step_size,
exact_divergence=True,
enable_grad=True,
)
self.assertIsInstance(log_likelihood, Tensor)
self.assertEqual(x_1.shape[0], log_likelihood.shape[0])
Expand All @@ -114,6 +137,10 @@ def dummy_log_p(x: Tensor) -> Tensor:
self.assertTrue(
torch.allclose(x_1 - 1.0, x_0, atol=1e-2),
)
log_likelihood.backward()
self.assertTrue(
torch.allclose(x_1.grad, torch.tensor([1.0, 1.0]), atol=1e-2),
)


if __name__ == "__main__":
Expand Down

0 comments on commit 68f0f96

Please sign in to comment.