Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Removed no_grad from solver #19

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion flow_matching/solver/discrete_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from flow_matching.path import MixtureDiscreteProbPath

from flow_matching.solver.solver import Solver
from flow_matching.solver.utils import toggle_grad
from flow_matching.utils import categorical, ModelWrapper
from .utils import get_nearest_times

Expand Down Expand Up @@ -82,7 +83,7 @@ def __init__(

self.source_distribution_p = source_distribution_p

@torch.no_grad()
@toggle_grad
def sample(
self,
x_init: Tensor,
Expand Down
24 changes: 12 additions & 12 deletions flow_matching/solver/ode_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from torchdiffeq import odeint

from flow_matching.solver.solver import Solver
from flow_matching.solver.utils import toggle_grad
from flow_matching.utils import gradient, ModelWrapper


Expand All @@ -27,7 +28,7 @@ def __init__(self, velocity_model: Union[ModelWrapper, Callable]):
super().__init__()
self.velocity_model = velocity_model

@torch.no_grad()
@toggle_grad
def sample(
self,
x_init: Tensor,
Expand Down Expand Up @@ -101,7 +102,7 @@ def ode_func(t, x):
else:
return sol[-1]

@torch.no_grad()
@toggle_grad
def compute_likelihood(
self,
x_1: Tensor,
Expand Down Expand Up @@ -174,16 +175,15 @@ def dynamics_func(t, states):
y_init = (x_1, torch.zeros(x_1.shape[0], device=x_1.device))
ode_opts = {"step_size": step_size} if step_size is not None else {}

with torch.no_grad():
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was this no_grad unnecessary previously?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was unnecessary yes.

sol, log_det = odeint(
dynamics_func,
y_init,
time_grid,
method=method,
options=ode_opts,
atol=atol,
rtol=rtol,
)
sol, log_det = odeint(
dynamics_func,
y_init,
time_grid,
method=method,
options=ode_opts,
atol=atol,
rtol=rtol,
)

x_source = sol[-1]
source_log_p = log_p0(x_source)
Expand Down
3 changes: 2 additions & 1 deletion flow_matching/solver/riemannian_ode_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from tqdm import tqdm

from flow_matching.solver.solver import Solver
from flow_matching.solver.utils import toggle_grad
from flow_matching.utils import ModelWrapper
from flow_matching.utils.manifolds import geodesic, Manifold

Expand All @@ -31,7 +32,7 @@ def __init__(self, manifold: Manifold, velocity_model: ModelWrapper):
self.manifold = manifold
self.velocity_model = velocity_model

@torch.no_grad()
@toggle_grad
def sample(
self,
x_init: Tensor,
Expand Down
11 changes: 11 additions & 0 deletions flow_matching/solver/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# This source code is licensed under the CC-by-NC license found in the
# LICENSE file in the root directory of this source tree.

from functools import wraps

import torch
from torch import Tensor

Expand All @@ -17,3 +19,12 @@ def get_nearest_times(time_grid: Tensor, t_discretization: Tensor) -> Tensor:
nearest_indices = distances.argmin(dim=1)

return t_discretization[nearest_indices]


def toggle_grad(func):
@wraps(func)
def wrapper(*args, enable_grad=False, **kwargs):
with torch.set_grad_enabled(enable_grad):
return func(*args, **kwargs)

return wrapper
52 changes: 50 additions & 2 deletions tests/solver/test_ode_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@ def forward(self, x: torch.Tensor, t: torch.Tensor, **extras) -> torch.Tensor:
class ConstantVelocityModel(ModelWrapper):
def __init__(self):
super().__init__(None)
self.a = torch.nn.Parameter(torch.tensor(1.0))

def forward(self, x: torch.Tensor, t: torch.Tensor, **extras) -> torch.Tensor:
return x * 0.0 + 1.0
return x * 0.0 + self.a


class TestODESolver(unittest.TestCase):
Expand Down Expand Up @@ -75,6 +76,45 @@ def test_sample_with_different_methods(self):
"The solution to the velocity field 3t^3 from 0 to 1 is incorrect.",
)

def test_gradients(self):
x_init = torch.tensor([1.0, 0.0])
step_size = 0.001
time_grid = torch.tensor([0.0, 1.0])

for method in ["euler", "dopri5", "midpoint", "heun3"]:
with self.subTest(method=method):
self.constant_velocity_model.zero_grad()
result = self.constant_velocity_solver.sample(
x_init=x_init,
step_size=step_size if method != "dopri5" else None,
time_grid=time_grid,
method=method,
enable_grad=True,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check grads are not computed without this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added test

)
loss = result.sum()
loss.backward()
self.assertAlmostEqual(
self.constant_velocity_model.a.grad, 2.0, delta=1e-4
)

def test_no_gradients(self):
x_init = torch.tensor([1.0, 0.0])
step_size = 0.001
time_grid = torch.tensor([0.0, 1.0])

method = "euler"
self.constant_velocity_model.zero_grad()
result = self.constant_velocity_solver.sample(
x_init=x_init,
step_size=step_size,
time_grid=time_grid,
method=method,
)
loss = result.sum()

with self.assertRaises(RuntimeError):
loss.backward()

def test_compute_likelihood(self):
x_1 = torch.tensor([[0.0, 0.0]])
step_size = 0.1
Expand All @@ -92,8 +132,11 @@ def dummy_log_p(x: Tensor) -> Tensor:
self.assertIsInstance(log_likelihood, Tensor)
self.assertEqual(x_1.shape[0], log_likelihood.shape[0])

with self.assertRaises(RuntimeError):
log_likelihood.backward()

def test_compute_likelihood_exact_divergence(self):
x_1 = torch.tensor([[0.0, 0.0]])
x_1 = torch.tensor([[0.0, 0.0]], requires_grad=True)
step_size = 0.1

# Define a dummy log probability function
Expand All @@ -105,6 +148,7 @@ def dummy_log_p(x: Tensor) -> Tensor:
log_p0=dummy_log_p,
step_size=step_size,
exact_divergence=True,
enable_grad=True,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check grads not computed without this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added test

)
self.assertIsInstance(log_likelihood, Tensor)
self.assertEqual(x_1.shape[0], log_likelihood.shape[0])
Expand All @@ -114,6 +158,10 @@ def dummy_log_p(x: Tensor) -> Tensor:
self.assertTrue(
torch.allclose(x_1 - 1.0, x_0, atol=1e-2),
)
log_likelihood.backward()
self.assertTrue(
torch.allclose(x_1.grad, torch.tensor([1.0, 1.0]), atol=1e-2),
)


if __name__ == "__main__":
Expand Down
Loading