-
Notifications
You must be signed in to change notification settings - Fork 29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enhance wrapped distributions #414
Open
alyst
wants to merge
10
commits into
TuringLang:master
Choose a base branch
from
alyst:enhance_wrapped_distr
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
f90056b
enhance wrapped distributions
alyst 7f79c23
distr_wrappers: add tests for multivariate distrs
alyst bcca942
add tests for model with multivariate NoDist
alyst ec07ed5
fix commented out tests
alyst e8710b3
fix reviewdog formatting issues
alyst 2976bbf
2nd round of reviewdog fixes
alyst ece33c6
refer WrappedDist and NoDist from API docs
alyst 490257a
export WrappedDist to make docs happy
alyst 71f3304
3rd round of trying to make the format doggy happy
alyst 017fa67
Merge branch 'master' into enhance_wrapped_distr
yebai File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,13 +1,33 @@ | ||
@testset "distribution_wrappers.jl" begin | ||
d = Normal() | ||
nd = DynamicPPL.NoDist(d) | ||
@testset "univariate" begin | ||
d = Normal() | ||
nd = DynamicPPL.NoDist(d) | ||
|
||
# Smoke test | ||
rand(nd) | ||
# Smoke test | ||
rand(nd) | ||
|
||
# Actual tests | ||
@test minimum(nd) == -Inf | ||
@test maximum(nd) == Inf | ||
@test logpdf(nd, 15.0) == 0 | ||
@test Bijectors.logpdf_with_trans(nd, 30, true) == 0 | ||
# Actual tests | ||
@test minimum(nd) == -Inf | ||
@test maximum(nd) == Inf | ||
@test logpdf(nd, 15.0) == 0 | ||
@test Bijectors.logpdf_with_trans(nd, 30, true) == 0 | ||
@test Bijectors.bijector(nd) == Bijectors.bijector(d) | ||
end | ||
|
||
@testset "multivariate" begin | ||
d = Product([Normal(), Uniform()]) | ||
nd = DynamicPPL.NoDist(d) | ||
|
||
# Smoke test | ||
@test length(rand(nd)) == 2 | ||
|
||
# Actual tests | ||
@test length(nd) == 2 | ||
@test size(nd) == (2,) | ||
@test minimum(nd) == [-Inf, 0.0] | ||
@test maximum(nd) == [Inf, 1.0] | ||
@test logpdf(nd, [15.0, 0.5]) == 0 | ||
@test Bijectors.logpdf_with_trans(nd, [0, 1]) == 0 | ||
@test Bijectors.bijector(nd) == Bijectors.bijector(d) | ||
end | ||
end |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems quite surprising, I have never seen anyone using
NoDist
in a model. I'm also not sure, why would you want to do that? When would such a model as the example here be useful?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a) This is a MWE
b) In the real usecase the length of the variable is ~500 elements. When I'm using
x[i] ~ ...
(or dot_tilde_assume()), the profiling indicates that with the current state of DynamicPPL ~50% of time is spent on indexing individual elements. That's why I've switched to multivariate distribution. With multivariate distribution the indexing overhead is resolved.c) In the real usecase the prior is
logpdf.(Ref(Normal(mean(x), sigma)), x) |> sum |> addlogp!!
, soNoDist
helps to declarex
and its domain (also see d).d) In the real usecase I'm switching between the evolutionary programming (BlackBoxOptim.jl) and gradient-based methods to get the MAP estimates. So while the model allows alternative parametrization, e.g.
xmean ~ Normal(0, 1)
,xdelta .~ Normal(0, sigma)
,x = xmean .+ xdelta
, it would be suboptimal for crossover operations; also it would introduce one extra degree of freedom.e) I appreciate your concerns regarding the usability of MWE, but I think the problem of wrapped distributions not supporting all necessary Distributions.jl API is there, and the tests do cover that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NoDist
is an internal workaround/implementation detail but asNamedDist
it's no "proper" user-facing distribution. Therefore it was not supposed to be used in a model directly, and not tested and implemented to support such use cases.More generally, your workarounds and use of internal functionality (also
addlogp!!
is somewhat internal, the user-facing alternative is@addlogprob!
which is still somewhat dangerous - IIRC in some cases it leads to incorrect or at least surprising results) make me wonder if there is some other functionality missing or some part of DynamicPPL that should be changed. I don't think the best solution is to start promoting and supporting such workarounds but rather we should better support the actual use cases and models in the first place. I think ideally you just implement your model in the most natural way and it works.One thing is still not clear to me (also in your real usecase): Why do you want to declare
x
with aNoDist
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess what I'm trying to achieve here with
NoDist()
is to declarex
first, and define its prior later.It's not necessary, but I wanted to avoid calculating Uniform priors, both for performance and for having meaningful probabilities.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But what I don't understand is why do you add a statement with NoDist first? You could just provide x as data to the model (if it is not sampled) or sample it from the actual priors (and here just preallocate the array first).
Having different statements for x where one is basically wrong seems a bit strange.
But if x has a uniform prior, you should use it properly, shouldn't you? If you don't want to include the prior in your log density calculations you could condition on x or only evaluate the loglikelihood (you can even just do it for a subset of parameters).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah that should work.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I meant that it's undesirable that apparently workarounds such as two tilde statements for the same variable are needed to achieve performance.
Maybe we should add an offical way for declaring a variable in the model (i.e., registering it without distribution)? Possibly an official macro (similar to
@addlogprob!
) that would then make sure that it ends up in the variable structure. I just don't know how it would be implemented exactly. Maybe it would be easiest to only support SimpleVarInfo? I assume it could be useful in cases where you would like to loop but don't want to end up with n different variables x[1], ..., x[n] in the resulting named tuple or dictionary. Alternatively, maybe we could add something like a (arguably also a bit hacky)For
/Map
distribution that would allow one to write something likeThe main difference to the existing possibilities would be that 1) it does not require preallocating an array etc. (such as
.~
), 2) it does not createn
different variablesx[1]
, ...,x[n]
(such as a regularfor
loop), 3) it does not require allocating an array of distributions (such asarraydist
/product_distribution
) but only create the individual distributions on the fly.Maybe the better approach would be to not introduce a new distribution but just support something like
arraydist(f, xs)
.I guess one of the main challenges would be to figure out what the type of
arraydist(f, xs)
should be. It might not be possible to infer if it is aMultivariateDistribution
,MatrixDistribution
etc. in general I assume.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually,
Flat()
doesn't definecdf()
, which is required fortruncated()
. But even if we definecdf(d::Flat, x) = one(x)
, thenP(a <= d <= b)
would be zero. So it would trigger an error intruncated()
, and most likely in many other places.One can define the new
FlatBounded(a, b)
pseudodistribution, but it looks very similar toNoDist(Uniform(a, b))
to me (except the transformation).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah I guess this is why we have the
FlatPositive
rather than just usingtruncated
. But yes, it ends up being very similar toNoDist
but not quite: thelogpdf_with_transform
is going to be different. ForNoDist
we want no correction but for something likeFlatPositive
we do we want correction.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So I've added
bijector
forNoDist
in #415 now because it's useful for the newgetindex(vi, vn, dist)
methods introduced (also found a pretty significant bug when combiningNoDist
+ transformedVarInfo
) 👍But, as I said previously, this will produce different results than something like
FlatPositive
which will, unlikeNoDist
, also include the log-absdet-jacobian correction.