Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Removed no_grad from solver #19

Closed
wants to merge 2 commits into from
Closed

Removed no_grad from solver #19

wants to merge 2 commits into from

Conversation

mhavasi
Copy link
Contributor

@mhavasi mhavasi commented Dec 17, 2024

It was suggested that no part of the core library, such as the sampler and likelihood computation, should be wrapped in no_grad. This is because if the code is expected to be integrated into other people's projects, it should not enforce no_grad and instead let the user decide whether to track the computation graph. While users can add their own no_grad, it is impossible for them to remove a no_grad that has already been applied.

This PR removes no_grad from the library.

I tested it with the example notebooks and I also added a unit test to make sure we can differentiate through the ode solver and the likelihood computation.

@mhavasi mhavasi requested review from rtqichen and itaigat December 17, 2024 10:59
@mhavasi mhavasi marked this pull request as ready for review December 17, 2024 11:14
step_size=step_size if method != "dopri5" else None,
time_grid=time_grid,
method=method,
enable_grad=True,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check grads are not computed without this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added test

@@ -105,6 +127,7 @@ def dummy_log_p(x: Tensor) -> Tensor:
log_p0=dummy_log_p,
step_size=step_size,
exact_divergence=True,
enable_grad=True,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check grads not computed without this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added test

@@ -174,16 +175,15 @@ def dynamics_func(t, states):
y_init = (x_1, torch.zeros(x_1.shape[0], device=x_1.device))
ode_opts = {"step_size": step_size} if step_size is not None else {}

with torch.no_grad():
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was this no_grad unnecessary previously?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was unnecessary yes.

@rtqichen
Copy link
Contributor

Do the docs for the affected methods get updated with the enable_grad argument?

@mhavasi
Copy link
Contributor Author

mhavasi commented Dec 18, 2024

Moving this PR to internal

@mhavasi mhavasi closed this Dec 18, 2024
@mhavasi mhavasi deleted the marton/no_no_grad branch December 18, 2024 14:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants