Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance wrapped distributions #414

Open
wants to merge 10 commits into
base: master
Choose a base branch
from

Conversation

alyst
Copy link
Contributor

@alyst alyst commented Jun 25, 2022

Add basic WrappedDistribution type for NoDist and NamedDist and teach them a few tricks like length() and bijector().
I've discovered that these methods are missing when trying to do

DynamicPPL.tilde_assume!!(context, NoDist(prior), @varname(v), varinfo)

where prior was a Product multivariate. With the changes implemented in this PR it is working.

@devmotion
Copy link
Member

I'm slightly worried about the additional complexity introduced by the new abstract type and functions such as wrapped_dist and wrapped_dist_type. Can't we just add whatever definition was missing?

In general, both distributions are only used internally in DynamicPPL and hence only the parts of the Distributions API relevant for DynamicPPL are implemented. What exactly was missing? Did you actually try to call tilde_assume!! directly?

@alyst
Copy link
Contributor Author

alyst commented Jun 25, 2022

I'm slightly worried about the additional complexity introduced by the new abstract type and functions such as wrapped_dist and wrapped_dist_type. Can't we just add whatever definition was missing?

It's just one abstract type and a very few standard boilerplate defs around it (wrapped_distr() etc). OTOH it allows to avoid the duplication of method definitions like length() etc. I see your point, but I think both approaches have advantages in terms of maintenance. Before this patch I had errors about length() and bijector() missing for NoDist, but I can see how more methods from Distributions API might be required in the future, so this PR makes it easier to add them.

Did you actually try to call tilde_assume!! directly?

Yes, I'm not using @model macro, I'm using DynamicPPL directly to have more control and flexibility in statistical models generation.

@devmotion
Copy link
Member

It's just one abstract type and a very few standard boilerplate defs around it (wrapped_distr() etc). OTOH it allows to avoid the duplication of method definitions like length() etc. I see your point, but I think both approaches have advantages in terms of maintenance. Before this patch I had errors about length() and bijector() missing for NoDist, but I can see how more methods from Distributions API might be required in the future, so this PR makes it easier to add them.

I can see that point, but I'm probably biased here towards not adding additional types and things that are potentially useful at some point in the future due to the history of DynamicPPL, and VarInfo in particular: At this point it is really unclear what methods in varinfo.jl are needed, useful or should be removed. That even motivated a complete refactor and rewrite but it is still messy.

So my suggestion would be

  • add a MWE to the tests that is currently failing
  • and add only the missing definitions that make the test pass.

Did you actually try to call tilde_assume!! directly?

It would be interesting to know if that can be reproduced with a regular @model as well, or if there is some problem with how tilde_assume!! was called.

Copy link
Member

@yebai yebai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks sensible to me. I agree with @devmotion that we want to be careful about introducing additional types and functions into DynamicPPL in general. It does seem that this PR only adds an internal type that fixes some known issues.

For the future, we probably want to move distribution_wrappers into src/contrib so it is clear they are not part of the official DynamicPPL API.

@devmotion
Copy link
Member

Can we add at least tests for every new function and type and fix the CI errors?

And I think it would be nice to see as well what actually went wrong and what has to be fixed.

@devmotion
Copy link
Member

Oh it seems maybe @torfjelde has already fixed the problems in 0f9765b?

@alyst
Copy link
Contributor Author

alyst commented Jun 27, 2022

Oh it seems maybe @torfjelde has already fixed the problems in ...

It doesn't define the bijector for NoDist though.

I've added MWE to the tests.

Comment on lines +74 to +77
x ~ NoDist(Product(fill(Uniform(-20, 20), 5)))
for i in eachindex(x)
x[i] ~ Normal(0, 1)
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems quite surprising, I have never seen anyone using NoDist in a model. I'm also not sure, why would you want to do that? When would such a model as the example here be useful?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems quite surprising, I have never seen anyone using NoDist in a model. I'm also not sure, why would you want to do that? When would such a model as the example here be useful?

a) This is a MWE
b) In the real usecase the length of the variable is ~500 elements. When I'm using x[i] ~ ... (or dot_tilde_assume()), the profiling indicates that with the current state of DynamicPPL ~50% of time is spent on indexing individual elements. That's why I've switched to multivariate distribution. With multivariate distribution the indexing overhead is resolved.
c) In the real usecase the prior is logpdf.(Ref(Normal(mean(x), sigma)), x) |> sum |> addlogp!!, so NoDist helps to declare x and its domain (also see d).
d) In the real usecase I'm switching between the evolutionary programming (BlackBoxOptim.jl) and gradient-based methods to get the MAP estimates. So while the model allows alternative parametrization, e.g. xmean ~ Normal(0, 1), xdelta .~ Normal(0, sigma), x = xmean .+ xdelta, it would be suboptimal for crossover operations; also it would introduce one extra degree of freedom.
e) I appreciate your concerns regarding the usability of MWE, but I think the problem of wrapped distributions not supporting all necessary Distributions.jl API is there, and the tests do cover that.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NoDist is an internal workaround/implementation detail but as NamedDist it's no "proper" user-facing distribution. Therefore it was not supposed to be used in a model directly, and not tested and implemented to support such use cases.

More generally, your workarounds and use of internal functionality (also addlogp!! is somewhat internal, the user-facing alternative is @addlogprob! which is still somewhat dangerous - IIRC in some cases it leads to incorrect or at least surprising results) make me wonder if there is some other functionality missing or some part of DynamicPPL that should be changed. I don't think the best solution is to start promoting and supporting such workarounds but rather we should better support the actual use cases and models in the first place. I think ideally you just implement your model in the most natural way and it works.

One thing is still not clear to me (also in your real usecase): Why do you want to declare x with a NoDist?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rather we should better support the actual use cases and models in the first place

I guess what I'm trying to achieve here with NoDist() is to declare x first, and define its prior later.

Why do you want to declare x with a NoDist?

It's not necessary, but I wanted to avoid calculating Uniform priors, both for performance and for having meaningful probabilities.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess what I'm trying to achieve here with NoDist() is to declare x first, and define its prior later.

But what I don't understand is why do you add a statement with NoDist first? You could just provide x as data to the model (if it is not sampled) or sample it from the actual priors (and here just preallocate the array first).

Having different statements for x where one is basically wrong seems a bit strange.

It's not necessary, but I wanted to avoid calculating Uniform priors, both for performance and for having meaningful probabilities.

But if x has a uniform prior, you should use it properly, shouldn't you? If you don't want to include the prior in your log density calculations you could condition on x or only evaluate the loglikelihood (you can even just do it for a subset of parameters).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Would it work properly if I declare truncated(Flat(), a, b) distribution?

Yeah that should work.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@devmotion I'm a bit confused as to whether or not your saying that the fact that @alyst has to do this to achieve the desired performance is undesirable or if you're suggesting that he can achieve the same performance by writing it in a for-loop and pre-allocating? Because if you're saying the former, I think we're all on the same page.

Yes, I meant that it's undesirable that apparently workarounds such as two tilde statements for the same variable are needed to achieve performance.

Maybe we should add an offical way for declaring a variable in the model (i.e., registering it without distribution)? Possibly an official macro (similar to @addlogprob!) that would then make sure that it ends up in the variable structure. I just don't know how it would be implemented exactly. Maybe it would be easiest to only support SimpleVarInfo? I assume it could be useful in cases where you would like to loop but don't want to end up with n different variables x[1], ..., x[n] in the resulting named tuple or dictionary. Alternatively, maybe we could add something like a (arguably also a bit hacky) For/Map distribution that would allow one to write something like

@model function ...
    ...
    x ~ For(1:n) do i
        f(i)
    end
    ...
end

The main difference to the existing possibilities would be that 1) it does not require preallocating an array etc. (such as .~), 2) it does not create n different variables x[1], ..., x[n] (such as a regular for loop), 3) it does not require allocating an array of distributions (such as arraydist/product_distribution) but only create the individual distributions on the fly.

Maybe the better approach would be to not introduce a new distribution but just support something like arraydist(f, xs).

I guess one of the main challenges would be to figure out what the type of arraydist(f, xs) should be. It might not be possible to infer if it is a MultivariateDistribution, MatrixDistribution etc. in general I assume.

Copy link
Contributor Author

@alyst alyst Jun 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it work properly if I declare truncated(Flat(), a, b) distribution?

Yeah that should work.

Actually, Flat() doesn't define cdf(), which is required for truncated(). But even if we define cdf(d::Flat, x) = one(x), then P(a <= d <= b) would be zero. So it would trigger an error in truncated(), and most likely in many other places.
One can define the new FlatBounded(a, b) pseudodistribution, but it looks very similar to NoDist(Uniform(a, b)) to me (except the transformation).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, Flat() doesn't define cdf(), which is required for truncated(). But even if we define cdf(d::Flat, x) = one(x), then P(a <= d <= b) would be zero. So it would trigger an error in truncated(), and most likely in many other places.
One can define the new FlatBounded(a, b) pseudodistribution, but it looks very similar to NoDist(Uniform(a, b)) to me (except the transformation).

Ah I guess this is why we have the FlatPositive rather than just using truncated. But yes, it ends up being very similar to NoDist but not quite: the logpdf_with_transform is going to be different. For NoDist we want no correction but for something like FlatPositive we do we want correction.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I've added bijector for NoDist in #415 now because it's useful for the new getindex(vi, vn, dist) methods introduced (also found a pretty significant bug when combining NoDist + transformed VarInfo) 👍

But, as I said previously, this will produce different results than something like FlatPositive which will, unlike NoDist, also include the log-absdet-jacobian correction.

@torfjelde
Copy link
Member

torfjelde commented Jun 28, 2022

It doesn't define the bijector for NoDist though.

I actually didn't do this deliberately because I'm uncertain if we ever want to hit this. NoDist should represent "don't do anything with this variable", but if we at some point hit bijector(nodist), then this indicates that we might be trying to compute the logabsdetjac correction which actually shouldn't be included in the log-joint computation 😕

So are we certain adding this implementation isn't doing something silently incorrect?

EDIT: See #414 (comment)

@devmotion
Copy link
Member

I was just looking at https://github.com/TuringLang/Turing.jl/blob/master/src/stdlib/distributions.jl for completely unrelated reasons, and discovered
some definitions of Bijectors.logpdf_with_trans(::NoDist, x, t) 😮

Regardless of whether they are useful etc., this seems like one of the worst places to hide them 😄

@torfjelde
Copy link
Member

I was just looking at https://github.com/TuringLang/Turing.jl/blob/master/src/stdlib/distributions.jl for completely unrelated reasons, and discovered some definitions of Bijectors.logpdf_with_trans(::NoDist, x, t) open_mouth

Regardless of whether they are useful etc., this seems like one of the worst places to hide them smile

Those shouldn't be there 😳

@alyst
Copy link
Contributor Author

alyst commented Oct 1, 2022

bors try

@bors
Copy link
Contributor

bors bot commented Oct 1, 2022

🔒 Permission denied

Existing reviewers: click here to make alyst a reviewer

@ParadaCarleton
Copy link
Member

bors try

bors bot added a commit that referenced this pull request Dec 19, 2022
@bors
Copy link
Contributor

bors bot commented Dec 19, 2022

try

Build failed:

@ParadaCarleton
Copy link
Member

bors try

bors bot added a commit that referenced this pull request Dec 19, 2022
@ParadaCarleton
Copy link
Member

@alyst Very sorry for the delay; looks like tests aren't passing ATM.

@bors
Copy link
Contributor

bors bot commented Dec 19, 2022

try

Build failed:

@devmotion
Copy link
Member

Maybe I missed something (haven't checked this PR for a while) but I think @torfjelde's and my concerns above are still valid?

@alyst alyst force-pushed the enhance_wrapped_distr branch 2 times, most recently from dd7b80b to 0fbe51f Compare March 21, 2023 19:31
@yebai yebai mentioned this pull request Oct 24, 2024
2 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants