Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Connect with AdvancedVI.jl #36

Open
Red-Portal opened this issue Jun 24, 2024 · 6 comments
Open

Connect with AdvancedVI.jl #36

Red-Portal opened this issue Jun 24, 2024 · 6 comments

Comments

@Red-Portal
Copy link
Member

Hi all,

Now that [email protected] is shaping up, shall we consider connecting it with NormalizingFlows? I think most of the work would be about setting up documentation and tests matching AdvancedVI. Is there anything else that needs to be done? Also, what are the current experiences with GPUs?

@Red-Portal Red-Portal changed the title Combine with AdvancedVI.jl Connect with AdvancedVI.jl Jun 24, 2024
@zuhengxu
Copy link
Member

Hi @Red-Portal, thanks for bringing this up and great work on AdvancedVI.jl! I'm thinking to first sync the API of NormalizingFlows.jl with AdvancedVI.jl (e.g., enable close form potential, stickandlanding, etc.). The rest of it will just be including some examples. What do you think?

I can't comment too much on the GPU side. As long as Distributions.jl and Bijectors.jl play nicely with CuArray, we should be good to go. This was not case a year ago, but could have changed now. @torfjelde, what is your current opinion on it?

@Red-Portal
Copy link
Member Author

Sounds all good to me! I vaguely remember that you didn't experience much performance gains from using GPUs. Any idea what was the bottleneck?

@zuhengxu
Copy link
Member

Sounds all good to me! I vaguely remember that you didn't experience much performance gains from using GPUs. Any idea what was the bottleneck?

Yes! I didn't do a proper investigation on the bottleneck, so I can't comment too much on it. If I have to take an educated guess, the current implementation of ELBO est is not ``batched''; it uses map function to apply the `elbo_single_sample` to a collection of samples (see this line). In my perspective, this could be one of the biggest performance lose when working with GPUs.

If I recall correctly, the reason why I chose this way is because many Bijectors.jl layers did not support batched input. (The situation might have changed now.)

@Red-Portal
Copy link
Member Author

Red-Portal commented Jun 26, 2024

That makes a lot of sense. I wonder if it would be hard to bypass bijectors and make a batch version locally until upstream catches up. Or maybe re-use as much upstream component while making a local batch interface. I could then make a NomalizingFlows.jl extension in AdvancedVI that matches that.

@Red-Portal
Copy link
Member Author

Actually that isn't even necessary. It would suffice to define rand(rng, q_flow, n_samples here without even going through the Bijectors interface.

@zuhengxu
Copy link
Member

zuhengxu commented Jun 26, 2024

Actually that isn't even necessary. It would suffice to define rand(rng, q_flow, n_samples here without even going through the Bijectors interface.

Ahh I think I didn't explain this part clearly. the rand or logpdf interface of Bijections.jl is not the issue here; it is more relavent to whether the bijective transformation b is compatible with the batchwise operation. To do rand(rng, q_flow, n_sample), ideally we want to dorand from reference to get bunch of x0s and transform it usingb(x0s), which I believe was not fully supported. (similarly for density evaluation, which requires the batched operation on binv).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants