Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Merged by Bors] - Linearization/flattening of SimpleVarInfo #417

Closed
wants to merge 245 commits into from
Closed
Show file tree
Hide file tree
Changes from 173 commits
Commits
Show all changes
245 commits
Select commit Hold shift + click to select a range
14594d6
performing linking in assume rather than implicitly in getindex
torfjelde Jan 7, 2022
0bc279f
added istrans to SimpleVarInfo
torfjelde Jan 7, 2022
81ee12e
Apply suggestions from code review
torfjelde Jan 7, 2022
d39f87d
added a comment
torfjelde Jan 7, 2022
23f34cc
bump patch version
torfjelde Jan 7, 2022
d3ec108
introduced settrans!!
torfjelde Jan 9, 2022
81782c9
added istrans(vi) and renamed all occurences of trans! to trans!!
torfjelde Jan 9, 2022
12bfb42
exclusively use settrans!! to set the istrans for SimpleVarInfo
torfjelde Jan 10, 2022
c2c2417
removed usage of deprecated method in turing tests
torfjelde Jan 10, 2022
2e2cb5c
added docstring to settrans!!
torfjelde Jan 13, 2022
f6c3fc4
include istrans flag in type of SimpleVarInfo instead
torfjelde Jan 27, 2022
d643a78
deprecated settrans! in favour of settrans!!
torfjelde Jan 27, 2022
3cab7d9
added some tests specifically for istrans
torfjelde Jan 27, 2022
8b870dc
formatting
torfjelde Jan 27, 2022
0b304db
fixed bugs for ThreadSafeVarInfo
torfjelde Jan 27, 2022
b146b11
additional constructor for SimpleVarInfo
torfjelde Jan 27, 2022
d170d92
Update src/DynamicPPL.jl
torfjelde Feb 9, 2022
f46183b
added ConstructionBase.jl as dep
torfjelde Feb 9, 2022
7a78eec
added constraint types and doctests
torfjelde Feb 9, 2022
70b3b70
added DocStringExtensions as a dep
torfjelde Feb 9, 2022
a03e8cf
formatting
torfjelde Feb 9, 2022
793c931
remove redundant maybe_link
torfjelde Feb 9, 2022
a9b12fd
fixed typo
torfjelde Feb 9, 2022
b1d7f9a
Merge branch 'master' into tor/link-improvements
yebai Feb 9, 2022
be98961
moved a docstring
torfjelde Feb 11, 2022
3e1588b
fixed bug in tets
torfjelde Feb 11, 2022
7697fce
Merge branch 'tor/link-improvements' of github.com:TuringLang/Dynamic…
torfjelde Feb 11, 2022
3139c62
version bump
torfjelde Feb 11, 2022
3610658
added missing istrans impl
torfjelde Feb 12, 2022
27171ad
fixed bug with istrans
torfjelde Feb 13, 2022
cd2d9d6
fixed issue with getindex_raw for VarInfo
torfjelde Feb 13, 2022
d948cb9
Update src/varinfo.jl
torfjelde Feb 13, 2022
6a3e18f
Merge branch 'master' into tor/link-improvements
torfjelde Jun 10, 2022
d674478
Merge branch 'master' into tor/link-improvements
torfjelde Jun 17, 2022
26d2dbb
getindex of varinfo implementations now optionally takes a Distributi…
torfjelde Jun 22, 2022
3fcba56
use get_index_raw with dist argument
torfjelde Jun 22, 2022
83a9448
added missing assume implementations for SimpleVarInfo
torfjelde Jun 22, 2022
356fa9c
fixed settrans!! for VarInfo
torfjelde Jun 22, 2022
13f037f
formatting
torfjelde Jun 24, 2022
c7544e0
fixed bug where constrained/unconstrained wasn't preserved in setinde…
torfjelde Jun 24, 2022
d1dccf1
hack to avoid type-instabilities for dot_assume with MultivariateDist…
torfjelde Jun 24, 2022
ff7ff4a
style
torfjelde Jun 26, 2022
2f1a2ff
added keys implementations for the models in TestUtils to make testin…
torfjelde Jun 26, 2022
d6311b7
added additional test model which uses dot-assume on MultivariateDist…
torfjelde Jun 26, 2022
ed2fa69
updated tests for SimpleVarInfo
torfjelde Jun 26, 2022
a82be56
added a no-op reconstruct for UnivariateDistribution
torfjelde Jun 26, 2022
7aacee5
fixed tests for loglikelihoods
torfjelde Jun 27, 2022
96f128f
fixed dot_tilde_assume for LikelihoodContext
torfjelde Jun 27, 2022
2e88d08
removed some now redundant explicit calls to maybe_invlink
torfjelde Jun 27, 2022
0f9765b
added impls of size and length for the wrapper distributions so they …
torfjelde Jun 27, 2022
116c95c
bumped version
torfjelde Jun 28, 2022
d797e99
removed redunant explict call to maybe_invlink
torfjelde Jun 28, 2022
44b2f66
added test model with array on RHS of a .~ statement
torfjelde Jun 29, 2022
81cd881
improved some of the default implementations of dot_assume
torfjelde Jun 29, 2022
2e14abd
removed unnecessary code in tests
torfjelde Jun 29, 2022
12adc83
improved linking usage in assumes for SimpleVarInfo
torfjelde Jun 29, 2022
af3e6ba
Merge branch 'master' into tor/link-improvements
yebai Jun 29, 2022
f7501df
added model for testing dynamic constraints
torfjelde Jun 30, 2022
abcabf4
added logjoint_true_with_logabsdet_jacobian to TestUtils
torfjelde Jun 30, 2022
fdee509
added test for dynamic constraints for SimpleVarInfo
torfjelde Jun 30, 2022
e974c83
fixed keys implementation of SimpleVarInfo
torfjelde Jun 30, 2022
6c6d5f5
reverted unintended change
torfjelde Jun 30, 2022
5d5bc88
added example_values and posterior_mean_values methods to models in T…
torfjelde Jun 30, 2022
0498336
demo models in TestUtils are now a bit more complex, including constr…
torfjelde Jun 30, 2022
f86f264
added logprior_true_with_logabsdet_jacobian for demo models
torfjelde Jul 1, 2022
0d31137
fixed mistakes in a couple of models in TestUtils
torfjelde Jul 1, 2022
c52630b
moved varnames method which creates iterator of leaf varnames into Te…
torfjelde Jul 1, 2022
fff060c
updated docstring for test_sampler_demo_models
torfjelde Jul 1, 2022
e21958c
renamed varnames to varname_leaves and renamed keys(model) to varname…
torfjelde Jul 1, 2022
9669345
added test_sampler_on_models as a generalization of test_sampler_demo…
torfjelde Jul 1, 2022
7e02735
updated docs
torfjelde Jul 1, 2022
a412029
added docs for TestUtils.DEMO_MODELS
torfjelde Jul 1, 2022
f3818c3
updated some tests
torfjelde Jul 1, 2022
8b799a4
fixed docstrings
torfjelde Jul 1, 2022
93cb298
fixed docstrings
torfjelde Jul 1, 2022
ba5852b
imprvoed docstring
torfjelde Jul 1, 2022
328f713
improved docstrings
torfjelde Jul 1, 2022
801bd4c
renamed Base.keys(model) to varnames(model) in TestUtils
torfjelde Jul 1, 2022
46f6f4c
added default implementation and docstring for TestUtils.varnames
torfjelde Jul 1, 2022
bcb767b
replace handwritten by DocStringExtensions
torfjelde Jul 1, 2022
c5be1c2
Apply suggestions from @devmotion
torfjelde Jul 1, 2022
f266929
Update src/context_implementations.jl
torfjelde Jul 1, 2022
c2dbbaf
removed some asserts and use broadcast instead of map
torfjelde Jul 1, 2022
1abb46c
replace map with broadcasting to ensure consistent behavior
torfjelde Jul 1, 2022
1086c6c
Update src/simple_varinfo.jl
torfjelde Jul 1, 2022
f2fb4a5
added a method nodist to allow broadcasting NoDist constructor
torfjelde Jul 1, 2022
490d24e
updated some tests
torfjelde Jul 1, 2022
6350ccd
renamed AbstractConstraint to AbstractTransformation and its subtypes
torfjelde Jul 1, 2022
951e4c3
updated tests
torfjelde Jul 1, 2022
dcd92c9
fixed nodist usage
torfjelde Jul 1, 2022
2922ffa
fixed implementation of nodist
torfjelde Jul 1, 2022
5266a4b
fixed typo
torfjelde Jul 1, 2022
3c38710
formatting
torfjelde Jul 1, 2022
ba92f3f
bump patch version
torfjelde Jul 1, 2022
70c864c
fixed ThreadsafeVarInfo
torfjelde Jul 1, 2022
8b6b440
Merge branch 'tor/link-improvements' into tor/test-utils-improvements
torfjelde Jul 1, 2022
5843699
fixed tests of pointwise_loglikelihoods
torfjelde Jul 1, 2022
66f41a9
Apply suggestions from code review
torfjelde Jul 1, 2022
eb2d6b5
allow type-stable settrans!! for SimpleVarInfo
torfjelde Jul 1, 2022
e8cdb91
use maybe_invlink in getindex for VarInfo
torfjelde Jul 1, 2022
359d384
added comment to warn about buggy behavior
torfjelde Jul 1, 2022
0b20f09
Merge branch 'tor/link-improvements' into tor/test-utils-improvements
torfjelde Jul 1, 2022
ab0a99b
Update src/context_implementations.jl
torfjelde Jul 1, 2022
dd10913
just fix potential bug in getindex for VarInfo
torfjelde Jul 1, 2022
18d28cc
revert previous change because it likely introduces bugs
torfjelde Jul 1, 2022
32b7aab
elaborate in comment regarding potential bug
torfjelde Jul 1, 2022
fb86231
Merge branch 'tor/link-improvements' of github.com:TuringLang/Dynamic…
torfjelde Jul 1, 2022
f782fe2
added error message to dot_assume
torfjelde Jul 1, 2022
7d3493d
added error message to dot_assume again
torfjelde Jul 1, 2022
2b1893c
Merge branch 'tor/link-improvements' into tor/test-utils-improvements
torfjelde Jul 1, 2022
912d7f8
Apply suggestions from code review
torfjelde Jul 2, 2022
a276e4a
renamed posterior_mean_values to posterior_mean
torfjelde Jul 2, 2022
626eea2
made demo models a bit more complex, now including different observat…
torfjelde Jul 2, 2022
1558924
Update docs/src/api.md
torfjelde Jul 2, 2022
a62c881
reduce number of method definitions by defining some useful type unio…
torfjelde Jul 2, 2022
5cc195a
removed unnecessary method
torfjelde Jul 2, 2022
ea5a7a4
Merge branch 'tor/test-utils-improvements' of github.com:TuringLang/D…
torfjelde Jul 2, 2022
702f2ff
fixed a couple of loglikelihood_true definitions
torfjelde Jul 2, 2022
d8f4970
style
torfjelde Jul 2, 2022
56f30bc
added tests for logprior and loglikelihood computation for SimpleVarInfo
torfjelde Jul 2, 2022
2eaef02
fixed implementation of logpdf_with_trans for NoDist
torfjelde Jul 2, 2022
f0f981b
added _protect_dists method to help with broadcasting of NoDist
torfjelde Jul 2, 2022
8063d1e
Merge branch 'tor/link-improvements' into tor/test-utils-improvements
torfjelde Jul 2, 2022
1e0b946
simplified show for SimpleVarInfo
torfjelde Jul 2, 2022
faa0e42
styling
torfjelde Jul 2, 2022
2935bde
Merge branch 'tor/link-improvements' into tor/test-utils-improvements
torfjelde Jul 2, 2022
78f22e1
removed unused variable
torfjelde Jul 3, 2022
025a4d4
added test for transformed values for the logprior_true and loglikeli…
torfjelde Jul 3, 2022
9e7f493
fixed bug in show for SimpleVarInfo
torfjelde Jul 3, 2022
a72e9b8
Merge branch 'tor/link-improvements' into tor/test-utils-improvements
torfjelde Jul 3, 2022
0a9383b
Revert "added _protect_dists method to help with broadcasting of NoDist"
torfjelde Jul 3, 2022
c057080
Merge branch 'tor/link-improvements' into tor/test-utils-improvements
torfjelde Jul 3, 2022
f5c60ae
renamed test_sampler_on_models to test_sampler
torfjelde Jul 3, 2022
d8b0a75
fixed getindex with vector of varnames for AbstractVarInfo
torfjelde Jul 3, 2022
7149c02
Merge branch 'tor/link-improvements' into tor/test-utils-improvements
torfjelde Jul 3, 2022
25f05de
updated docs
torfjelde Jul 3, 2022
e05fa29
share implementation of example_values
torfjelde Jul 3, 2022
431664d
Apply suggestions from code review
torfjelde Jul 4, 2022
363ebae
added unflatten and values_as for Vector
torfjelde Jul 6, 2022
66424f8
added getindex for AbstractVarInfo with Colon
torfjelde Jul 6, 2022
3bc27f8
added unflatten to VarInfo
torfjelde Jul 7, 2022
ca5b080
added make_default_varinfo allowing specification of how to initialize
torfjelde Jul 7, 2022
cb05fc9
added unflatten also taking sampler for SimpleVarInfo
torfjelde Jul 10, 2022
45445cf
added tonamedtuple impl for SimpleVarInfo
torfjelde Jul 10, 2022
ea8f844
fixed implementation of unflatten for arrays
torfjelde Jul 10, 2022
52274ba
added default impl of unflatten taking sampler
torfjelde Jul 10, 2022
7da0ee9
improved tonamedtuple for SimpleVarInfo with Dict
torfjelde Jul 12, 2022
b3499a3
added marginal_mean_of_samples according to suggestions
torfjelde Jul 12, 2022
2bd5dcd
removed example_values in favour of rand with NamedTuple
torfjelde Jul 13, 2022
61a594c
updated docs
torfjelde Jul 13, 2022
ce5f6e4
Merge branch 'tor/test-utils-improvements' into tor/simple-varinfo-li…
torfjelde Jul 13, 2022
6c941bd
fixed method ambiguity error
torfjelde Jul 18, 2022
5e92e56
added islinked for SimpleVarInfo
torfjelde Jul 18, 2022
aabc45a
formatting
torfjelde Jul 18, 2022
939540c
added link!! and invlink!! as BangBang alternatives to link! and invl…
torfjelde Jul 18, 2022
aecf97f
added specialized implementation for NamedBijector and SimpleVarInfo
torfjelde Jul 18, 2022
9dcefdb
use inverse instead of inv
torfjelde Jul 19, 2022
fd0796b
preserve DefaultTransformation
torfjelde Jul 19, 2022
2ef1f59
Merge branch 'master' into tor/simple-varinfo-linearization
torfjelde Jul 23, 2022
94e5d48
removed duplicated defs
torfjelde Jul 23, 2022
15fdf19
style
torfjelde Jul 23, 2022
48dfb9c
fixed empty!! and added isempty for SimpleVarInfo
torfjelde Jul 23, 2022
70ba82d
added setindex!! for sampler with SimpleVarInfo
torfjelde Jul 23, 2022
d3bff26
made values_as compatible with empty SimpleVarInfo
torfjelde Jul 23, 2022
b14e9cf
added tests for base functionality for SimpleVarInfo too
torfjelde Jul 23, 2022
9f106fa
renamed bijectors.jl to transforming.jl
torfjelde Jul 23, 2022
bf34356
fixed update of logp after initialize_parameters!!
torfjelde Jul 23, 2022
3e5f763
remove now-redundant todo
torfjelde Jul 23, 2022
e649f37
improved the initial step
torfjelde Jul 23, 2022
5941270
fixed bug with initialize_parameters!! introduced in previous commit
torfjelde Jul 23, 2022
f30b875
Update src/sampler.jl
torfjelde Jul 23, 2022
cb3e1f4
add some comments on tonamedtuple
torfjelde Jul 23, 2022
f79fab4
Apply suggestions from code review
torfjelde Jul 24, 2022
58c2550
Update src/transforming.jl
torfjelde Jul 24, 2022
c316e70
renamed make_default_varinfo to default_varinfo
torfjelde Jul 24, 2022
998fcf4
simplified impls of getindex
torfjelde Jul 24, 2022
0913a24
Apply suggestions from code review
torfjelde Jul 24, 2022
9af2638
made impls of default getindex for VarInfo a bit more sensible
torfjelde Jul 24, 2022
482ade7
removed unnecessary namespace specification
torfjelde Jul 24, 2022
b79bf28
use isempty(vi) instead of checking its values
torfjelde Jul 24, 2022
15087c5
fix values_as for certain combinations
torfjelde Jul 24, 2022
88dbdca
added deprecation warnings for link! and invlink!
torfjelde Jul 25, 2022
9b3c40f
add logabsdet-jacobian term in link! and invlink!
torfjelde Jul 27, 2022
57d321c
use context to implement link!! and invlink!!
torfjelde Jul 27, 2022
4409149
added tests for link!! and invlink!!
torfjelde Jul 27, 2022
c34f257
added a note comment
torfjelde Jul 27, 2022
73765e7
renamed DefaultTransformation to LazyTransformation and
torfjelde Jul 27, 2022
0c0c393
added maybe_invlink_before_eval!! allowing invlinking once
torfjelde Jul 27, 2022
5e51755
formatting
torfjelde Jul 27, 2022
3dbc7a9
use OrderedDict instead of Dict for SimpleVarInfo as it preserves the
torfjelde Jul 28, 2022
809de9a
added compat entry for OrderedCollections
torfjelde Jul 29, 2022
656175f
added compat entry for OrderedCollections
torfjelde Jul 29, 2022
a225978
use OrderedDict instead of Dict for SimpleVarInfo as it preserves the
torfjelde Jul 28, 2022
fce67ee
improvements to values_as
torfjelde Jul 30, 2022
b165b35
export values_as
torfjelde Jul 30, 2022
af9c520
added values_as to docs
torfjelde Jul 30, 2022
47c30e3
added proper testing for values_as
torfjelde Jul 30, 2022
40477a4
bump patch version
torfjelde Jul 30, 2022
a972f8e
Apply suggestions from code review
torfjelde Jul 30, 2022
ab2a8b5
use ConstructionBase explicitly
torfjelde Aug 18, 2022
ee7fcd6
use OrderedDict in rand instead of NamedTuple as it supports arbitrar…
torfjelde Aug 18, 2022
a44e712
Merge branch 'tor/minor-varinfo-improvements' into tor/simple-varinfo…
torfjelde Aug 19, 2022
3e869ff
Merge branch 'master' into tor/simple-varinfo-linearization
torfjelde Sep 7, 2022
08de024
Merge branch 'master' into tor/simple-varinfo-linearization
torfjelde Sep 9, 2022
0082505
properly deprecate link! and invlink!
torfjelde Sep 9, 2022
597dfda
added transformation impls for ThreadSafeVarInfo
torfjelde Sep 9, 2022
be7ae6c
added missing impl of values_as for VarInfo and Vector
torfjelde Sep 9, 2022
8dfc7c0
use inverse instead of deprecated inv
torfjelde Sep 9, 2022
e475c87
Update src/transforming.jl
torfjelde Sep 9, 2022
1e6b0a9
renamed nested_haskey and defined common method called getvalue and
torfjelde Sep 9, 2022
ce5757a
Merge branch 'tor/simple-varinfo-linearization' of github.com:TuringL…
torfjelde Sep 9, 2022
e23763b
minor version bump
torfjelde Sep 9, 2022
43b034d
removed unnecessary and confusing constructor
torfjelde Oct 6, 2022
5c7df84
added TODO comment to deprecate
torfjelde Oct 19, 2022
5c7163d
renamed LazyTransformation to DynamicTransformation
torfjelde Oct 24, 2022
5f21dd7
updated docs and docstrings
torfjelde Oct 24, 2022
e044676
Update docs/make.jl
yebai Oct 25, 2022
effec2b
increased tolerance in one of the tests
torfjelde Oct 25, 2022
c7240ab
Merge branch 'tor/simple-varinfo-linearization' of github.com:TuringL…
torfjelde Oct 25, 2022
5fabd07
increase tolerance of tests
torfjelde Oct 25, 2022
2e0fe49
updated docs for varinfos, in particular the shared interface
torfjelde Oct 25, 2022
965fcf5
big refactoring of the varinfo related implementations and docs
torfjelde Oct 27, 2022
200a886
Update src/abstract_varinfo.jl
yebai Oct 27, 2022
5a0296e
Update src/abstract_varinfo.jl
yebai Oct 27, 2022
a2e332e
Update src/varinfo.jl
yebai Oct 27, 2022
d950635
Update varinfo.jl
yebai Oct 29, 2022
e0907c1
fixed bugs with linking
torfjelde Oct 31, 2022
6aceac4
Merge branch 'tor/simple-varinfo-linearization' of github.com:TuringL…
torfjelde Oct 31, 2022
e5d8984
fixed threadsafevarinfo issues
torfjelde Oct 31, 2022
63b3638
added tests for StaticBijector
torfjelde Nov 1, 2022
3ffeef1
added impl of maybe_invlink_before_eval!! for VarInfo
torfjelde Nov 1, 2022
8ed91ec
fixed bug in invlink!! for StaticBijector
torfjelde Nov 1, 2022
8ddfb4c
added maybe_invlink_before_eval!! impl for ThreadSafeVarInfo
torfjelde Nov 1, 2022
3246cf4
fixed bug in doctests
torfjelde Nov 1, 2022
74b5d93
relax constraint on istrans
torfjelde Nov 2, 2022
c6264e5
fixed unflatten for Dict to respect the original type
torfjelde Nov 2, 2022
da04c7b
suggest using OrderedDict instead of Dict in docstrings
torfjelde Nov 2, 2022
41fd89a
fixed doctest
torfjelde Nov 2, 2022
e7b8b10
added docs for unflatten for varinfos
torfjelde Nov 2, 2022
732e94b
Merge branch 'master' into tor/simple-varinfo-linearization
yebai Nov 2, 2022
1c1b6ed
added comment to explain settrans!! for VarInfo
torfjelde Nov 3, 2022
fbf9e0a
renamed MaybeThreadSafeVarInfo
torfjelde Nov 3, 2022
a427983
added comment on maybe_inlink_before_eval!!
torfjelde Nov 3, 2022
6b126b8
removed unnecessary defs in tests
torfjelde Nov 3, 2022
d715b0c
formatting
torfjelde Nov 3, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ export AbstractVarInfo,
setorder!,
istrans,
link!,
link!!,
invlink!,
invlink!!,
tonamedtuple,
# VarName (reexport from AbstractPPL)
VarName,
Expand Down Expand Up @@ -150,5 +152,6 @@ include("compat/ad.jl")
include("loglikelihoods.jl")
include("submodel_macro.jl")
include("test_utils.jl")
include("transforming.jl")

end # module
37 changes: 23 additions & 14 deletions src/sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,21 @@ function AbstractMCMC.step(
return vi, nothing
end

function make_default_varinfo(
rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler
)
return make_default_varinfo(rng, model, sampler, DefaultContext())
end
function make_default_varinfo(
rng::Random.AbstractRNG,
model::Model,
sampler::AbstractSampler,
context::AbstractContext,
)
init_sampler = initialsampler(sampler)
return VarInfo(rng, model, init_sampler, context)
end
torfjelde marked this conversation as resolved.
Show resolved Hide resolved

# initial step: general interface for resuming and
function AbstractMCMC.step(
rng::Random.AbstractRNG,
Expand All @@ -82,23 +97,17 @@ function AbstractMCMC.step(
end

# Sample initial values.
_spl = initialsampler(spl)
vi = VarInfo(rng, model, _spl)
vi = default_varinfo(rng, model, spl)

# Update the parameters if provided.
if init_params !== nothing
vi = initialize_parameters!!(vi, init_params, spl)
vi = initialize_parameters!!(vi, init_params, spl, model)

# Update joint log probability.
# TODO: fix properly by using sampler and evaluation contexts
# This is a quick fix for https://github.com/TuringLang/Turing.jl/issues/1588
# and https://github.com/TuringLang/Turing.jl/issues/1563
# to avoid that existing variables are resampled
if _spl isa SampleFromUniform
model(rng, vi, SampleFromPrior())
else
model(rng, vi, _spl)
end
vi = last(evaluate!!(model, vi, DefaultContext()))
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
end

return initialstep(rng, model, spl, vi; init_params=init_params, kwargs...)
Expand All @@ -121,7 +130,9 @@ By default, it returns an instance of [`SampleFromPrior`](@ref).
"""
initialsampler(spl::Sampler) = SampleFromPrior()

function initialize_parameters!!(vi::AbstractVarInfo, init_params, spl::Sampler)
function initialize_parameters!!(
vi::AbstractVarInfo, init_params, spl::Sampler, model::Model
)
@debug "Using passed-in initial variable values" init_params

# Flatten parameters.
Expand All @@ -132,8 +143,7 @@ function initialize_parameters!!(vi::AbstractVarInfo, init_params, spl::Sampler)
# Get all values.
linked = islinked(vi, spl)
if linked
# TODO: Make work with immutable `vi`.
invlink!(vi, spl)
vi = invlink!!(vi, spl, model)
end
theta = vi[spl]
length(theta) == length(init_theta) ||
Expand All @@ -150,8 +160,7 @@ function initialize_parameters!!(vi::AbstractVarInfo, init_params, spl::Sampler)
# Update in `vi`.
vi = setindex!!(vi, theta, spl)
if linked
# TODO: Make work with immutable `vi`.
link!(vi, spl)
vi = link!!(vi, spl, model)
end

return vi
Expand Down
105 changes: 94 additions & 11 deletions src/simple_varinfo.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,3 @@
abstract type AbstractTransformation end

struct NoTransformation <: AbstractTransformation end
struct DefaultTransformation <: AbstractTransformation end

"""
$(TYPEDEF)

Expand Down Expand Up @@ -197,6 +192,8 @@ struct SimpleVarInfo{NT,T,C<:AbstractTransformation} <: AbstractVarInfo
transformation::C
end

transformation(vi::SimpleVarInfo) = vi.transformation

SimpleVarInfo(values, logp) = SimpleVarInfo(values, logp, NoTransformation())

function SimpleVarInfo{T}(θ) where {T<:Real}
Expand Down Expand Up @@ -227,9 +224,17 @@ function SimpleVarInfo{T}(
return SimpleVarInfo(values, convert(T, getlogp(vi)))
end

SimpleVarInfo(svi::SimpleVarInfo, spl, x::AbstractVector) = unflatten(svi, x)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this needed? Seems a bit like introducing some of the surprising VarInfo constructors.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah no, good catch!


unflatten(svi::SimpleVarInfo, spl, x::AbstractVector) = unflatten(svi, x)
yebai marked this conversation as resolved.
Show resolved Hide resolved
function unflatten(svi::SimpleVarInfo, x::AbstractVector)
return Setfield.@set svi.values = unflatten(svi.values, x)
end

function BangBang.empty!!(vi::SimpleVarInfo)
Setfield.@set resetlogp!!(vi).values = empty!!(vi.values)
return resetlogp!!(Setfield.@set vi.values = empty!!(vi.values))
yebai marked this conversation as resolved.
Show resolved Hide resolved
end
Base.isempty(vi::SimpleVarInfo) = isempty(vi.values)

getlogp(vi::SimpleVarInfo) = vi.logp
setlogp!!(vi::SimpleVarInfo, logp) = Setfield.@set vi.logp = logp
Expand Down Expand Up @@ -308,11 +313,8 @@ end
# HACK: Needed to disambiguiate.
Base.getindex(vi::SimpleVarInfo, vns::Vector{<:VarName}) = map(Base.Fix1(getindex, vi), vns)

Base.getindex(vi::SimpleVarInfo, spl::SampleFromPrior) = vi.values
Base.getindex(vi::SimpleVarInfo, spl::SampleFromUniform) = vi.values

# TODO: Should we do better?
Base.getindex(vi::SimpleVarInfo, spl::Sampler) = vi.values
Base.getindex(svi::SimpleVarInfo, ::Colon) = values_as(svi, Vector)
Base.getindex(svi::SimpleVarInfo, ::Sampler) = svi[:]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be

Suggested change
Base.getindex(svi::SimpleVarInfo, ::Sampler) = svi[:]
Base.getindex(svi::SimpleVarInfo, ::AbstractSampler) = svi[:]

in line with setindex!!?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will cause method ambiguity; there are definitions for ::AbstractVarInfo in src/varinfo.jl.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could one generalize the definitons in src/varinfo.jl as well to fix those ambiguities?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could just remove the two impls for SampleFromPrior and SampleFromUniform, and instead have AbstractSampler hit getindex(vi, ::Colon). But I given that this hasn't already been done, I decided against it. I'm happy to make it though! Don't think this should have any unforseen consequences.


# Since we don't perform any transformations in `getindex` for `SimpleVarInfo`
# we simply call `getindex` in `getindex_raw`.
Expand Down Expand Up @@ -365,6 +367,10 @@ function BangBang.setindex!!(vi::SimpleVarInfo, val, vn::VarName)
return Setfield.@set vi.values = set!!(vi.values, vn, val)
end

function BangBang.setindex!!(vi::SimpleVarInfo, val, spl::AbstractSampler)
return unflatten(vi, spl, val)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe consider depreciating the API or remove it? The spl argument is simply ignored which might be confusing downstream.

end

# TODO: Specialize to handle certain cases, e.g. a collection of `VarName` with
# same symbol and same type of, say, `IndexLens`, for improved `.~` performance.
function BangBang.setindex!!(vi::SimpleVarInfo, vals, vns::AbstractVector{<:VarName})
Expand Down Expand Up @@ -509,6 +515,45 @@ end

# HACK: Allows us to re-use the implementation of `dot_tilde`, etc. for literals.
increment_num_produce!(::SimpleOrThreadSafeSimple) = nothing
setgid!(vi::SimpleOrThreadSafeSimple, gid::Selector, vn::VarName) = nothing
devmotion marked this conversation as resolved.
Show resolved Hide resolved

# We need these to be compatible with how chains are constructed from `AbstractVarInfo` in Turing.jl.
# TODO: Move away from using these `tonamedtuple` methods.
function tonamedtuple(vi::SimpleOrThreadSafeSimple{<:NamedTuple{names}}) where {names}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we unify the approaches for getting named tuples, vectors etc? E.g. by using values_as or convert?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So tonamedtuple is not what you think it is 😅 It creates a NT of the form (s = (s_values, s_vns), m = (m_values, m_vns)), etc. This is only used to construct the Chains in Turing.jl but we have it here for some reason.

IIRC it also causes insane performance issues for larger models when constructing the chains.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be great to open an issue for this.

nt_vals = map(keys(vi)) do vn
val = vi[vn]
vns = collect(DynamicPPL.TestUtils.varname_leaves(vn, val))
vals = map(Base.Fix1(getindex, vi), vns)
(vals, map(string, vns))
end

return NamedTuple{names}(nt_vals)
end

function tonamedtuple(vi::SimpleOrThreadSafeSimple{<:Dict})
syms_to_result = Dict{Symbol,Tuple{Vector{Real},Vector{String}}}()
for vn in keys(vi)
# Extract the leaf varnames and values.
val = vi[vn]
vns = collect(DynamicPPL.TestUtils.varname_leaves(vn, val))
vals = map(Base.Fix1(getindex, vi), vns)

# Determine the corresponding symbol.
sym = only(unique(map(getsym, vns)))

# Initialize entry if not yet initialized.
if !haskey(syms_to_result, sym)
syms_to_result[sym] = (Real[], String[])
end

# Combine with old result.
old_vals, old_string_vns = syms_to_result[sym]
syms_to_result[sym] = (vcat(old_vals, vals), vcat(old_string_vns, map(string, vns)))
end

# Construct `NamedTuple`.
return NamedTuple(pairs(syms_to_result))
end

# NOTE: We don't implement `settrans!!(vi, trans, vn)`.
function settrans!!(vi::SimpleVarInfo, trans)
Expand All @@ -525,6 +570,8 @@ istrans(vi::SimpleVarInfo) = !(vi.transformation isa NoTransformation)
istrans(vi::SimpleVarInfo, vn::VarName) = istrans(vi)
istrans(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, vn::VarName) = istrans(vi.varinfo, vn)

islinked(vi::SimpleVarInfo, ::Union{Sampler,SampleFromPrior}) = istrans(vi)

"""
values_as(varinfo[, Type])

Expand All @@ -536,6 +583,10 @@ values_as(vi::SimpleVarInfo) = vi.values
values_as(vi::SimpleVarInfo, ::Type{Dict}) = Dict(pairs(vi.values))
values_as(vi::SimpleVarInfo, ::Type{NamedTuple}) = NamedTuple(pairs(vi.values))
values_as(vi::SimpleVarInfo{<:NamedTuple}, ::Type{NamedTuple}) = vi.values
function values_as(vi::SimpleVarInfo{<:Any,T}, ::Type{Vector}) where {T}
isempty(vi.values) && return T[]
return mapreduce(v -> vec([v;]), vcat, values(vi.values))
end

"""
logjoint(model::Model, θ)
Expand Down Expand Up @@ -632,3 +683,35 @@ julia> # Truth.
```
"""
Distributions.loglikelihood(model::Model, θ) = loglikelihood(model, SimpleVarInfo(θ))

# Allow usage of `NamedBijector` too.
function link!!(
t::BijectorTransformation{<:Bijectors.NamedBijector},
vi::SimpleVarInfo{<:NamedTuple},
spl::AbstractSampler,
model::Model,
)
# TODO: Make sure that `spl` is respected.
b = t.bijector
x = vi.values
y, logjac = with_logabsdet_jacobian(b, x)
lp_new = getlogp(vi) - logjac
vi_new = setlogp!!(Setfield.@set(vi.values = y), lp_new)
return settrans!!(vi_new, t)
end

function invlink!!(
t::BijectorTransformation{<:Bijectors.NamedBijector},
vi::SimpleVarInfo{<:NamedTuple},
spl::AbstractSampler,
model::Model,
)
# TODO: Make sure that `spl` is respected.
b = t.bijector
ib = inverse(b)
y = vi.values
x, logjac = with_logabsdet_jacobian(ib, y)
lp_new = getlogp(vi) - logjac
vi_new = setlogp!!(Setfield.@set(vi.values = x), lp_new)
return settrans!!(vi_new, NoTransformation())
end
98 changes: 98 additions & 0 deletions src/transforming.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
function Bijectors.Stacked(
model::Model,
::Val{sym2ranges}=Val(false);
varinfo::VarInfo=VarInfo(model),
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
) where {sym2ranges}
dists = vcat([varinfo.metadata[sym].dists for sym in keys(varinfo.metadata)]...)

num_ranges = sum([
length(varinfo.metadata[sym].ranges) for sym in keys(varinfo.metadata)
])
ranges = Vector{UnitRange{Int}}(undef, num_ranges)
idx = 0
range_idx = 1

# ranges might be discontinuous => values are vectors of ranges rather than just ranges
sym_lookup = Dict{Symbol,Vector{UnitRange{Int}}}()
for sym in keys(varinfo.metadata)
sym_lookup[sym] = Vector{UnitRange{Int}}()
for r in varinfo.metadata[sym].ranges
ranges[range_idx] = idx .+ r
push!(sym_lookup[sym], ranges[range_idx])
range_idx += 1
end

idx += varinfo.metadata[sym].ranges[end][end]
end

b = Bijectors.Stacked(map(Bijectors.bijector, dists), ranges)
return sym2ranges ? (b, Dict(zip(keys(sym_lookup), values(sym_lookup)))) : b
end

link!!(vi::AbstractVarInfo, model::Model) = link!!(vi, SampleFromPrior(), model)
function link!!(t::AbstractTransformation, vi::AbstractVarInfo, model::Model)
return link!!(t, vi, SampleFromPrior(), model)
end
function link!!(vi::AbstractVarInfo, spl::AbstractSampler, model::Model)
# Use `default_transformation` to decide which transformation to use if none is specified.
return link!!(default_transformation(model, vi), vi, spl, model)
end
function link!!(
t::DefaultTransformation, vi::AbstractVarInfo, spl::AbstractSampler, model::Model
)
# TODO: Implement this properly, e.g. using a context or something.
# Fall back to `Bijectors.Stacked` but then we act like we're using
# the `DefaultTransformation` by setting the transformation accordingly.
return settrans!!(
link!!(BijectorTransformation(Bijectors.Stacked(model)), vi, spl, model), t
)
end
function link!!(t::DefaultTransformation, vi::VarInfo, spl::AbstractSampler, model::Model)
# TODO: Implement this properly, e.g. using a context or something.
DynamicPPL.link!(vi, spl)
# TODO: Add `logabsdet_jacobian` correction to `logp`!
return vi
end
function link!!(
t::BijectorTransformation, vi::AbstractVarInfo, spl::AbstractSampler, model::Model
)
b = t.bijector
x = vi[spl]
y, logjac = with_logabsdet_jacobian(b, x)

lp_new = getlogp(vi) - logjac
vi_new = setlogp!!(unflatten(vi, spl, y), lp_new)
return settrans!!(vi_new, t)
end

invlink!!(vi::AbstractVarInfo, model::Model) = invlink!!(vi, SampleFromPrior(), model)
function invlink!!(t::AbstractTransformation, vi::AbstractVarInfo, model::Model)
return invlink!!(t, vi, SampleFromPrior(), model)
end
function invlink!!(vi::AbstractVarInfo, spl::AbstractSampler, model::Model)
# Here we extract the `transformation` from `vi` rather than using the default one.
return invlink!!(transformation(vi), vi, spl, model)
end
function invlink!!(
::DefaultTransformation, vi::AbstractVarInfo, spl::AbstractSampler, model::Model
)
# TODO: Implement this properly, e.g. using a context or something.
return invlink!!(BijectorTransformation(Bijectors.Stacked(model)), vi, spl, model)
end
function invlink!!(::DefaultTransformation, vi::VarInfo, spl::AbstractSampler, model::Model)
# TODO: Implement this properly, e.g. using a context or something.
DynamicPPL.invlink!(vi, spl)
return vi
end
function invlink!!(
t::BijectorTransformation, vi::AbstractVarInfo, spl::AbstractSampler, model::Model
)
b = t.bijector
ib = inverse(b)
y = vi[spl]
x, logjac = with_logabsdet_jacobian(ib, y)
# TODO: Do we need this?
lp_new = getlogp(vi) - logjac
vi_new = setlogp!!(unflatten(vi, spl, x), lp_new)
return settrans!!(vi_new, NoTransformation())
end
Loading