Upstream flowjax changes #152
Labels
good first issue
Good for newcomers
help wanted
Extra attention is needed
normalizing-flows
Needed for adaptation through normalizing-flows
https://github.com/aseyboldt/flowjax contains a couple of changes that we probably want to upstream.
It mostly adds a method to the bijections
inverse_and_gradient
that moves a pair (position, gradient of logp at position) through the bijection.This could also be improved, especially for Coupling flows the current implementation evaluates the neural network twice, when I think once should be enough.
A lot of the sampling time with the transformed adaptation is currently spent in the optimizer here:
https://github.com/aseyboldt/flowjax/blob/main/flowjax/train/data_fit.py#L22
This could be improved by jiting the inner loop over batches. It also seems like jax is recompiling things if the training function is called repeatedly for some reason. Figuring out why that might be and stopping it would help a lot.
The text was updated successfully, but these errors were encountered: