Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upstream flowjax changes #152

Open
aseyboldt opened this issue Oct 17, 2024 · 0 comments
Open

Upstream flowjax changes #152

aseyboldt opened this issue Oct 17, 2024 · 0 comments
Labels
good first issue Good for newcomers help wanted Extra attention is needed normalizing-flows Needed for adaptation through normalizing-flows

Comments

@aseyboldt
Copy link
Member

aseyboldt commented Oct 17, 2024

https://github.com/aseyboldt/flowjax contains a couple of changes that we probably want to upstream.

It mostly adds a method to the bijections inverse_and_gradient that moves a pair (position, gradient of logp at position) through the bijection.

This could also be improved, especially for Coupling flows the current implementation evaluates the neural network twice, when I think once should be enough.

A lot of the sampling time with the transformed adaptation is currently spent in the optimizer here:

https://github.com/aseyboldt/flowjax/blob/main/flowjax/train/data_fit.py#L22

This could be improved by jiting the inner loop over batches. It also seems like jax is recompiling things if the training function is called repeatedly for some reason. Figuring out why that might be and stopping it would help a lot.

@aseyboldt aseyboldt added help wanted Extra attention is needed normalizing-flows Needed for adaptation through normalizing-flows good first issue Good for newcomers labels Oct 17, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
good first issue Good for newcomers help wanted Extra attention is needed normalizing-flows Needed for adaptation through normalizing-flows
Projects
None yet
Development

No branches or pull requests

1 participant