Skip to content

Commit

Permalink
internal sync
Browse files Browse the repository at this point in the history
  • Loading branch information
Marton Havasi committed Dec 18, 2024
1 parent 14f99df commit 472979c
Show file tree
Hide file tree
Showing 4 changed files with 175 additions and 39 deletions.
29 changes: 16 additions & 13 deletions flow_matching/solver/ode_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def __init__(self, velocity_model: Union[ModelWrapper, Callable]):
super().__init__()
self.velocity_model = velocity_model

@torch.no_grad()
def sample(
self,
x_init: Tensor,
Expand All @@ -37,6 +36,7 @@ def sample(
rtol: float = 1e-5,
time_grid: Tensor = torch.tensor([0.0, 1.0]),
return_intermediates: bool = False,
enable_grad: bool = False,
**model_extras,
) -> Union[Tensor, Sequence[Tensor]]:
r"""Solve the ODE with the velocity field.
Expand Down Expand Up @@ -72,6 +72,7 @@ def forward(self, x: torch.Tensor, t: torch.Tensor, **extras) -> torch.Tensor:
rtol (float): Relative tolerance, used for adaptive step solvers.
time_grid (Tensor): The process is solved in the interval [min(time_grid, max(time_grid)] and if step_size is None then time discretization is set by the time grid. May specify a descending time_grid to solve in the reverse direction. Defaults to torch.tensor([0.0, 1.0]).
return_intermediates (bool, optional): If True then return intermediate time steps according to time_grid. Defaults to False.
enable_grad (bool, optional): Whether to compute gradients during sampling. Defaults to False.
**model_extras: Additional input for the model.
Returns:
Expand All @@ -85,23 +86,23 @@ def ode_func(t, x):

ode_opts = {"step_size": step_size} if step_size is not None else {}

# Approximate ODE solution with numerical ODE solver
sol = odeint(
ode_func,
x_init,
time_grid,
method=method,
options=ode_opts,
atol=atol,
rtol=rtol,
)
with torch.set_grad_enabled(enable_grad):
# Approximate ODE solution with numerical ODE solver
sol = odeint(
ode_func,
x_init,
time_grid,
method=method,
options=ode_opts,
atol=atol,
rtol=rtol,
)

if return_intermediates:
return sol
else:
return sol[-1]

@torch.no_grad()
def compute_likelihood(
self,
x_1: Tensor,
Expand All @@ -113,6 +114,7 @@ def compute_likelihood(
time_grid: Tensor = torch.tensor([1.0, 0.0]),
return_intermediates: bool = False,
exact_divergence: bool = False,
enable_grad: bool = False,
**model_extras,
) -> Union[Tuple[Tensor, Tensor], Tuple[Sequence[Tensor], Tensor]]:
r"""Solve for log likelihood given a target sample at :math:`t=0`.
Expand All @@ -130,6 +132,7 @@ def compute_likelihood(
time_grid (Tensor): If step_size is None then time discretization is set by the time grid. Must start at 1.0 and end at 0.0, otherwise the likelihood computation is not valid. Defaults to torch.tensor([1.0, 0.0]).
return_intermediates (bool, optional): If True then return intermediate time steps according to time_grid. Otherwise only return the final sample. Defaults to False.
exact_divergence (bool): Whether to compute the exact divergence or use the Hutchinson estimator.
enable_grad (bool, optional): Whether to compute gradients during sampling. Defaults to False.
**model_extras: Additional input for the model.
Returns:
Expand Down Expand Up @@ -174,7 +177,7 @@ def dynamics_func(t, states):
y_init = (x_1, torch.zeros(x_1.shape[0], device=x_1.device))
ode_opts = {"step_size": step_size} if step_size is not None else {}

with torch.no_grad():
with torch.set_grad_enabled(enable_grad):
sol, log_det = odeint(
dynamics_func,
y_init,
Expand Down
53 changes: 29 additions & 24 deletions flow_matching/solver/riemannian_ode_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ def __init__(self, manifold: Manifold, velocity_model: ModelWrapper):
self.manifold = manifold
self.velocity_model = velocity_model

@torch.no_grad()
def sample(
self,
x_init: Tensor,
Expand All @@ -42,6 +41,7 @@ def sample(
time_grid: Tensor = torch.tensor([0.0, 1.0]),
return_intermediates: bool = False,
verbose: bool = False,
enable_grad: bool = False,
**model_extras,
) -> Tensor:
r"""Solve the ODE with the `velocity_field` on the manifold.
Expand All @@ -55,6 +55,7 @@ def sample(
time_grid (Tensor, optional): The process is solved in the interval [min(time_grid, max(time_grid)] and if step_size is None then time discretization is set by the time grid. Defaults to torch.tensor([0.0,1.0]).
return_intermediates (bool, optional): If True then return intermediate time steps according to time_grid. Defaults to False.
verbose (bool, optional): Whether to print progress bars. Defaults to False.
enable_grad (bool, optional): Whether to compute gradients during sampling. Defaults to False.
**model_extras: Additional input for the model.
Returns:
Expand All @@ -68,6 +69,9 @@ def sample(
assert method in step_fns.keys(), f"Unknown method {method}"
step_fn = step_fns[method]

def velocity_func(x, t):
return self.velocity_model(x=x, t=t, **model_extras)

# --- Factor this out.
time_grid = torch.sort(time_grid.to(device=x_init.device)).values

Expand Down Expand Up @@ -98,29 +102,30 @@ def sample(
xts = []
i_ret = 0

xt = x_init
for t0, t1 in zip(t0s, t_discretization[1:]):
dt = t1 - t0
xt_next = step_fn(
self.velocity_model,
xt,
t0,
dt,
manifold=self.manifold,
projx=projx,
proju=proju,
)
if return_intermediates:
while (
i_ret < len(time_grid)
and t0 <= time_grid[i_ret]
and time_grid[i_ret] <= t1
):
xts.append(
interp(self.manifold, xt, xt_next, t0, t1, time_grid[i_ret])
)
i_ret += 1
xt = xt_next
with torch.set_grad_enabled(enable_grad):
xt = x_init
for t0, t1 in zip(t0s, t_discretization[1:]):
dt = t1 - t0
xt_next = step_fn(
velocity_func,
xt,
t0,
dt,
manifold=self.manifold,
projx=projx,
proju=proju,
)
if return_intermediates:
while (
i_ret < len(time_grid)
and t0 <= time_grid[i_ret]
and time_grid[i_ret] <= t1
):
xts.append(
interp(self.manifold, xt, xt_next, t0, t1, time_grid[i_ret])
)
i_ret += 1
xt = xt_next

if return_intermediates:
return torch.stack(xts, dim=0)
Expand Down
74 changes: 72 additions & 2 deletions tests/solver/test_ode_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@ def forward(self, x: torch.Tensor, t: torch.Tensor, **extras) -> torch.Tensor:
class ConstantVelocityModel(ModelWrapper):
def __init__(self):
super().__init__(None)
self.a = torch.nn.Parameter(torch.tensor(1.0))

def forward(self, x: torch.Tensor, t: torch.Tensor, **extras) -> torch.Tensor:
return x * 0.0 + 1.0
return x * 0.0 + self.a


class TestODESolver(unittest.TestCase):
Expand Down Expand Up @@ -75,6 +76,45 @@ def test_sample_with_different_methods(self):
"The solution to the velocity field 3t^3 from 0 to 1 is incorrect.",
)

def test_gradients(self):
x_init = torch.tensor([1.0, 0.0])
step_size = 0.001
time_grid = torch.tensor([0.0, 1.0])

for method in ["euler", "dopri5", "midpoint", "heun3"]:
with self.subTest(method=method):
self.constant_velocity_model.zero_grad()
result = self.constant_velocity_solver.sample(
x_init=x_init,
step_size=step_size if method != "dopri5" else None,
time_grid=time_grid,
method=method,
enable_grad=True,
)
loss = result.sum()
loss.backward()
self.assertAlmostEqual(
self.constant_velocity_model.a.grad, 2.0, delta=1e-4
)

def test_no_gradients(self):
x_init = torch.tensor([1.0, 0.0], requires_grad=True)
step_size = 0.001
time_grid = torch.tensor([0.0, 1.0])

method = "euler"
self.constant_velocity_model.zero_grad()
result = self.constant_velocity_solver.sample(
x_init=x_init,
step_size=step_size,
time_grid=time_grid,
method=method,
)
loss = result.sum()

with self.assertRaises(RuntimeError):
loss.backward()

def test_compute_likelihood(self):
x_1 = torch.tensor([[0.0, 0.0]])
step_size = 0.1
Expand All @@ -92,8 +132,33 @@ def dummy_log_p(x: Tensor) -> Tensor:
self.assertIsInstance(log_likelihood, Tensor)
self.assertEqual(x_1.shape[0], log_likelihood.shape[0])

with self.assertRaises(RuntimeError):
log_likelihood.backward()

def test_compute_likelihood_gradients_non_zero(self):
x_1 = torch.tensor([[0.0, 0.0]], requires_grad=True)
step_size = 0.1

# Define a dummy log probability function
def dummy_log_p(x: Tensor) -> Tensor:
return -0.5 * torch.sum(x**2, dim=1)

_, log_likelihood = self.dummy_solver.compute_likelihood(
x_1=x_1,
log_p0=dummy_log_p,
step_size=step_size,
exact_divergence=False,
enable_grad=True,
)
log_likelihood.backward()
# The gradient is hard to compute analytically, but if the gradients of the flow were 0.0,
# then the gradients of x_1 would be 1.0, which would be incorrect.
self.assertFalse(
torch.allclose(x_1.grad, torch.tensor([1.0, 1.0]), atol=1e-2),
)

def test_compute_likelihood_exact_divergence(self):
x_1 = torch.tensor([[0.0, 0.0]])
x_1 = torch.tensor([[0.0, 0.0]], requires_grad=True)
step_size = 0.1

# Define a dummy log probability function
Expand All @@ -105,6 +170,7 @@ def dummy_log_p(x: Tensor) -> Tensor:
log_p0=dummy_log_p,
step_size=step_size,
exact_divergence=True,
enable_grad=True,
)
self.assertIsInstance(log_likelihood, Tensor)
self.assertEqual(x_1.shape[0], log_likelihood.shape[0])
Expand All @@ -114,6 +180,10 @@ def dummy_log_p(x: Tensor) -> Tensor:
self.assertTrue(
torch.allclose(x_1 - 1.0, x_0, atol=1e-2),
)
log_likelihood.backward()
self.assertTrue(
torch.allclose(x_1.grad, torch.tensor([1.0, 1.0]), atol=1e-2),
)


if __name__ == "__main__":
Expand Down
58 changes: 58 additions & 0 deletions tests/solver/test_riemannian_ode_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,22 @@ def forward(self, x, t):
return torch.zeros_like(x)


class ExtraModel(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, t, must_be_true=False):
assert must_be_true
return torch.zeros_like(x)


class TestRiemannianODESolver(unittest.TestCase):
def setUp(self):
self.manifold = Sphere()
self.velocity_model = HundredVelocityModel()
self.solver = RiemannianODESolver(self.manifold, self.velocity_model)
self.extra_model = ExtraModel()
self.extra_solver = RiemannianODESolver(self.manifold, self.extra_model)

def test_init(self):
self.assertEqual(self.solver.manifold, self.manifold)
Expand Down Expand Up @@ -152,6 +163,53 @@ def test_sample_return_intermediates_euler(self):
)
self.assertEqual(result.shape, (3, 1, 3)) # Two intermediate points

def test_model_extras(self):
x_init = self.manifold.projx(torch.randn(1, 3))
step_size = 0.01
time_grid = torch.tensor([0.0, 0.5, 1.0])
result = self.extra_solver.sample(
x_init,
step_size,
method="euler",
time_grid=time_grid,
return_intermediates=True,
must_be_true=True,
)
self.assertEqual(result.shape, (3, 1, 3))

with self.assertRaises(AssertionError):
result = self.extra_solver.sample(
x_init,
step_size,
method="euler",
time_grid=time_grid,
return_intermediates=True,
)

def test_gradient(self):
x_init = torch.tensor(
self.manifold.projx(torch.randn(1, 3)), requires_grad=True
)
step_size = 0.01
time_grid = torch.tensor([0.0, 1.0])
result = self.solver.sample(
x_init, step_size, method="euler", time_grid=time_grid, enable_grad=True
)
result.sum().backward()
self.assertIsInstance(x_init.grad, torch.Tensor)

def test_no_gradient(self):
x_init = torch.tensor(
self.manifold.projx(torch.randn(1, 3)), requires_grad=True
)
step_size = 0.01
time_grid = torch.tensor([0.0, 1.0])
result = self.solver.sample(
x_init, step_size, method="euler", time_grid=time_grid, enable_grad=False
)
with self.assertRaises(RuntimeError):
result.sum().backward()


if __name__ == "__main__":
unittest.main()

0 comments on commit 472979c

Please sign in to comment.