diff --git a/Project.toml b/Project.toml index b51099e78..785df1d38 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.20.2" +version = "0.21.0" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" diff --git a/docs/make.jl b/docs/make.jl index 000b8dbae..6b88c18cd 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -3,7 +3,9 @@ using DynamicPPL using DynamicPPL: AbstractPPL # Doctest setup -DocMeta.setdocmeta!(DynamicPPL, :DocTestSetup, :(using DynamicPPL); recursive=true) +DocMeta.setdocmeta!( + DynamicPPL, :DocTestSetup, :(using DynamicPPL, Distributions); recursive=true +) makedocs(; sitename="DynamicPPL", diff --git a/docs/src/api.md b/docs/src/api.md index 46f60cf96..8be3697bd 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -156,6 +156,8 @@ AbstractVarInfo ### Common API +#### Accumulation of log-probabilities + ```@docs getlogp setlogp!! @@ -163,16 +165,47 @@ acclogp!! resetlogp!! ``` +#### Variables and their realizations + ```@docs +keys getindex +DynamicPPL.getindex_raw push!! empty!! +isempty ``` ```@docs values_as ``` +#### Transformations + +```@docs +DynamicPPL.AbstractTransformation +DynamicPPL.NoTransformation +DynamicPPL.DynamicTransformation +DynamicPPL.StaticTransformation +``` + +```@docs +DynamicPPL.istrans +DynamicPPL.settrans!! +DynamicPPL.transformation +DynamicPPL.link!! +DynamicPPL.invlink!! +DynamicPPL.default_transformation +DynamicPPL.maybe_invlink_before_eval!! +``` + +#### Utils + +```@docs +DynamicPPL.unflatten +DynamicPPL.tonamedtuple +``` + #### `SimpleVarInfo` ```@docs @@ -191,10 +224,8 @@ TypedVarInfo One main characteristic of [`VarInfo`](@ref) is that samples are stored in a linearized form. ```@docs -tonamedtuple link! invlink! -istrans ``` ```@docs diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 2ccef46f6..69b5821d1 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -59,7 +59,9 @@ export AbstractVarInfo, setorder!, istrans, link!, + link!!, invlink!, + invlink!!, tonamedtuple, values_as, # VarName (reexport from AbstractPPL) @@ -126,7 +128,6 @@ export loglikelihood # Used here and overloaded in Turing function getspace end -# Necessary forward declarations """ AbstractVarInfo @@ -134,10 +135,16 @@ Abstract supertype for data structures that capture random variables when execut probabilistic model and accumulate log densities such as the log likelihood or the log joint probability of the model. -See also: [`VarInfo`](@ref) +See also: [`VarInfo`](@ref), [`SimpleVarInfo`](@ref). """ abstract type AbstractVarInfo <: AbstractModelTrace end +const LEGACY_WARNING = """ +!!! warning + This method is considered legacy, and is likely to be deprecated in the future. +""" + +# Necessary forward declarations include("utils.jl") include("selector.jl") include("model.jl") @@ -145,8 +152,9 @@ include("sampler.jl") include("varname.jl") include("distribution_wrappers.jl") include("contexts.jl") -include("varinfo.jl") +include("abstract_varinfo.jl") include("threadsafe.jl") +include("varinfo.jl") include("simple_varinfo.jl") include("context_implementations.jl") include("compiler.jl") @@ -155,5 +163,6 @@ include("compat/ad.jl") include("loglikelihoods.jl") include("submodel_macro.jl") include("test_utils.jl") +include("transforming.jl") end # module diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl new file mode 100644 index 000000000..e77f03023 --- /dev/null +++ b/src/abstract_varinfo.jl @@ -0,0 +1,549 @@ +# Transformation related. +""" + $(TYPEDEF) + +Represents a transformation to be used in `link!!` and `invlink!!`, amongst others. + +A concrete implementation of this should implement the following methods: +- [`link!!`](@ref): transforms the [`AbstractVarInfo`](@ref) to the unconstrained space. +- [`invlink!!`](@ref): transforms the [`AbstractVarInfo`](@ref) to the constrained space. + +And potentially: +- [`maybe_invlink_before_eval!!`](@ref): hook to decide whether to transform _before_ + evaluating the model. + +See also: [`link!!`](@ref), [`invlink!!`](@ref), [`maybe_invlink_before_eval!!`](@ref). +""" +abstract type AbstractTransformation end + +""" + $(TYPEDEF) + +Transformation which applies the identity function. +""" +struct NoTransformation <: AbstractTransformation end + +""" + $(TYPEDEF) + +Transformation which transforms the variables on a per-need-basis +in the execution of a given `Model`. + +This is in constrast to `StaticTransformation` which transforms all variables +_before_ the execution of a given `Model`. + +See also: [`StaticTransformation`](@ref). +""" +struct DynamicTransformation <: AbstractTransformation end + +""" + $(TYPEDEF) + +Transformation which transforms all variables _before_ the execution of a given `Model`. + +This is done through the `maybe_invlink_before_eval!!` method. + +See also: [`DynamicTransformation`](@ref), [`maybe_invlink_before_eval!!`](@ref). + +# Fields +$(TYPEDFIELDS) +""" +struct StaticTransformation{F} <: AbstractTransformation + "The function, assumed to implement the `Bijectors` interface, to be applied to the variables" + bijector::F +end + +""" + default_transformation(model::Model[, vi::AbstractVarInfo]) + +Return the `AbstractTransformation` currently related to `model` and, potentially, `vi`. +""" +default_transformation(model::Model, ::AbstractVarInfo) = default_transformation(model) +default_transformation(::Model) = DynamicTransformation() + +""" + transformation(vi::AbstractVarInfo) + +Return the `AbstractTransformation` related to `vi`. +""" +function transformation end + +# Accumulation of log-probabilities. +""" + getlogp(vi::AbstractVarInfo) + +Return the log of the joint probability of the observed data and parameters sampled in +`vi`. +""" +function getlogp end + +""" + setlogp!!(vi::AbstractVarInfo, logp) + +Set the log of the joint probability of the observed data and parameters sampled in +`vi` to `logp`, mutating if it makes sense. +""" +function setlogp!! end + +""" + acclogp!!(vi::AbstractVarInfo, logp) + +Add `logp` to the value of the log of the joint probability of the observed data and +parameters sampled in `vi`, mutating if it makes sense. +""" +function acclogp!! end + +""" + resetlogp!!(vi::AbstractVarInfo) + +Reset the value of the log of the joint probability of the observed data and parameters +sampled in `vi` to 0, mutating if it makes sense. +""" +resetlogp!!(vi::AbstractVarInfo) = setlogp!!(vi, zero(getlogp(vi))) + +# Variables and their realizations. +@doc """ + keys(vi::AbstractVarInfo) + +Return an iterator over all `vns` in `vi`. +""" Base.keys + +@doc """ + getindex(vi::AbstractVarInfo, vn::VarName[, dist::Distribution]) + getindex(vi::AbstractVarInfo, vns::Vector{<:VarName}[, dist::Distribution]) + +Return the current value(s) of `vn` (`vns`) in `vi` in the support of its (their) +distribution(s). + +If `dist` is specified, the value(s) will be reshaped accordingly. + +See also: [`getindex_raw(vi::AbstractVarInfo, vn::VarName, dist::Distribution)`](@ref) +""" Base.getindex + +""" + getindex(vi::AbstractVarInfo, ::Colon) + getindex(vi::AbstractVarInfo, ::AbstractSampler) + +Return the current value(s) of `vn` (`vns`) in `vi` in the support of its (their) +distribution(s) as a flattened `Vector`. + +The default implementation is to call [`values_as`](@ref) with `Vector` as the type-argument. + +See also: [`getindex(vi::AbstractVarInfo, vn::VarName, dist::Distribution)`](@ref) +""" +Base.getindex(vi::AbstractVarInfo, ::Colon) = values_as(vi, Vector) +Base.getindex(vi::AbstractVarInfo, ::AbstractSampler) = vi[:] + +""" + getindex_raw(vi::AbstractVarInfo, vn::VarName[, dist::Distribution]) + getindex_raw(vi::AbstractVarInfo, vns::Vector{<:VarName}[, dist::Distribution]) + +Return the current value(s) of `vn` (`vns`) in `vi`. + +If `dist` is specified, the value(s) will be reshaped accordingly. + +See also: [`getindex(vi::AbstractVarInfo, vn::VarName, dist::Distribution)`](@ref) + +!!! note + The difference between `getindex(vi, vn, dist)` and `getindex_raw` is that + `getindex` will also transform the value(s) to the support of the distribution(s). + This is _not_ the case for `getindex_raw`. + +""" +function getindex_raw end + +""" + push!!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution) + +Push a new random variable `vn` with a sampled value `r` from a distribution `dist` to +the `VarInfo` `vi`, mutating if it makes sense. +""" +function BangBang.push!!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution) + return BangBang.push!!(vi, vn, r, dist, Set{Selector}([])) +end + +""" + push!!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, spl::AbstractSampler) + +Push a new random variable `vn` with a sampled value `r` sampled with a sampler `spl` +from a distribution `dist` to `VarInfo` `vi`, if it makes sense. + +The sampler is passed here to invalidate its cache where defined. + +$(LEGACY_WARNING) +""" +function BangBang.push!!( + vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, spl::Sampler +) + return BangBang.push!!(vi, vn, r, dist, spl.selector) +end +function BangBang.push!!( + vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, spl::AbstractSampler +) + return BangBang.push!!(vi, vn, r, dist) +end + +""" + push!!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, gid::Selector) + +Push a new random variable `vn` with a sampled value `r` sampled with a sampler of +selector `gid` from a distribution `dist` to `VarInfo` `vi`. + +$(LEGACY_WARNING) +""" +function BangBang.push!!( + vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, gid::Selector +) + return BangBang.push!!(vi, vn, r, dist, Set([gid])) +end + +@doc """ + empty!!(vi::AbstractVarInfo) + +Empty the fields of `vi.metadata` and reset `vi.logp[]` and `vi.num_produce[]` to +zeros. + +This is useful when using a sampling algorithm that assumes an empty `vi`, e.g. `SMC`. +""" BangBang.empty!! + +@doc """ + isempty(vi::AbstractVarInfo) + +Return true if `vi` is empty and false otherwise. +""" Base.isempty + +""" + values_as(varinfo[, Type]) + +Return the values/realizations in `varinfo` as `Type`, if implemented. + +If no `Type` is provided, return values as stored in `varinfo`. + +# Examples + +`SimpleVarInfo` with `NamedTuple`: + +```jldoctest +julia> data = (x = 1.0, m = [2.0]); + +julia> values_as(SimpleVarInfo(data)) +(x = 1.0, m = [2.0]) + +julia> values_as(SimpleVarInfo(data), NamedTuple) +(x = 1.0, m = [2.0]) + +julia> values_as(SimpleVarInfo(data), OrderedDict) +OrderedDict{VarName{sym, Setfield.IdentityLens} where sym, Any} with 2 entries: + x => 1.0 + m => [2.0] + +julia> values_as(SimpleVarInfo(data), Vector) +2-element Vector{Float64}: + 1.0 + 2.0 +``` + +`SimpleVarInfo` with `OrderedDict`: + +```jldoctest +julia> data = OrderedDict{Any,Any}(@varname(x) => 1.0, @varname(m) => [2.0]); + +julia> values_as(SimpleVarInfo(data)) +OrderedDict{Any, Any} with 2 entries: + x => 1.0 + m => [2.0] + +julia> values_as(SimpleVarInfo(data), NamedTuple) +(x = 1.0, m = [2.0]) + +julia> values_as(SimpleVarInfo(data), OrderedDict) +OrderedDict{Any, Any} with 2 entries: + x => 1.0 + m => [2.0] + +julia> values_as(SimpleVarInfo(data), Vector) +2-element Vector{Float64}: + 1.0 + 2.0 +``` + +`TypedVarInfo`: + +```jldoctest +julia> # Just use an example model to construct the `VarInfo` because we're lazy. + vi = VarInfo(DynamicPPL.TestUtils.demo_assume_dot_observe()); + +julia> vi[@varname(s)] = 1.0; vi[@varname(m)] = 2.0; + +julia> # For the sake of brevity, let's just check the type. + md = values_as(vi); md.s isa DynamicPPL.Metadata +true + +julia> values_as(vi, NamedTuple) +(s = 1.0, m = 2.0) + +julia> values_as(vi, OrderedDict) +OrderedDict{VarName{sym, Setfield.IdentityLens} where sym, Float64} with 2 entries: + s => 1.0 + m => 2.0 + +julia> values_as(vi, Vector) +2-element Vector{Float64}: + 1.0 + 2.0 +``` + +`UntypedVarInfo`: + +```jldoctest +julia> # Just use an example model to construct the `VarInfo` because we're lazy. + vi = VarInfo(); DynamicPPL.TestUtils.demo_assume_dot_observe()(vi); + +julia> vi[@varname(s)] = 1.0; vi[@varname(m)] = 2.0; + +julia> # For the sake of brevity, let's just check the type. + values_as(vi) isa DynamicPPL.Metadata +true + +julia> values_as(vi, NamedTuple) +(s = 1.0, m = 2.0) + +julia> values_as(vi, OrderedDict) +OrderedDict{VarName{sym, Setfield.IdentityLens} where sym, Float64} with 2 entries: + s => 1.0 + m => 2.0 + +julia> values_as(vi, Vector) +2-element Vector{Real}: + 1.0 + 2.0 +``` +""" +function values_as end + +""" + eltype(vi::AbstractVarInfo, spl::Union{AbstractSampler,SampleFromPrior} + +Determine the default `eltype` of the values returned by `vi[spl]`. + +!!! warning + This should generally not be called explicitly, as it's only used in + [`matchingvalue`](@ref) to determine the default type to use in place of + type-parameters passed to the model. + + This method is considered legacy, and is likely to be deprecated in the future. +""" +function Base.eltype(vi::AbstractVarInfo, spl::Union{AbstractSampler,SampleFromPrior}) + return eltype(Core.Compiler.return_type(getindex, Tuple{typeof(vi),typeof(spl)})) +end + +# Transformations +""" + istrans(vi::AbstractVarInfo[, vns::Union{VarName, AbstractVector{<:Varname}}]) + +Return `true` if `vi` is working in unconstrained space, and `false` +if `vi` is assuming realizations to be in support of the corresponding distributions. + +If `vns` is provided, then only check if this/these varname(s) are transformed. + +!!! warning + Not all implementations of `AbstractVarInfo` support transforming only a subset of + the variables. +""" +istrans(vi::AbstractVarInfo) = istrans(vi, collect(keys(vi))) +function istrans(vi::AbstractVarInfo, vns::AbstractVector) + return !isempty(vns) && all(Base.Fix1(istrans, vi), vns) +end + +""" + settrans!!(vi::AbstractVarInfo, trans::Bool[, vn::VarName]) + +Return `vi` with `istrans(vi, vn)` evaluating to `true`. + +If `vn` is not specified, then `istrans(vi)` evaluates to `true` for all variables. +""" +function settrans!! end + +""" + link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model) + link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, spl::AbstractSampler, model::Model) + +Transforms the variables in `vi` to their linked space, using the transformation `t`. + +If `t` is not provided, `default_transformation(model, vi)` will be used. + +See also: [`default_transformation`](@ref), [`invlink!!`](@ref). +""" +link!!(vi::AbstractVarInfo, model::Model) = link!!(vi, SampleFromPrior(), model) +function link!!(t::AbstractTransformation, vi::AbstractVarInfo, model::Model) + return link!!(t, vi, SampleFromPrior(), model) +end +function link!!(vi::AbstractVarInfo, spl::AbstractSampler, model::Model) + # Use `default_transformation` to decide which transformation to use if none is specified. + return link!!(default_transformation(model, vi), vi, spl, model) +end + +""" + invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model) + invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, spl::AbstractSampler, model::Model) + +Transform the variables in `vi` to their constrained space, using the (inverse of) +transformation `t`. + +If `t` is not provided, `default_transformation(model, vi)` will be used. + +See also: [`default_transformation`](@ref), [`link!!`](@ref). +""" +invlink!!(vi::AbstractVarInfo, model::Model) = invlink!!(vi, SampleFromPrior(), model) +function invlink!!(t::AbstractTransformation, vi::AbstractVarInfo, model::Model) + return invlink!!(t, vi, SampleFromPrior(), model) +end +function invlink!!(vi::AbstractVarInfo, spl::AbstractSampler, model::Model) + # Here we extract the `transformation` from `vi` rather than using the default one. + return invlink!!(transformation(vi), vi, spl, model) +end + +# Vector-based ones. +function link!!( + t::StaticTransformation{<:Bijectors.Bijector{1}}, + vi::AbstractVarInfo, + spl::AbstractSampler, + model::Model, +) + b = inverse(t.bijector) + x = vi[spl] + y, logjac = with_logabsdet_jacobian(b, x) + + lp_new = getlogp(vi) - logjac + vi_new = setlogp!!(unflatten(vi, spl, y), lp_new) + return settrans!!(vi_new, t) +end + +function invlink!!( + t::StaticTransformation{<:Bijectors.Bijector{1}}, + vi::AbstractVarInfo, + spl::AbstractSampler, + model::Model, +) + b = t.bijector + y = vi[spl] + x, logjac = with_logabsdet_jacobian(b, y) + + lp_new = getlogp(vi) + logjac + vi_new = setlogp!!(unflatten(vi, spl, x), lp_new) + return settrans!!(vi_new, NoTransformation()) +end + +""" + maybe_invlink_before_eval!!([t::Transformation,] vi, context, model) + +Return a possibly invlinked version of `vi`. + +This will be called prior to `model` evaluation, allowing one to perform a single +`invlink!!` _before_ evaluation rather than lazyily evaluating the transforms on as-we-need +basis as is done with [`DynamicTransformation`](@ref). + +See also: [`StaticTransformation`](@ref), [`DynamicTransformation`](@ref). + +# Examples +```julia-repl +julia> using DynamicPPL, Distributions, Bijectors + +julia> @model demo() = x ~ Normal() +demo (generic function with 2 methods) + +julia> # By subtyping `Bijector{1}`, we inherit the `(inv)link!!` defined for + # bijectors which acts on 1-dimensional arrays, i.e. vectors. + struct MyBijector <: Bijectors.Bijector{1} end + +julia> # Define some dummy `inverse` which will be used in the `link!!` call. + Bijectors.inverse(f::MyBijector) = identity + +julia> # We need to define `with_logabsdet_jacobian` for `MyBijector` + # (`identity` already has `with_logabsdet_jacobian` defined) + function Bijectors.with_logabsdet_jacobian(::MyBijector, x) + # Just using a large number of the logabsdet-jacobian term + # for demonstration purposes. + return (x, 1000) + end + +julia> # Change the `default_transformation` for our model to be a + # `StaticTransformation` using `MyBijector`. + function DynamicPPL.default_transformation(::Model{typeof(demo)}) + return DynamicPPL.StaticTransformation(MyBijector()) + end + +julia> model = demo(); + +julia> vi = SimpleVarInfo(x=1.0) +SimpleVarInfo((x = 1.0,), 0.0) + +julia> # Uses the `inverse` of `MyBijector`, which we have defined as `identity` + vi_linked = link!!(vi, model) +Transformed SimpleVarInfo((x = 1.0,), 0.0) + +julia> # Now performs a single `invlink!!` before model evaluation. + logjoint(model, vi_linked) +-1001.4189385332047 +``` +""" +function maybe_invlink_before_eval!!( + vi::AbstractVarInfo, context::AbstractContext, model::Model +) + return maybe_invlink_before_eval!!(transformation(vi), vi, context, model) +end +function maybe_invlink_before_eval!!( + ::NoTransformation, vi::AbstractVarInfo, context::AbstractContext, model::Model +) + return vi +end +function maybe_invlink_before_eval!!( + ::DynamicTransformation, vi::AbstractVarInfo, context::AbstractContext, model::Model +) + # `DynamicTransformation` is meant to _not_ do the transformation statically, hence we do nothing. + return vi +end +function maybe_invlink_before_eval!!( + t::StaticTransformation, vi::AbstractVarInfo, context::AbstractContext, model::Model +) + return invlink!!(t, vi, _default_sampler(context), model) +end + +function _default_sampler(context::AbstractContext) + return _default_sampler(NodeTrait(_default_sampler, context), context) +end +_default_sampler(::IsLeaf, context::AbstractContext) = SampleFromPrior() +function _default_sampler(::IsParent, context::AbstractContext) + return _default_sampler(childcontext(context)) +end + +# Utilities +""" + unflatten(vi::AbstractVarInfo[, spl::AbstractSampler], x::AbstractVector) + +Return a new instance of `vi` with the values of `x` assigned to the variables. + +If `spl` is provided, `x` is assumed to be realizations only for variables related +to `spl`. +""" +function unflatten end + +""" + tonamedtuple(vi::AbstractVarInfo) + +Convert a `vi` into a `NamedTuple` where each variable symbol maps to the values and +indexing string of the variable. + +For example, a model that had a vector of vector-valued +variables `x` would return + +```julia +(x = ([1.5, 2.0], [3.0, 1.0], ["x[1]", "x[2]"]), ) +``` +""" +function tonamedtuple end + +# Legacy code that is currently overloaded for the sake of simplicity. +# TODO: Remove when possible. +increment_num_produce!(::AbstractVarInfo) = nothing +setgid!(vi::AbstractVarInfo, gid::Selector, vn::VarName) = nothing diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 6271b5d8c..810600072 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -204,14 +204,14 @@ function assume( sampler::Union{SampleFromPrior,SampleFromUniform}, dist::Distribution, vn::VarName, - vi::AbstractVarInfo, + vi::VarInfoOrThreadSafeVarInfo, ) if haskey(vi, vn) # Always overwrite the parameters with new ones for `SampleFromUniform`. if sampler isa SampleFromUniform || is_flagged(vi, vn, "del") unset_flag!(vi, vn, "del") r = init(rng, dist, sampler) - vi[vn] = vectorize(dist, maybe_link(vi, vn, dist, r)) + BangBang.setindex!!(vi, vectorize(dist, maybe_link(vi, vn, dist, r)), vn) setorder!(vi, vn, get_num_produce(vi)) else # Otherwise we just extract it. @@ -526,6 +526,8 @@ function get_and_set_val!( # we then broadcast. This will allocate a vector of `nothing` though. if istrans(vi) push!!.((vi,), vns, link.((vi,), vns, dists, r), dists, (spl,)) + # NOTE: Need to add the correction. + acclogp!!(vi, sum(logabsdetjac.(bijector.(dists), r))) # `push!!` sets the trans-flag to `false` by default. settrans!!.((vi,), true, vns) else diff --git a/src/contexts.jl b/src/contexts.jl index bd8acf278..2c59cf68c 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -298,34 +298,33 @@ childcontext(context::ConditionContext) = context.context setchildcontext(parent::ConditionContext, child) = ConditionContext(parent.values, child) """ - hasvalue(context, vn) + hasvalue(context::AbstractContext, vn::VarName) Return `true` if `vn` is found in `context`. """ -hasvalue(context, vn) = false -hasvalue(context::ConditionContext, vn::VarName) = nested_haskey(context.values, vn) +hasvalue(context::AbstractContext, vn::VarName) = false +hasvalue(context::ConditionContext, vn::VarName) = hasvalue(context.values, vn) function hasvalue(context::ConditionContext, vns::AbstractArray{<:VarName}) - return all(Base.Fix1(nested_haskey, context.values), vns) + return all(Base.Fix1(hasvalue, context.values), vns) end """ - getvalue(context, vn) + getvalue(context::AbstractContext, vn::VarName) Return value of `vn` in `context`. """ -function getvalue(context::AbstractContext, vn) +function getvalue(context::AbstractContext, vn::VarName) return error("context $(context) does not contain value for $vn") end -getvalue(context::NamedConditionContext, vn) = get(context.values, vn) -getvalue(context::ConditionContext, vn) = nested_getindex(context.values, vn) +getvalue(context::ConditionContext, vn::VarName) = getvalue(context.values, vn) """ hasvalue_nested(context, vn) Return `true` if `vn` is found in `context` or any of its descendants. -This is contrast to [`hasvalue`](@ref) which only checks for `vn` in `context`, -not recursively checking if `vn` is in any of its descendants. +This is contrast to [`hasvalue(::AbstractContext, ::VarName)`](@ref) which only checks +for `vn` in `context`, not recursively checking if `vn` is in any of its descendants. """ function hasvalue_nested(context::AbstractContext, vn) return hasvalue_nested(NodeTrait(hasvalue_nested, context), context, vn) diff --git a/src/model.jl b/src/model.jl index 174df571c..b7e0984c5 100644 --- a/src/model.jl +++ b/src/model.jl @@ -458,7 +458,7 @@ julia> conditioned(cm).var"a.m" 1.0 julia> keys(VarInfo(cm)) # <= no variables are sampled -Any[] +VarName[] ``` """ conditioned(model::Model) = conditioned(model.context) @@ -590,7 +590,16 @@ Evaluate the `model` with the arguments matching the given `context` and `varinf context_new = setleafcontext( context, setleafcontext(model.context, leafcontext(context)) ) - model.f(model, varinfo, context_new, $(unwrap_args...)) + model.f( + model, + # Maybe perform `invlink!!` once prior to evaluation to avoid + # lazy `invlink`-ing of the parameters. This can be useful for + # speeding up computation. See docs for `maybe_invlink_before_eval!!` + # for more information. + maybe_invlink_before_eval!!(varinfo, context_new, model), + context_new, + $(unwrap_args...), + ) end end @@ -629,7 +638,7 @@ function Base.rand(rng::Random.AbstractRNG, ::Type{T}, model::Model) where {T} SamplingContext(rng, SampleFromPrior(), DefaultContext()), ), ) - return DynamicPPL.values_as(x, T) + return values_as(x, T) end # Default RNG and type diff --git a/src/sampler.jl b/src/sampler.jl index 550c27642..3a4daf0b1 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -67,6 +67,19 @@ function AbstractMCMC.step( return vi, nothing end +function default_varinfo(rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler) + return default_varinfo(rng, model, sampler, DefaultContext()) +end +function default_varinfo( + rng::Random.AbstractRNG, + model::Model, + sampler::AbstractSampler, + context::AbstractContext, +) + init_sampler = initialsampler(sampler) + return VarInfo(rng, model, init_sampler, context) +end + # initial step: general interface for resuming and function AbstractMCMC.step( rng::Random.AbstractRNG, @@ -82,23 +95,17 @@ function AbstractMCMC.step( end # Sample initial values. - _spl = initialsampler(spl) - vi = VarInfo(rng, model, _spl) + vi = default_varinfo(rng, model, spl) # Update the parameters if provided. if init_params !== nothing - vi = initialize_parameters!!(vi, init_params, spl) + vi = initialize_parameters!!(vi, init_params, spl, model) # Update joint log probability. - # TODO: fix properly by using sampler and evaluation contexts # This is a quick fix for https://github.com/TuringLang/Turing.jl/issues/1588 # and https://github.com/TuringLang/Turing.jl/issues/1563 # to avoid that existing variables are resampled - if _spl isa SampleFromUniform - model(rng, vi, SampleFromPrior()) - else - model(rng, vi, _spl) - end + vi = last(evaluate!!(model, vi, DefaultContext())) end return initialstep(rng, model, spl, vi; init_params=init_params, kwargs...) @@ -121,7 +128,9 @@ By default, it returns an instance of [`SampleFromPrior`](@ref). """ initialsampler(spl::Sampler) = SampleFromPrior() -function initialize_parameters!!(vi::AbstractVarInfo, init_params, spl::Sampler) +function initialize_parameters!!( + vi::AbstractVarInfo, init_params, spl::Sampler, model::Model +) @debug "Using passed-in initial variable values" init_params # Flatten parameters. @@ -132,8 +141,7 @@ function initialize_parameters!!(vi::AbstractVarInfo, init_params, spl::Sampler) # Get all values. linked = islinked(vi, spl) if linked - # TODO: Make work with immutable `vi`. - invlink!(vi, spl) + vi = invlink!!(vi, spl, model) end theta = vi[spl] length(theta) == length(init_theta) || @@ -150,8 +158,7 @@ function initialize_parameters!!(vi::AbstractVarInfo, init_params, spl::Sampler) # Update in `vi`. vi = setindex!!(vi, theta, spl) if linked - # TODO: Make work with immutable `vi`. - link!(vi, spl) + vi = link!!(vi, spl, model) end return vi diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 8edf9e21e..d1d637d27 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -1,8 +1,3 @@ -abstract type AbstractTransformation end - -struct NoTransformation <: AbstractTransformation end -struct DefaultTransformation <: AbstractTransformation end - """ $(TYPEDEF) @@ -136,7 +131,7 @@ julia> # (✓) Positive probability mass on negative numbers! getlogp(last(DynamicPPL.evaluate!!(m, vi, DynamicPPL.DefaultContext()))) -1.3678794411714423 -julia> # While if we forget to make indicate that it's transformed: +julia> # While if we forget to indicate that it's transformed: vi = DynamicPPL.settrans!!(SimpleVarInfo((x = -1.0,)), false) SimpleVarInfo((x = -1.0,), 0.0) @@ -202,6 +197,8 @@ struct SimpleVarInfo{NT,T,C<:AbstractTransformation} <: AbstractVarInfo transformation::C end +transformation(vi::SimpleVarInfo) = vi.transformation + # Makes things a bit more readable vs. putting `Float64` everywhere. const SIMPLEVARINFO_DEFAULT_ELTYPE = Float64 @@ -251,9 +248,15 @@ function SimpleVarInfo{T}( return SimpleVarInfo(values, convert(T, getlogp(vi))) end +unflatten(svi::SimpleVarInfo, spl, x::AbstractVector) = unflatten(svi, x) +function unflatten(svi::SimpleVarInfo, x::AbstractVector) + return Setfield.@set svi.values = unflatten(svi.values, x) +end + function BangBang.empty!!(vi::SimpleVarInfo) - Setfield.@set resetlogp!!(vi).values = empty!!(vi.values) + return resetlogp!!(Setfield.@set vi.values = empty!!(vi.values)) end +Base.isempty(vi::SimpleVarInfo) = isempty(vi.values) getlogp(vi::SimpleVarInfo) = vi.logp setlogp!!(vi::SimpleVarInfo, logp) = Setfield.@set vi.logp = logp @@ -311,11 +314,7 @@ end # HACK: Needed to disambiguiate. Base.getindex(vi::SimpleVarInfo, vns::Vector{<:VarName}) = map(Base.Fix1(getindex, vi), vns) -Base.getindex(vi::SimpleVarInfo, spl::SampleFromPrior) = vi.values -Base.getindex(vi::SimpleVarInfo, spl::SampleFromUniform) = vi.values - -# TODO: Should we do better? -Base.getindex(vi::SimpleVarInfo, spl::Sampler) = vi.values +Base.getindex(svi::SimpleVarInfo, ::Colon) = values_as(svi, Vector) # Since we don't perform any transformations in `getindex` for `SimpleVarInfo` # we simply call `getindex` in `getindex_raw`. @@ -330,13 +329,17 @@ function getindex_raw(vi::SimpleVarInfo, vns::Vector{<:VarName}, dist::Distribut return reconstruct(dist, vals, length(vns)) end -Base.haskey(vi::SimpleVarInfo, vn::VarName) = nested_haskey(vi.values, vn) +Base.haskey(vi::SimpleVarInfo, vn::VarName) = hasvalue(vi.values, vn) function BangBang.setindex!!(vi::SimpleVarInfo, val, vn::VarName) # For `NamedTuple` we treat the symbol in `vn` as the _property_ to set. return Setfield.@set vi.values = set!!(vi.values, vn, val) end +function BangBang.setindex!!(vi::SimpleVarInfo, val, spl::AbstractSampler) + return unflatten(vi, spl, val) +end + # TODO: Specialize to handle certain cases, e.g. a collection of `VarName` with # same symbol and same type of, say, `IndexLens`, for improved `.~` performance. function BangBang.setindex!!(vi::SimpleVarInfo, vals, vns::AbstractVector{<:VarName}) @@ -479,12 +482,47 @@ function dot_assume( return value, lp, vi end -# HACK: Allows us to re-use the implementation of `dot_tilde`, etc. for literals. -increment_num_produce!(::SimpleOrThreadSafeSimple) = nothing +# We need these to be compatible with how chains are constructed from `AbstractVarInfo` in Turing.jl. +# TODO: Move away from using these `tonamedtuple` methods. +function tonamedtuple(vi::SimpleOrThreadSafeSimple{<:NamedTuple{names}}) where {names} + nt_vals = map(keys(vi)) do vn + val = vi[vn] + vns = collect(DynamicPPL.TestUtils.varname_leaves(vn, val)) + vals = map(Base.Fix1(getindex, vi), vns) + (vals, map(string, vns)) + end + + return NamedTuple{names}(nt_vals) +end + +function tonamedtuple(vi::SimpleOrThreadSafeSimple{<:Dict}) + syms_to_result = Dict{Symbol,Tuple{Vector{Real},Vector{String}}}() + for vn in keys(vi) + # Extract the leaf varnames and values. + val = vi[vn] + vns = collect(DynamicPPL.TestUtils.varname_leaves(vn, val)) + vals = map(Base.Fix1(getindex, vi), vns) + + # Determine the corresponding symbol. + sym = only(unique(map(getsym, vns))) + + # Initialize entry if not yet initialized. + if !haskey(syms_to_result, sym) + syms_to_result[sym] = (Real[], String[]) + end + + # Combine with old result. + old_vals, old_string_vns = syms_to_result[sym] + syms_to_result[sym] = (vcat(old_vals, vals), vcat(old_string_vns, map(string, vns))) + end + + # Construct `NamedTuple`. + return NamedTuple(pairs(syms_to_result)) +end # NOTE: We don't implement `settrans!!(vi, trans, vn)`. function settrans!!(vi::SimpleVarInfo, trans) - return settrans!!(vi, trans ? DefaultTransformation() : NoTransformation()) + return settrans!!(vi, trans ? DynamicTransformation() : NoTransformation()) end function settrans!!(vi::SimpleVarInfo, transformation::AbstractTransformation) return Setfield.@set vi.transformation = transformation @@ -497,8 +535,14 @@ istrans(vi::SimpleVarInfo) = !(vi.transformation isa NoTransformation) istrans(vi::SimpleVarInfo, vn::VarName) = istrans(vi) istrans(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, vn::VarName) = istrans(vi.varinfo, vn) +islinked(vi::SimpleVarInfo, ::Union{Sampler,SampleFromPrior}) = istrans(vi) + values_as(vi::SimpleVarInfo) = vi.values values_as(vi::SimpleVarInfo{<:T}, ::Type{T}) where {T} = vi.values +function values_as(vi::SimpleVarInfo{<:Any,T}, ::Type{Vector}) where {T} + isempty(vi) && return T[] + return mapreduce(v -> vec([v;]), vcat, values(vi.values)) +end function values_as(vi::SimpleVarInfo, ::Type{D}) where {D<:AbstractDict} return ConstructionBase.constructorof(D)(zip(keys(vi), values(vi.values))) end @@ -527,8 +571,8 @@ julia> # Using a `NamedTuple`. logjoint(demo([1.0]), (m = 100.0, )) -9902.33787706641 -julia> # Using a `Dict`. - logjoint(demo([1.0]), Dict(@varname(m) => 100.0)) +julia> # Using a `OrderedDict`. + logjoint(demo([1.0]), OrderedDict(@varname(m) => 100.0)) -9902.33787706641 julia> # Truth. @@ -559,8 +603,8 @@ julia> # Using a `NamedTuple`. logprior(demo([1.0]), (m = 100.0, )) -5000.918938533205 -julia> # Using a `Dict`. - logprior(demo([1.0]), Dict(@varname(m) => 100.0)) +julia> # Using a `OrderedDict`. + logprior(demo([1.0]), OrderedDict(@varname(m) => 100.0)) -5000.918938533205 julia> # Truth. @@ -591,8 +635,8 @@ julia> # Using a `NamedTuple`. loglikelihood(demo([1.0]), (m = 100.0, )) -4901.418938533205 -julia> # Using a `Dict`. - loglikelihood(demo([1.0]), Dict(@varname(m) => 100.0)) +julia> # Using a `OrderedDict`. + loglikelihood(demo([1.0]), OrderedDict(@varname(m) => 100.0)) -4901.418938533205 julia> # Truth. @@ -602,6 +646,37 @@ julia> # Truth. """ Distributions.loglikelihood(model::Model, θ) = loglikelihood(model, SimpleVarInfo(θ)) +# Allow usage of `NamedBijector` too. +function link!!( + t::StaticTransformation{<:Bijectors.NamedBijector}, + vi::SimpleVarInfo{<:NamedTuple}, + spl::AbstractSampler, + model::Model, +) + # TODO: Make sure that `spl` is respected. + b = inverse(t.bijector) + x = vi.values + y, logjac = with_logabsdet_jacobian(b, x) + lp_new = getlogp(vi) - logjac + vi_new = setlogp!!(Setfield.@set(vi.values = y), lp_new) + return settrans!!(vi_new, t) +end + +function invlink!!( + t::StaticTransformation{<:Bijectors.NamedBijector}, + vi::SimpleVarInfo{<:NamedTuple}, + spl::AbstractSampler, + model::Model, +) + # TODO: Make sure that `spl` is respected. + b = t.bijector + y = vi.values + x, logjac = with_logabsdet_jacobian(b, y) + lp_new = getlogp(vi) + logjac + vi_new = setlogp!!(Setfield.@set(vi.values = x), lp_new) + return settrans!!(vi_new, NoTransformation()) +end + # Threadsafe stuff. # For `SimpleVarInfo` we don't really need `Ref` so let's not use it. function ThreadSafeVarInfo(vi::SimpleVarInfo) diff --git a/src/test_utils.jl b/src/test_utils.jl index bcc649675..9c9034ee5 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -644,6 +644,39 @@ const DEMO_MODELS = ( demo_dot_assume_matrix_dot_observe_matrix(), ) +# Model to test `StaticTransformation` with. +""" + demo_static_transformation() + +Simple model for which [`default_transformation`](@ref) returns a [`StaticTransformation`](@ref). +""" +@model function demo_static_transformation() + s ~ InverseGamma(2, 3) + m ~ Normal(0, sqrt(s)) + 1.5 ~ Normal(m, sqrt(s)) + 2.0 ~ Normal(m, sqrt(s)) + + return (; s, m, x=[1.5, 2.0], logp=getlogp(__varinfo__)) +end + +function DynamicPPL.default_transformation(::Model{typeof(demo_static_transformation)}) + b = Bijectors.stack(Bijectors.Exp{0}(), Bijectors.Identity{0}()) + return DynamicPPL.StaticTransformation(b) +end + +posterior_mean(::Model{typeof(demo_static_transformation)}) = (s=49 / 24, m=7 / 6) +function logprior_true(::Model{typeof(demo_static_transformation)}, s, m) + return logpdf(InverseGamma(2, 3), s) + logpdf(Normal(0, sqrt(s)), m) +end +function loglikelihood_true(::Model{typeof(demo_static_transformation)}, s, m) + return logpdf(Normal(m, sqrt(s)), 1.5) + logpdf(Normal(m, sqrt(s)), 2.0) +end +function logprior_true_with_logabsdet_jacobian( + model::Model{typeof(demo_static_transformation)}, s, m +) + return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) +end + """ marginal_mean_of_samples(chain, varname) diff --git a/src/threadsafe.jl b/src/threadsafe.jl index 4dc37e4ea..85ad0e23e 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -17,6 +17,8 @@ const ThreadSafeVarInfoWithRef{V<:AbstractVarInfo} = ThreadSafeVarInfo{ V,<:AbstractArray{<:Ref} } +transformation(vi::ThreadSafeVarInfo) = transformation(vi.varinfo) + # Instead of updating the log probability of the underlying variables we # just update the array of log probabilities. function acclogp!!(vi::ThreadSafeVarInfo, logp) @@ -53,6 +55,12 @@ function setlogp!!(vi::ThreadSafeVarInfoWithRef, logp) return ThreadSafeVarInfo(setlogp!!(vi.varinfo, logp), vi.logps) end +function BangBang.push!!( + vi::ThreadSafeVarInfo, vn::VarName, r, dist::Distribution, gidset::Set{Selector} +) + return Setfield.@set vi.varinfo = push!!(vi.varinfo, vn, r, dist, gidset) +end + get_num_produce(vi::ThreadSafeVarInfo) = get_num_produce(vi.varinfo) increment_num_produce!(vi::ThreadSafeVarInfo) = increment_num_produce!(vi.varinfo) reset_num_produce!(vi::ThreadSafeVarInfo) = reset_num_produce!(vi.varinfo) @@ -73,27 +81,55 @@ link!(vi::ThreadSafeVarInfo, spl::AbstractSampler) = link!(vi.varinfo, spl) invlink!(vi::ThreadSafeVarInfo, spl::AbstractSampler) = invlink!(vi.varinfo, spl) islinked(vi::ThreadSafeVarInfo, spl::AbstractSampler) = islinked(vi.varinfo, spl) -getindex(vi::ThreadSafeVarInfo, spl::AbstractSampler) = getindex(vi.varinfo, spl) -getindex(vi::ThreadSafeVarInfo, spl::SampleFromPrior) = getindex(vi.varinfo, spl) -getindex(vi::ThreadSafeVarInfo, spl::SampleFromUniform) = getindex(vi.varinfo, spl) +function link!!( + t::AbstractTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model +) + return link!!(t, vi.varinfo, spl, model) +end + +function invlink!!( + t::AbstractTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model +) + return invlink!!(t, vi.varinfo, spl, model) +end + +function maybe_invlink_before_eval!!( + vi::ThreadSafeVarInfo, context::AbstractContext, model::Model +) + # Defer to the wrapped `AbstractVarInfo` object. + # NOTE: When computing `getlogp` for `ThreadSafeVarInfo` we do include the `getlogp(vi.varinfo)` + # hence the log-absdet-jacobian term will correctly be included in the `getlogp(vi)`. + return Setfield.@set vi.varinfo = maybe_invlink_before_eval!!( + vi.varinfo, context, model + ) +end +# `getindex` +getindex(vi::ThreadSafeVarInfo, ::Colon) = getindex(vi.varinfo, Colon()) getindex(vi::ThreadSafeVarInfo, vn::VarName) = getindex(vi.varinfo, vn) +getindex(vi::ThreadSafeVarInfo, vns::AbstractVector{<:VarName}) = getindex(vi.varinfo, vns) function getindex(vi::ThreadSafeVarInfo, vn::VarName, dist::Distribution) return getindex(vi.varinfo, vn, dist) end -getindex(vi::ThreadSafeVarInfo, vns::Vector{<:VarName}) = getindex(vi.varinfo, vns) -function getindex(vi::ThreadSafeVarInfo, vns::Vector{<:VarName}, dist::Distribution) +function getindex(vi::ThreadSafeVarInfo, vns::AbstractVector{<:VarName}, dist::Distribution) return getindex(vi.varinfo, vns, dist) end +getindex(vi::ThreadSafeVarInfo, spl::AbstractSampler) = getindex(vi.varinfo, spl) +getindex_raw(vi::ThreadSafeVarInfo, ::Colon) = getindex_raw(vi.varinfo, Colon()) getindex_raw(vi::ThreadSafeVarInfo, vn::VarName) = getindex_raw(vi.varinfo, vn) +function getindex_raw(vi::ThreadSafeVarInfo, vns::AbstractVector{<:VarName}) + return getindex_raw(vi.varinfo, vns) +end function getindex_raw(vi::ThreadSafeVarInfo, vn::VarName, dist::Distribution) return getindex_raw(vi.varinfo, vn, dist) end -getindex_raw(vi::ThreadSafeVarInfo, vns::Vector{<:VarName}) = getindex_raw(vi.varinfo, vns) -function getindex_raw(vi::ThreadSafeVarInfo, vns::Vector{<:VarName}, dist::Distribution) +function getindex_raw( + vi::ThreadSafeVarInfo, vns::AbstractVector{<:VarName}, dist::Distribution +) return getindex_raw(vi.varinfo, vns, dist) end +getindex_raw(vi::ThreadSafeVarInfo, spl::AbstractSampler) = getindex_raw(vi.varinfo, spl) function BangBang.setindex!!(vi::ThreadSafeVarInfo, val, spl::AbstractSampler) return Setfield.@set vi.varinfo = BangBang.setindex!!(vi.varinfo, val, spl) @@ -121,11 +157,7 @@ function BangBang.empty!!(vi::ThreadSafeVarInfo) return resetlogp!!(Setfield.@set!(vi.varinfo = empty!!(vi.varinfo))) end -function BangBang.push!!( - vi::ThreadSafeVarInfo, vn::VarName, r, dist::Distribution, gidset::Set{Selector} -) - return Setfield.@set vi.varinfo = push!!(vi.varinfo, vn, r, dist, gidset) -end +values_as(vi::ThreadSafeVarInfo, ::Type{T}) where {T} = values_as(vi.varinfo, T) function unset_flag!(vi::ThreadSafeVarInfo, vn::VarName, flag::String) return unset_flag!(vi.varinfo, vn, flag) @@ -135,3 +167,14 @@ function is_flagged(vi::ThreadSafeVarInfo, vn::VarName, flag::String) end tonamedtuple(vi::ThreadSafeVarInfo) = tonamedtuple(vi.varinfo) + +# Transformations. +function settrans!!(vi::ThreadSafeVarInfo, trans::Bool, vn::VarName) + return Setfield.@set vi.varinfo = settrans!!(vi.varinfo, trans, vn) +end +function settrans!!(vi::ThreadSafeVarInfo, spl::AbstractSampler, dist::Distribution) + return Setfield.@set vi.varinfo = settrans!!(vi.varinfo, spl, dist) +end + +istrans(vi::ThreadSafeVarInfo, vn::VarName) = istrans(vi.varinfo, vn) +istrans(vi::ThreadSafeVarInfo, vns::AbstractVector{<:VarName}) = istrans(vi.varinfo, vns) diff --git a/src/transforming.jl b/src/transforming.jl new file mode 100644 index 000000000..f4b50b057 --- /dev/null +++ b/src/transforming.jl @@ -0,0 +1,96 @@ +struct DynamicTransformationContext{isinverse} <: AbstractContext end +NodeTrait(::DynamicTransformationContext) = IsLeaf() + +function tilde_assume( + ::DynamicTransformationContext{isinverse}, right, vn, vi +) where {isinverse} + r = vi[vn, right] + lp = Bijectors.logpdf_with_trans(right, r, !isinverse) + + if istrans(vi, vn) + @assert isinverse "Trying to link already transformed variables" + else + @assert !isinverse "Trying to invlink non-transformed variables" + end + + # Only transform if `!isinverse` since `vi[vn, right]` + # already performs the inverse transformation if it's transformed. + r_transformed = isinverse ? r : bijector(right)(r) + return r, lp, setindex!!(vi, r_transformed, vn) +end + +function dot_tilde_assume( + ::DynamicTransformationContext{isinverse}, + dist::Distribution, + var::AbstractArray, + vns::AbstractArray{<:VarName}, + vi, +) where {isinverse} + r = getindex.((vi,), vns, (dist,)) + b = bijector(dist) + + is_trans_uniques = unique(istrans.((vi,), vns)) + @assert length(is_trans_uniques) == 1 "DynamicTransformationContext only supports transforming all variables" + is_trans = first(is_trans_uniques) + if is_trans + @assert isinverse "Trying to link already transformed variables" + else + @assert !isinverse "Trying to invlink non-transformed variables" + end + + # Only transform if `!isinverse` since `vi[vn, right]` + # already performs the inverse transformation if it's transformed. + r_transformed = isinverse ? r : b.(r) + lp = sum(Bijectors.logpdf_with_trans.((dist,), r, (!isinverse,))) + return r, lp, setindex!!(vi, r_transformed, vns) +end + +function dot_tilde_assume( + ::DynamicTransformationContext{isinverse}, + dist::MultivariateDistribution, + var::AbstractMatrix, + vns::AbstractVector{<:VarName}, + vi::AbstractVarInfo, +) where {isinverse} + @assert length(dist) == size(var, 1) "dimensionality of `var` ($(size(var, 1))) is incompatible with dimensionality of `dist` $(length(dist))" + r = vi[vns, dist] + + # Compute `logpdf` with logabsdet-jacobian correction. + lp = sum(zip(vns, eachcol(r))) do (vn, ri) + return Bijectors.logpdf_with_trans(dist, ri, !isinverse) + end + + # Transform _all_ values. + is_trans_uniques = unique(istrans.((vi,), vns)) + @assert length(is_trans_uniques) == 1 "DynamicTransformationContext only supports transforming all variables" + is_trans = first(is_trans_uniques) + if is_trans + @assert isinverse "Trying to link already transformed variables" + else + @assert !isinverse "Trying to invlink non-transformed variables" + end + + b = bijector(dist) + for (vn, ri) in zip(vns, eachcol(r)) + # Only transform if `!isinverse` since `vi[vn, right]` + # already performs the inverse transformation if it's transformed. + vi = DynamicPPL.setindex!!(vi, isinverse ? ri : b(ri), vn) + end + + return r, lp, vi +end + +function link!!( + t::DynamicTransformation, vi::AbstractVarInfo, spl::AbstractSampler, model::Model +) + return settrans!!(last(evaluate!!(model, vi, DynamicTransformationContext{false}())), t) +end + +function invlink!!( + ::DynamicTransformation, vi::AbstractVarInfo, spl::AbstractSampler, model::Model +) + return settrans!!( + last(evaluate!!(model, vi, DynamicTransformationContext{true}())), + NoTransformation(), + ) +end diff --git a/src/utils.jl b/src/utils.jl index 3779f6412..8f076efee 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -236,22 +236,26 @@ istransformable(::Transformable) = true # Single-sample initialisations # ################################# -inittrans(rng, dist::UnivariateDistribution) = invlink(dist, randrealuni(rng)) +inittrans(rng, dist::UnivariateDistribution) = Bijectors.invlink(dist, randrealuni(rng)) function inittrans(rng, dist::MultivariateDistribution) - return invlink(dist, randrealuni(rng, size(dist)[1])) + return Bijectors.invlink(dist, randrealuni(rng, size(dist)[1])) +end +function inittrans(rng, dist::MatrixDistribution) + return Bijectors.invlink(dist, randrealuni(rng, size(dist)...)) end -inittrans(rng, dist::MatrixDistribution) = invlink(dist, randrealuni(rng, size(dist)...)) ################################ # Multi-sample initialisations # ################################ -inittrans(rng, dist::UnivariateDistribution, n::Int) = invlink(dist, randrealuni(rng, n)) +function inittrans(rng, dist::UnivariateDistribution, n::Int) + return Bijectors.invlink(dist, randrealuni(rng, n)) +end function inittrans(rng, dist::MultivariateDistribution, n::Int) - return invlink(dist, randrealuni(rng, size(dist)[1], n)) + return Bijectors.invlink(dist, randrealuni(rng, size(dist)[1], n)) end function inittrans(rng, dist::MatrixDistribution, n::Int) - return invlink(dist, [randrealuni(rng, size(dist)...) for _ in 1:n]) + return Bijectors.invlink(dist, [randrealuni(rng, size(dist)...) for _ in 1:n]) end ####################### @@ -427,113 +431,195 @@ function BangBang.possible( promote_type(eltype(C), eltype(T)) <: eltype(C) end +# HACK(torfjelde): This makes it so it works on iterators, etc. by default. +# TODO(torfjelde): Do better. """ - nested_getindex(values::AbstractDict, vn::VarName) + unflatten(original, x::AbstractVector) -Return value corresponding to `vn` in `values` by also looking -in the the actual values of the dict. +Return instance of `original` constructed from `x`. +""" +function unflatten(original, x::AbstractVector) + lengths = map(length, original) + end_indices = cumsum(lengths) + return map(zip(original, lengths, end_indices)) do (v, l, end_idx) + start_idx = end_idx - l + 1 + return unflatten(v, @view(x[start_idx:end_idx])) + end +end + +unflatten(::Real, x::Real) = x +unflatten(::Real, x::AbstractVector) = only(x) +unflatten(::AbstractVector{<:Real}, x::Real) = vcat(x) +unflatten(::AbstractVector{<:Real}, x::AbstractVector) = x +unflatten(original::AbstractArray{<:Real}, x::AbstractVector) = reshape(x, size(original)) + +function unflatten(original::Tuple, x::AbstractVector) + lengths = map(length, original) + end_indices = cumsum(lengths) + return ntuple(length(original)) do i + v = original[i] + l = lengths[i] + end_idx = end_indices[i] + start_idx = end_idx - l + 1 + return unflatten(v, @view(x[start_idx:end_idx])) + end +end +function unflatten(original::NamedTuple{names}, x::AbstractVector) where {names} + return NamedTuple{names}(unflatten(values(original), x)) +end +function unflatten(original::AbstractDict, x::AbstractVector) + D = ConstructionBase.constructorof(typeof(original)) + return D(zip(keys(original), unflatten(collect(values(original)), x))) +end + +# TODO: Move `getvalue` and `hasvalue` to AbstractPPL.jl. +""" + getvalue(vals, vn::VarName) + +Return the value(s) in `vals` represented by `vn`. + +Note that this method is different from `getindex`. See examples below. # Examples +For `NamedTuple`: + ```jldoctest -julia> DynamicPPL.nested_getindex(Dict(@varname(x) => [1.0]), @varname(x)) # same as `getindex` +julia> vals = (x = [1.0],); + +julia> DynamicPPL.getvalue(vals, @varname(x)) # same as `getindex` 1-element Vector{Float64}: 1.0 -julia> DynamicPPL.nested_getindex(Dict(@varname(x) => [1.0]), @varname(x[1])) # different from `getindex` +julia> DynamicPPL.getvalue(vals, @varname(x[1])) # different from `getindex` 1.0 -julia> DynamicPPL.nested_getindex(Dict(@varname(x) => [1.0]), @varname(x[2])) +julia> DynamicPPL.getvalue(vals, @varname(x[2])) ERROR: BoundsError: attempt to access 1-element Vector{Float64} at index [2] [...] ``` -""" -function nested_getindex(values::AbstractDict, vn::VarName) - maybeval = get(values, vn, nothing) - if maybeval !== nothing - return maybeval - end - # Split the lens into the key / `parent` and the extraction lens / `child`. - parent, child, issuccess = splitlens(getlens(vn)) do lens - l = lens === nothing ? Setfield.IdentityLens() : lens - haskey(values, VarName(vn, l)) - end - # When combined with `VarInfo`, `nothing` is equivalent to `IdentityLens`. - keylens = parent === nothing ? Setfield.IdentityLens() : parent +For `AbstractDict`: - # If we found a valid split, then we can extract the value. - if !issuccess - # At this point we just throw an error since the key could not be found. - throw(KeyError(vn)) - end +```jldoctest +julia> vals = Dict(@varname(x) => [1.0]); - # TODO: Should we also check that we `canview` the extracted `value` - # rather than just let it fail upon `get` call? - value = values[VarName(vn, keylens)] - return get(value, child) -end +julia> DynamicPPL.getvalue(vals, @varname(x)) # same as `getindex` +1-element Vector{Float64}: + 1.0 + +julia> DynamicPPL.getvalue(vals, @varname(x[1])) # different from `getindex` +1.0 + +julia> DynamicPPL.getvalue(vals, @varname(x[2])) +ERROR: BoundsError: attempt to access 1-element Vector{Float64} at index [2] +[...] +``` + +In the `AbstractDict` case we can also have keys such as `v[1]`: + +```jldoctest +julia> vals = Dict(@varname(x[1]) => [1.0,]); + +julia> DynamicPPL.getvalue(vals, @varname(x[1])) # same as `getindex` +1-element Vector{Float64}: + 1.0 + +julia> DynamicPPL.getvalue(vals, @varname(x[1][1])) # different from `getindex` +1.0 + +julia> DynamicPPL.getvalue(vals, @varname(x[1][2])) +ERROR: BoundsError: attempt to access 1-element Vector{Float64} at index [2] +[...] + +julia> DynamicPPL.getvalue(vals, @varname(x[2][1])) +ERROR: KeyError: key x[2][1] not found +[...] +``` +""" +getvalue(vals::NamedTuple, vn::VarName) = get(vals, vn) +getvalue(vals::AbstractDict, vn::VarName) = nested_getindex(vals, vn) """ - nested_haskey(x, vn::VarName) + hasvalue(vals, vn::VarName) -Determine whether `x` has a mapping for a given `vn`. +Determine whether `vals` has a mapping for a given `vn`, as compatible with [`getvalue`](@ref). # Examples With `x` as a `NamedTuple`: + ```jldoctest -julia> DynamicPPL.nested_haskey((x = 1.0, ), @varname(x)) +julia> DynamicPPL.hasvalue((x = 1.0, ), @varname(x)) true -julia> DynamicPPL.nested_haskey((x = 1.0, ), @varname(x[1])) +julia> DynamicPPL.hasvalue((x = 1.0, ), @varname(x[1])) false -julia> DynamicPPL.nested_haskey((x = [1.0],), @varname(x)) +julia> DynamicPPL.hasvalue((x = [1.0],), @varname(x)) true -julia> DynamicPPL.nested_haskey((x = [1.0],), @varname(x[1])) +julia> DynamicPPL.hasvalue((x = [1.0],), @varname(x[1])) true -julia> DynamicPPL.nested_haskey((x = [1.0],), @varname(x[2])) +julia> DynamicPPL.hasvalue((x = [1.0],), @varname(x[2])) false ``` With `x` as a `AbstractDict`: + ```jldoctest -julia> DynamicPPL.nested_haskey(Dict(@varname(x) => 1.0, ), @varname(x)) +julia> DynamicPPL.hasvalue(Dict(@varname(x) => 1.0, ), @varname(x)) +true + +julia> DynamicPPL.hasvalue(Dict(@varname(x) => 1.0, ), @varname(x[1])) +false + +julia> DynamicPPL.hasvalue(Dict(@varname(x) => [1.0]), @varname(x)) +true + +julia> DynamicPPL.hasvalue(Dict(@varname(x) => [1.0]), @varname(x[1])) true -julia> DynamicPPL.nested_haskey(Dict(@varname(x) => 1.0, ), @varname(x[1])) +julia> DynamicPPL.hasvalue(Dict(@varname(x) => [1.0]), @varname(x[2])) false +``` + +In the `AbstractDict` case we can also have keys such as `v[1]`: + +```jldoctest +julia> vals = Dict(@varname(x[1]) => [1.0,]); -julia> DynamicPPL.nested_haskey(Dict(@varname(x) => [1.0]), @varname(x)) +julia> DynamicPPL.hasvalue(vals, @varname(x[1])) # same as `haskey` true -julia> DynamicPPL.nested_haskey(Dict(@varname(x) => [1.0]), @varname(x[1])) +julia> DynamicPPL.hasvalue(vals, @varname(x[1][1])) # different from `haskey` true -julia> DynamicPPL.nested_haskey(Dict(@varname(x) => [1.0]), @varname(x[2])) +julia> DynamicPPL.hasvalue(vals, @varname(x[1][2])) +false + +julia> DynamicPPL.hasvalue(vals, @varname(x[2][1])) false ``` """ -function nested_haskey(nt::NamedTuple, vn::VarName{sym}) where {sym} +function hasvalue(vals::NamedTuple, vn::VarName{sym}) where {sym} # LHS: Ensure that `nt` indeed has the property we want. # RHS: Ensure that the lens can view into `nt`. - return haskey(nt, sym) && canview(getlens(vn), getproperty(nt, sym)) + return haskey(vals, sym) && canview(getlens(vn), getproperty(vals, sym)) end # For `dictlike` we need to check wether `vn` is "immediately" present, or # if some ancestor of `vn` is present in `dictlike`. -function nested_haskey(dict::AbstractDict, vn::VarName) +function hasvalue(vals::AbstractDict, vn::VarName) # First we check if `vn` is present as is. - haskey(dict, vn) && return true + haskey(vals, vn) && return true # If `vn` is not present, we check any parent-varnames by attempting # to split the lens into the key / `parent` and the extraction lens / `child`. # If `issuccess` is `true`, we found such a split, and hence `vn` is present. parent, child, issuccess = splitlens(getlens(vn)) do lens l = lens === nothing ? Setfield.IdentityLens() : lens - haskey(dict, VarName(vn, l)) + haskey(vals, VarName(vn, l)) end # When combined with `VarInfo`, `nothing` is equivalent to `IdentityLens`. keylens = parent === nothing ? Setfield.IdentityLens() : parent @@ -542,11 +628,43 @@ function nested_haskey(dict::AbstractDict, vn::VarName) issuccess || return false # At this point we just need to check that we `canview` the value. - value = dict[VarName(vn, keylens)] + value = vals[VarName(vn, keylens)] return canview(child, value) end +""" + nested_getindex(values::AbstractDict, vn::VarName) + +Return value corresponding to `vn` in `values` by also looking +in the the actual values of the dict. +""" +function nested_getindex(values::AbstractDict, vn::VarName) + maybeval = get(values, vn, nothing) + if maybeval !== nothing + return maybeval + end + + # Split the lens into the key / `parent` and the extraction lens / `child`. + parent, child, issuccess = splitlens(getlens(vn)) do lens + l = lens === nothing ? Setfield.IdentityLens() : lens + haskey(values, VarName(vn, l)) + end + # When combined with `VarInfo`, `nothing` is equivalent to `IdentityLens`. + keylens = parent === nothing ? Setfield.IdentityLens() : parent + + # If we found a valid split, then we can extract the value. + if !issuccess + # At this point we just throw an error since the key could not be found. + throw(KeyError(vn)) + end + + # TODO: Should we also check that we `canview` the extracted `value` + # rather than just let it fail upon `get` call? + value = values[VarName(vn, keylens)] + return get(value, child) +end + """ float_type_with_fallback(x) diff --git a/src/varinfo.jl b/src/varinfo.jl index 00b99162f..6107b869f 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -103,6 +103,14 @@ struct VarInfo{Tmeta,Tlogp} <: AbstractVarInfo end const UntypedVarInfo = VarInfo{<:Metadata} const TypedVarInfo = VarInfo{<:NamedTuple} +const VarInfoOrThreadSafeVarInfo{Tmeta} = Union{ + VarInfo{Tmeta},ThreadSafeVarInfo{<:VarInfo{Tmeta}} +} + +# NOTE: This is kind of weird, but it effectively preserves the "old" +# behavior where we're allowed to call `link!` on the same `VarInfo` +# multiple times. +transformation(vi::VarInfo) = DynamicTransformation() function VarInfo(old_vi::UntypedVarInfo, spl, x::AbstractVector) new_vi = deepcopy(old_vi) @@ -129,6 +137,11 @@ function VarInfo( end VarInfo(model::Model, args...) = VarInfo(Random.GLOBAL_RNG, model, args...) +unflatten(vi::VarInfo, x::AbstractVector) = unflatten(vi, SampleFromPrior(), x) + +# TODO: deprecate. +unflatten(vi::VarInfo, spl, x::AbstractVector) = VarInfo(vi, spl, x) + # without AbstractSampler function VarInfo(rng::Random.AbstractRNG, model::Model, context::AbstractContext) return VarInfo(rng, model, SampleFromPrior(), context) @@ -268,11 +281,11 @@ Return the index range of `vn` in the metadata of `vi`. getrange(vi::VarInfo, vn::VarName) = getmetadata(vi, vn).ranges[getidx(vi, vn)] """ - getranges(vi::AbstractVarInfo, vns::Vector{<:VarName}) + getranges(vi::VarInfo, vns::Vector{<:VarName}) Return the indices of `vns` in the metadata of `vi` corresponding to `vn`. """ -function getranges(vi::AbstractVarInfo, vns::Vector{<:VarName}) +function getranges(vi::VarInfo, vns::Vector{<:VarName}) return mapreduce(vn -> getrange(vi, vn), vcat, vns; init=Int[]) end @@ -308,7 +321,7 @@ Return the value(s) of `vns`. The values may or may not be transformed to Euclidean space. """ -function getval(vi::AbstractVarInfo, vns::Vector{<:VarName}) +function getval(vi::VarInfo, vns::Vector{<:VarName}) return mapreduce(vn -> getval(vi, vn), vcat, vns) end @@ -357,12 +370,7 @@ Return the set of sampler selectors associated with `vn` in `vi`. """ getgid(vi::VarInfo, vn::VarName) = getmetadata(vi, vn).gids[getidx(vi, vn)] -""" - settrans!!(vi::VarInfo, trans::Bool, vn::VarName) - -Set the `trans` flag value of `vn` in `vi`. -""" -function settrans!!(vi::AbstractVarInfo, trans::Bool, vn::VarName) +function settrans!!(vi::VarInfo, trans::Bool, vn::VarName) if trans set_flag!(vi, vn, "trans") else @@ -372,11 +380,6 @@ function settrans!!(vi::AbstractVarInfo, trans::Bool, vn::VarName) return vi end -""" - settrans!!(vi::AbstractVarInfo, trans) - -Return new instance of `vi` but with `istrans(vi, trans)` now evaluating to `true`. -""" function settrans!!(vi::VarInfo, trans::Bool) for vn in keys(vi) settrans!!(vi, trans, vn) @@ -385,6 +388,14 @@ function settrans!!(vi::VarInfo, trans::Bool) return vi end +settrans!!(vi::VarInfo, trans::NoTransformation) = settrans!!(vi, false) +# HACK: This is necessary to make something like `link!!(transformation, vi, model)` +# work properly, which will transform the variables according to `transformation` +# and then call `settrans!!(vi, transformation)`. An alternative would be to add +# the `transformation` to the `VarInfo` object, but at the moment doesn't seem +# worth it as `VarInfo` has its own way of handling transformations. +settrans!!(vi::VarInfo, trans::AbstractTransformation) = settrans!!(vi, true) + """ syms(vi::VarInfo) @@ -413,7 +424,7 @@ end end # Get all indices of variables belonging to a given sampler -@inline function _getidcs(vi::AbstractVarInfo, spl::Sampler) +@inline function _getidcs(vi::VarInfo, spl::Sampler) # NOTE: 0b00 is the sanity flag for # |\____ getidcs (mask = 0b10) # \_____ getranges (mask = 0b01) @@ -463,8 +474,8 @@ end end # Get all vns of variables belonging to spl -_getvns(vi::AbstractVarInfo, spl::Sampler) = _getvns(vi, spl.selector, Val(getspace(spl))) -function _getvns(vi::AbstractVarInfo, spl::Union{SampleFromPrior,SampleFromUniform}) +_getvns(vi::VarInfo, spl::Sampler) = _getvns(vi, spl.selector, Val(getspace(spl))) +function _getvns(vi::VarInfo, spl::Union{SampleFromPrior,SampleFromUniform}) return _getvns(vi, Selector(), Val(())) end function _getvns(vi::UntypedVarInfo, s::Selector, space) @@ -484,7 +495,7 @@ end end # Get the index (in vals) ranges of all the vns of variables belonging to spl -@inline function _getranges(vi::AbstractVarInfo, spl::Sampler) +@inline function _getranges(vi::VarInfo, spl::Sampler) ## Uncomment the spl.info stuff when it is concretely typed, not Dict{Symbol, Any} #if ~haskey(spl.info, :cache_updated) spl.info[:cache_updated] = CACHERESET end #if haskey(spl.info, :ranges) && (spl.info[:cache_updated] & CACHERANGES) > 0 @@ -497,7 +508,7 @@ end #end end # Get the index (in vals) ranges of all the vns of variables belonging to selector `s` in `space` -@inline function _getranges(vi::AbstractVarInfo, s::Selector, space) +@inline function _getranges(vi::VarInfo, s::Selector, space) return _getranges(vi, _getidcs(vi, s, space)) end @inline function _getranges(vi::UntypedVarInfo, idcs::Vector{Int}) @@ -606,14 +617,6 @@ function TypedVarInfo(vi::UntypedVarInfo) end TypedVarInfo(vi::TypedVarInfo) = vi -""" - empty!!(vi::VarInfo) - -Empty the fields of `vi.metadata` and reset `vi.logp[]` and `vi.num_produce[]` to -zeros. - -This is useful when using a sampling algorithm that assumes an empty `vi`, e.g. `SMC`. -""" function BangBang.empty!!(vi::VarInfo) _empty!(vi.metadata) resetlogp!!(vi) @@ -630,13 +633,11 @@ end end # Functions defined only for UntypedVarInfo -""" - keys(vi::AbstractVarInfo) - -Return an iterator over all `vns` in `vi`. -""" Base.keys(vi::UntypedVarInfo) = keys(vi.metadata.idcs) +# HACK: Necessary to avoid returning `Any[]` which won't dispatch correctly +# on other methods in the codebase which requires `Vector{<:VarName}`. +Base.keys(vi::TypedVarInfo{<:NamedTuple{()}}) = VarName[] @generated function Base.keys(vi::TypedVarInfo{<:NamedTuple{names}}) where {names} expr = Expr(:call) push!(expr.args, :vcat) @@ -657,63 +658,20 @@ function setgid!(vi::VarInfo, gid::Selector, vn::VarName) return push!(getmetadata(vi, vn).gids[getidx(vi, vn)], gid) end -""" - istrans(vi::AbstractVarInfo) +istrans(vi::VarInfo, vn::VarName) = is_flagged(vi, vn, "trans") -Return `true` if `vi` is working in unconstrained space, and `false` -if `vi` is assuming realizations to be in support of the corresponding distributions. -""" -istrans(vi::AbstractVarInfo) = false # `VarInfo` works in constrained space by default. - -""" - istrans(vi::VarInfo, vn::VarName) +getlogp(vi::VarInfo) = vi.logp[] -Return true if `vn`'s values in `vi` are transformed to Euclidean space, and false if -they are in the support of `vn`'s distribution. -""" -istrans(vi::AbstractVarInfo, vn::VarName) = is_flagged(vi, vn, "trans") -function istrans(vi::AbstractVarInfo, vns::AbstractVector{<:VarName}) - return all(Base.Fix1(istrans, vi), vns) -end - -""" - getlogp(vi::VarInfo) - -Return the log of the joint probability of the observed data and parameters sampled in -`vi`. -""" -getlogp(vi::AbstractVarInfo) = vi.logp[] - -""" - setlogp!!(vi::VarInfo, logp) - -Set the log of the joint probability of the observed data and parameters sampled in -`vi` to `logp`, mutating if it makes sense. -""" function setlogp!!(vi::VarInfo, logp) vi.logp[] = logp return vi end -""" - acclogp!!(vi::VarInfo, logp) - -Add `logp` to the value of the log of the joint probability of the observed data and -parameters sampled in `vi`, mutating if it makes sense. -""" function acclogp!!(vi::VarInfo, logp) vi.logp[] += logp return vi end -""" - resetlogp!!(vi::AbstractVarInfo) - -Reset the value of the log of the joint probability of the observed data and parameters -sampled in `vi` to 0, mutating if it makes sense. -""" -resetlogp!!(vi::AbstractVarInfo) = setlogp!!(vi, zero(getlogp(vi))) - """ get_num_produce(vi::VarInfo) @@ -736,18 +694,13 @@ Add 1 to `num_produce` in `vi`. increment_num_produce!(vi::VarInfo) = vi.num_produce[] += 1 """ - reset_num_produce!(vi::AbstractVarInfo) + reset_num_produce!(vi::VarInfo) Reset the value of `num_produce` the log of the joint probability of the observed data and parameters sampled in `vi` to 0. """ -reset_num_produce!(vi::AbstractVarInfo) = set_num_produce!(vi, 0) +reset_num_produce!(vi::VarInfo) = set_num_produce!(vi, 0) -""" - isempty(vi::VarInfo) - -Return true if `vi` is empty and false otherwise. -""" isempty(vi::UntypedVarInfo) = isempty(vi.metadata.idcs) isempty(vi::TypedVarInfo) = _isempty(vi.metadata) @generated function _isempty(metadata::NamedTuple{names}) where {names} @@ -759,6 +712,12 @@ isempty(vi::TypedVarInfo) = _isempty(vi.metadata) end # X -> R for all variables associated with given sampler +function link!!(t::DynamicTransformation, vi::VarInfo, spl::AbstractSampler, model::Model) + # Call `_link!` instead of `link!` to avoid deprecation warning. + _link!(vi, spl) + return vi +end + """ link!(vi::VarInfo, spl::Sampler) @@ -766,7 +725,21 @@ Transform the values of the random variables sampled by `spl` in `vi` from the s of their distributions to the Euclidean space and set their corresponding `"trans"` flag values to `true`. """ -function link!(vi::UntypedVarInfo, spl::Sampler) +function link!(vi::VarInfo, spl::AbstractSampler) + Base.depwarn( + "`link!(varinfo, sampler)` is deprecated, use `link!!(varinfo, sampler, model)` instead.", + :link!, + ) + return _link!(vi, spl) +end +function link!(vi::VarInfo, spl::AbstractSampler, spaceval::Val) + Base.depwarn( + "`link!(varinfo, sampler, spaceval)` is deprecated, use `link!!(varinfo, sampler, model)` instead.", + :link!, + ) + return _link!(vi, spl, spaceval) +end +function _link!(vi::UntypedVarInfo, spl::Sampler) # TODO: Change to a lazy iterator over `vns` vns = _getvns(vi, spl) if ~istrans(vi, vns[1]) @@ -774,21 +747,25 @@ function link!(vi::UntypedVarInfo, spl::Sampler) @debug "X -> ℝ for $(vn)..." dist = getdist(vi, vn) # TODO: Use inplace versions to avoid allocations - setval!( - vi, - vectorize(dist, Bijectors.link(dist, reconstruct(dist, getval(vi, vn)))), - vn, - ) + b = bijector(dist) + x = reconstruct(dist, getval(vi, vn)) + y, logjac = with_logabsdet_jacobian(b, x) + setval!(vi, vectorize(dist, y), vn) + acclogp!!(vi, -logjac) settrans!!(vi, true, vn) end else @warn("[DynamicPPL] attempt to link a linked vi") end end -function link!(vi::TypedVarInfo, spl::AbstractSampler) - return link!(vi, spl, Val(getspace(spl))) +function _link!(vi::TypedVarInfo, spl::AbstractSampler) + Base.depwarn( + "`link!(varinfo, sampler)` is deprecated, use `link!!(varinfo, sampler, model)` instead.", + :link!, + ) + return _link!(vi, spl, Val(getspace(spl))) end -function link!(vi::TypedVarInfo, spl::AbstractSampler, spaceval::Val) +function _link!(vi::TypedVarInfo, spl::AbstractSampler, spaceval::Val) vns = _getvns(vi, spl) return _link!(vi.metadata, vi, vns, spaceval) end @@ -807,14 +784,11 @@ end for vn in f_vns @debug "X -> R for $(vn)..." dist = getdist(vi, vn) - setval!( - vi, - vectorize( - dist, - Bijectors.link(dist, reconstruct(dist, getval(vi, vn))), - ), - vn, - ) + x = reconstruct(dist, getval(vi, vn)) + b = bijector(dist) + y, logjac = with_logabsdet_jacobian(b, x) + setval!(vi, vectorize(dist, y), vn) + acclogp!!(vi, -logjac) settrans!!(vi, true, vn) end else @@ -828,6 +802,20 @@ end end # R -> X for all variables associated with given sampler +function invlink!!(::DynamicTransformation, vi::VarInfo, spl::AbstractSampler, model::Model) + # Call `_invlink!` instead of `invlink!` to avoid deprecation warning. + _invlink!(vi, spl) + return vi +end + +function maybe_invlink_before_eval!!(vi::VarInfo, context::AbstractContext, model::Model) + # Because `VarInfo` does not contain any information about what the transformation + # other than whether or not it has actually been transformed, the best we can do + # is just assume that `default_transformation` is the correct one if `istrans(vi)`. + t = istrans(vi) ? default_transformation(model, vi) : NoTransformation() + return maybe_invlink_before_eval!!(t, vi, context, model) +end + """ invlink!(vi::VarInfo, spl::AbstractSampler) @@ -835,27 +823,43 @@ Transform the values of the random variables sampled by `spl` in `vi` from the Euclidean space back to the support of their distributions and sets their corresponding `"trans"` flag values to `false`. """ -function invlink!(vi::UntypedVarInfo, spl::AbstractSampler) +function invlink!(vi::VarInfo, spl::AbstractSampler) + Base.depwarn( + "`invlink!(varinfo, sampler)` is deprecated, use `invlink!!(varinfo, sampler, model)` instead.", + :invlink!, + ) + return _invlink!(vi, spl) +end + +function invlink!(vi::VarInfo, spl::AbstractSampler, spaceval::Val) + Base.depwarn( + "`invlink!(varinfo, sampler, spaceval)` is deprecated, use `invlink!!(varinfo, sampler, model)` instead.", + :invlink!, + ) + return _invlink!(vi, spl, spaceval) +end + +function _invlink!(vi::UntypedVarInfo, spl::AbstractSampler) vns = _getvns(vi, spl) if istrans(vi, vns[1]) for vn in vns @debug "ℝ -> X for $(vn)..." dist = getdist(vi, vn) - setval!( - vi, - vectorize(dist, Bijectors.invlink(dist, reconstruct(dist, getval(vi, vn)))), - vn, - ) + y = reconstruct(dist, getval(vi, vn)) + b = inverse(bijector(dist)) + x, logjac = with_logabsdet_jacobian(b, y) + setval!(vi, vectorize(dist, x), vn) + acclogp!!(vi, -logjac) settrans!!(vi, false, vn) end else @warn("[DynamicPPL] attempt to invlink an invlinked vi") end end -function invlink!(vi::TypedVarInfo, spl::AbstractSampler) - return invlink!(vi, spl, Val(getspace(spl))) +function _invlink!(vi::TypedVarInfo, spl::AbstractSampler) + return _invlink!(vi, spl, Val(getspace(spl))) end -function invlink!(vi::TypedVarInfo, spl::AbstractSampler, spaceval::Val) +function _invlink!(vi::TypedVarInfo, spl::AbstractSampler, spaceval::Val) vns = _getvns(vi, spl) return _invlink!(vi.metadata, vi, vns, spaceval) end @@ -874,16 +878,11 @@ end for vn in f_vns @debug "ℝ -> X for $(vn)..." dist = getdist(vi, vn) - setval!( - vi, - vectorize( - dist, - Bijectors.invlink( - dist, reconstruct(dist, getval(vi, vn)) - ), - ), - vn, - ) + y = reconstruct(dist, getval(vi, vn)) + b = inverse(bijector(dist)) + x, logjac = with_logabsdet_jacobian(b, y) + setval!(vi, vectorize(dist, x), vn) + acclogp!!(vi, -logjac) settrans!!(vi, false, vn) end else @@ -896,8 +895,10 @@ end return expr end -maybe_link(vi, vn, dist, val) = istrans(vi, vn) ? Bijectors.link(dist, val) : val -maybe_invlink(vi, vn, dist, val) = istrans(vi, vn) ? Bijectors.invlink(dist, val) : val +link(vi, vn, dist, val) = Bijectors.link(dist, val) +invlink(vi, vn, dist, val) = Bijectors.invlink(dist, val) +maybe_link(vi, vn, dist, val) = istrans(vi, vn) ? link(vi, vn, dist, val) : val +maybe_invlink(vi, vn, dist, val) = istrans(vi, vn) ? invlink(vi, vn, dist, val) : val """ islinked(vi::VarInfo, spl::Union{Sampler, SampleFromPrior}) @@ -927,23 +928,13 @@ end # The default getindex & setindex!() for get & set values # NOTE: vi[vn] will always transform the variable to its original space and Julia type -""" - getindex(vi::VarInfo, vn::VarName) - getindex(vi::VarInfo, vns::Vector{<:VarName}) - -Return the current value(s) of `vn` (`vns`) in `vi` in the support of its (their) -distribution(s). - -If the value(s) is (are) transformed to the Euclidean space, it is -(they are) transformed back. -""" -getindex(vi::AbstractVarInfo, vn::VarName) = getindex(vi, vn, getdist(vi, vn)) -function getindex(vi::AbstractVarInfo, vn::VarName, dist::Distribution) +getindex(vi::VarInfo, vn::VarName) = getindex(vi, vn, getdist(vi, vn)) +function getindex(vi::VarInfo, vn::VarName, dist::Distribution) @assert haskey(vi, vn) "[DynamicPPL] attempted to replay unexisting variables in VarInfo" val = getindex_raw(vi, vn, dist) return maybe_invlink(vi, vn, dist, val) end -function getindex(vi::AbstractVarInfo, vns::Vector{<:VarName}) +function getindex(vi::VarInfo, vns::Vector{<:VarName}) # FIXME(torfjelde): Using `getdist(vi, first(vns))` won't be correct in cases # such as `x .~ [Normal(), Exponential()]`. # BUT we also can't fix this here because this will lead to "incorrect" @@ -951,7 +942,7 @@ function getindex(vi::AbstractVarInfo, vns::Vector{<:VarName}) # where by "incorrect" we mean there exists pieces of code expecting this behavior. return getindex(vi, vns, getdist(vi, first(vns))) end -function getindex(vi::AbstractVarInfo, vns::Vector{<:VarName}, dist::Distribution) +function getindex(vi::VarInfo, vns::Vector{<:VarName}, dist::Distribution) @assert haskey(vi, vns[1]) "[DynamicPPL] attempted to replay unexisting variables in VarInfo" vals_linked = mapreduce(vcat, vns) do vn getindex(vi, vn, dist) @@ -959,14 +950,14 @@ function getindex(vi::AbstractVarInfo, vns::Vector{<:VarName}, dist::Distributio return reconstruct(dist, vals_linked, length(vns)) end -getindex_raw(vi::AbstractVarInfo, vn::VarName) = getindex_raw(vi, vn, getdist(vi, vn)) -function getindex_raw(vi::AbstractVarInfo, vn::VarName, dist::Distribution) +getindex_raw(vi::VarInfo, vn::VarName) = getindex_raw(vi, vn, getdist(vi, vn)) +function getindex_raw(vi::VarInfo, vn::VarName, dist::Distribution) return reconstruct(dist, getval(vi, vn)) end -function getindex_raw(vi::AbstractVarInfo, vns::Vector{<:VarName}) +function getindex_raw(vi::VarInfo, vns::Vector{<:VarName}) return getindex_raw(vi, vns, getdist(vi, first(vns))) end -function getindex_raw(vi::AbstractVarInfo, vns::Vector{<:VarName}, dist::Distribution) +function getindex_raw(vi::VarInfo, vns::Vector{<:VarName}, dist::Distribution) return reconstruct(dist, getval(vi, vns), length(vns)) end @@ -977,8 +968,6 @@ Return the current value(s) of the random variables sampled by `spl` in `vi`. The value(s) may or may not be transformed to Euclidean space. """ -getindex(vi::AbstractVarInfo, spl::SampleFromPrior) = copy(getall(vi)) -getindex(vi::AbstractVarInfo, spl::SampleFromUniform) = copy(getall(vi)) getindex(vi::UntypedVarInfo, spl::Sampler) = copy(getval(vi, _getranges(vi, spl))) function getindex(vi::TypedVarInfo, spl::Sampler) # Gets the ranges as a NamedTuple @@ -1002,8 +991,8 @@ Set the current value(s) of the random variable `vn` in `vi` to `val`. The value(s) may or may not be transformed to Euclidean space. """ -setindex!(vi::AbstractVarInfo, val, vn::VarName) = (setval!(vi, val, vn); return vi) -function BangBang.setindex!!(vi::AbstractVarInfo, val, vn::VarName) +setindex!(vi::VarInfo, val, vn::VarName) = (setval!(vi, val, vn); return vi) +function BangBang.setindex!!(vi::VarInfo, val, vn::VarName) return (setindex!(vi, val, vn); return vi) end @@ -1014,7 +1003,7 @@ Set the current value(s) of the random variables sampled by `spl` in `vi` to `va The value(s) may or may not be transformed to Euclidean space. """ -setindex!(vi::AbstractVarInfo, val, spl::SampleFromPrior) = setall!(vi, val) +setindex!(vi::VarInfo, val, spl::SampleFromPrior) = setall!(vi, val) setindex!(vi::UntypedVarInfo, val, spl::Sampler) = setval!(vi, val, _getranges(vi, spl)) function setindex!(vi::TypedVarInfo, val, spl::Sampler) # Gets a `NamedTuple` mapping each symbol to the indices in the symbol's `vals` field sampled from the sampler `spl` @@ -1023,7 +1012,7 @@ function setindex!(vi::TypedVarInfo, val, spl::Sampler) return nothing end -function BangBang.setindex!!(vi::AbstractVarInfo, val, spl::AbstractSampler) +function BangBang.setindex!!(vi::VarInfo, val, spl::AbstractSampler) setindex!(vi, val, spl) return vi end @@ -1044,19 +1033,6 @@ end return expr end -""" - tonamedtuple(vi::VarInfo) - -Convert a `vi` into a `NamedTuple` where each variable symbol maps to the values and -indexing string of the variable. - -For example, a model that had a vector of vector-valued -variables `x` would return - -```julia -(x = ([1.5, 2.0], [3.0, 1.0], ["x[1]", "x[2]"]), ) -``` -""" function tonamedtuple(vi::VarInfo) return tonamedtuple(vi.metadata, vi) end @@ -1079,10 +1055,6 @@ end return map(vn -> vi[vn], f_vns) end -function Base.eltype(vi::AbstractVarInfo, spl::Union{AbstractSampler,SampleFromPrior}) - return eltype(Core.Compiler.return_type(getindex, Tuple{typeof(vi),typeof(spl)})) -end - """ haskey(vi::VarInfo, vn::VarName) @@ -1143,46 +1115,6 @@ function Base.show(io::IO, vi::UntypedVarInfo) return print(io, ")") end -""" - push!!(vi::VarInfo, vn::VarName, r, dist::Distribution) - -Push a new random variable `vn` with a sampled value `r` from a distribution `dist` to -the `VarInfo` `vi`, mutating if it makes sense. -""" -function BangBang.push!!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution) - return BangBang.push!!(vi, vn, r, dist, Set{Selector}([])) -end - -""" - push!!(vi::VarInfo, vn::VarName, r, dist::Distribution, spl::AbstractSampler) - -Push a new random variable `vn` with a sampled value `r` sampled with a sampler `spl` -from a distribution `dist` to `VarInfo` `vi`, if it makes sense. - -The sampler is passed here to invalidate its cache where defined. -""" -function BangBang.push!!( - vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, spl::Sampler -) - return BangBang.push!!(vi, vn, r, dist, spl.selector) -end -function BangBang.push!!( - vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, spl::AbstractSampler -) - return BangBang.push!!(vi, vn, r, dist) -end - -""" - push!!(vi::VarInfo, vn::VarName, r, dist::Distribution, gid::Selector) - -Push a new random variable `vn` with a sampled value `r` sampled with a sampler of -selector `gid` from a distribution `dist` to `VarInfo` `vi`. -""" -function BangBang.push!!( - vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, gid::Selector -) - return BangBang.push!!(vi, vn, r, dist, Set([gid])) -end function BangBang.push!!( vi::VarInfo, vn::VarName, r, dist::Distribution, gidset::Set{Selector} ) @@ -1308,7 +1240,7 @@ end Set `vn`'s `gid` to `Set([spl.selector])`, if `vn` does not have a sampler selector linked and `vn`'s symbol is in the space of `spl`. """ -function updategid!(vi::AbstractVarInfo, vn::VarName, spl::Sampler) +function updategid!(vi::VarInfo, vn::VarName, spl::Sampler) if inspace(vn, getspace(spl)) setgid!(vi, spl.selector, vn) end @@ -1316,11 +1248,11 @@ end # TODO: Maybe rename or something? """ - _apply!(kernel!, vi::AbstractVarInfo, values, keys) + _apply!(kernel!, vi::VarInfo, values, keys) Calls `kernel!(vi, vn, values, keys)` for every `vn` in `vi`. """ -function _apply!(kernel!, vi::AbstractVarInfo, values, keys) +function _apply!(kernel!, vi::VarInfoOrThreadSafeVarInfo, values, keys) keys_strings = map(string, collectmaybe(keys)) num_indices_seen = 0 @@ -1378,7 +1310,7 @@ end end end -function _find_missing_keys(vi::AbstractVarInfo, keys) +function _find_missing_keys(vi::VarInfoOrThreadSafeVarInfo, keys) string_vns = map(string, collectmaybe(Base.keys(vi))) # If `key` isn't subsumed by any element of `string_vns`, it is not present in `vi`. missing_keys = filter(keys) do key @@ -1389,9 +1321,9 @@ function _find_missing_keys(vi::AbstractVarInfo, keys) end """ - setval!(vi::AbstractVarInfo, x) - setval!(vi::AbstractVarInfo, values, keys) - setval!(vi::AbstractVarInfo, chains::AbstractChains, sample_idx::Int, chain_idx::Int) + setval!(vi::VarInfo, x) + setval!(vi::VarInfo, values, keys) + setval!(vi::VarInfo, chains::AbstractChains, sample_idx::Int, chain_idx::Int) Set the values in `vi` to the provided values and leave those which are not present in `x` or `chains` unchanged. @@ -1445,15 +1377,13 @@ julia> var_info[@varname(x[1])] # [✓] unchanged -0.22312984965118443 ``` """ -setval!(vi::AbstractVarInfo, x) = setval!(vi, values(x), keys(x)) -setval!(vi::AbstractVarInfo, values, keys) = _apply!(_setval_kernel!, vi, values, keys) -function setval!( - vi::AbstractVarInfo, chains::AbstractChains, sample_idx::Int, chain_idx::Int -) +setval!(vi::VarInfo, x) = setval!(vi, values(x), keys(x)) +setval!(vi::VarInfo, values, keys) = _apply!(_setval_kernel!, vi, values, keys) +function setval!(vi::VarInfo, chains::AbstractChains, sample_idx::Int, chain_idx::Int) return setval!(vi, chains.value[sample_idx, :, chain_idx], keys(chains)) end -function _setval_kernel!(vi::VarInfo, vn::VarName, values, keys) +function _setval_kernel!(vi::VarInfoOrThreadSafeVarInfo, vn::VarName, values, keys) indices = findall(Base.Fix1(subsumes_string, string(vn)), keys) if !isempty(indices) val = reduce(vcat, values[indices]) @@ -1465,9 +1395,9 @@ function _setval_kernel!(vi::VarInfo, vn::VarName, values, keys) end """ - setval_and_resample!(vi::AbstractVarInfo, x) - setval_and_resample!(vi::AbstractVarInfo, values, keys) - setval_and_resample!(vi::AbstractVarInfo, chains::AbstractChains, sample_idx, chain_idx) + setval_and_resample!(vi::VarInfo, x) + setval_and_resample!(vi::VarInfo, values, keys) + setval_and_resample!(vi::VarInfo, chains::AbstractChains, sample_idx, chain_idx) Set the values in `vi` to the provided values and those which are not present in `x` or `chains` to *be* resampled. @@ -1522,19 +1452,21 @@ julia> var_info[@varname(x[1])] # [✓] changed ## See also - [`setval!`](@ref) """ -function setval_and_resample!(vi::AbstractVarInfo, x) +function setval_and_resample!(vi::VarInfoOrThreadSafeVarInfo, x) return setval_and_resample!(vi, values(x), keys(x)) end -function setval_and_resample!(vi::AbstractVarInfo, values, keys) +function setval_and_resample!(vi::VarInfoOrThreadSafeVarInfo, values, keys) return _apply!(_setval_and_resample_kernel!, vi, values, keys) end function setval_and_resample!( - vi::AbstractVarInfo, chains::AbstractChains, sample_idx::Int, chain_idx::Int + vi::VarInfoOrThreadSafeVarInfo, chains::AbstractChains, sample_idx::Int, chain_idx::Int ) return setval_and_resample!(vi, chains.value[sample_idx, :, chain_idx], keys(chains)) end -function _setval_and_resample_kernel!(vi::VarInfo, vn::VarName, values, keys) +function _setval_and_resample_kernel!( + vi::VarInfoOrThreadSafeVarInfo, vn::VarName, values, keys +) indices = findall(Base.Fix1(subsumes_string, string(vn)), keys) if !isempty(indices) val = reduce(vcat, values[indices]) @@ -1549,94 +1481,8 @@ function _setval_and_resample_kernel!(vi::VarInfo, vn::VarName, values, keys) return indices end -""" - values_as(varinfo[, Type]) - -Return the values/realizations in `varinfo` as `Type`, if implemented. - -If no `Type` is provided, return values as stored in `varinfo`. - -# Examples - -`SimpleVarInfo` with `NamedTuple`: - -```jldoctest -julia> data = (x = 1.0, m = [2.0]); - -julia> values_as(SimpleVarInfo(data)) -(x = 1.0, m = [2.0]) - -julia> values_as(SimpleVarInfo(data), NamedTuple) -(x = 1.0, m = [2.0]) - -julia> values_as(SimpleVarInfo(data), OrderedDict) -OrderedDict{VarName{sym, Setfield.IdentityLens} where sym, Any} with 2 entries: - x => 1.0 - m => [2.0] -``` - -`SimpleVarInfo` with `OrderedDict`: - -```jldoctest -julia> data = OrderedDict{Any,Any}(@varname(x) => 1.0, @varname(m) => [2.0]); - -julia> values_as(SimpleVarInfo(data)) -OrderedDict{Any, Any} with 2 entries: - x => 1.0 - m => [2.0] - -julia> values_as(SimpleVarInfo(data), NamedTuple) -(x = 1.0, m = [2.0]) - -julia> values_as(SimpleVarInfo(data), OrderedDict) -OrderedDict{Any, Any} with 2 entries: - x => 1.0 - m => [2.0] -``` - -`TypedVarInfo`: - -```jldoctest -julia> # Just use an example model to construct the `VarInfo` because we're lazy. - vi = VarInfo(DynamicPPL.TestUtils.demo_assume_dot_observe()); - -julia> vi[@varname(s)] = 1.0; vi[@varname(m)] = 2.0; - -julia> # For the sake of brevity, let's just check the type. - md = values_as(vi); md.s isa DynamicPPL.Metadata -true - -julia> values_as(vi, NamedTuple) -(s = 1.0, m = 2.0) - -julia> values_as(vi, OrderedDict) -OrderedDict{VarName{sym, Setfield.IdentityLens} where sym, Float64} with 2 entries: - s => 1.0 - m => 2.0 -``` - -`UntypedVarInfo`: - -```jldoctest -julia> # Just use an example model to construct the `VarInfo` because we're lazy. - vi = VarInfo(); DynamicPPL.TestUtils.demo_assume_dot_observe()(vi); - -julia> vi[@varname(s)] = 1.0; vi[@varname(m)] = 2.0; - -julia> # For the sake of brevity, let's just check the type. - values_as(vi) isa DynamicPPL.Metadata -true - -julia> values_as(vi, NamedTuple) -(s = 1.0, m = 2.0) - -julia> values_as(vi, OrderedDict) -OrderedDict{VarName{sym, Setfield.IdentityLens} where sym, Float64} with 2 entries: - s => 1.0 - m => 2.0 -``` -""" values_as(vi::VarInfo) = vi.metadata +values_as(vi::VarInfo, ::Type{Vector}) = copy(getall(vi)) function values_as(vi::UntypedVarInfo, ::Type{NamedTuple}) iter = values_from_metadata(vi.metadata) return NamedTuple(map(p -> Symbol(p.first) => p.second, iter)) diff --git a/test/sampler.jl b/test/sampler.jl index 959ec3ccd..ba1a8a600 100644 --- a/test/sampler.jl +++ b/test/sampler.jl @@ -15,7 +15,7 @@ @test length(chains) == N # Expected value of ``X`` where ``X ~ N(2, ...)`` is 2. - @test mean(vi[@varname(m)] for vi in chains) ≈ 2 atol = 0.1 + @test mean(vi[@varname(m)] for vi in chains) ≈ 2 atol = 0.15 # Expected value of ``X`` where ``X ~ IG(2, 3)`` is 3. @test mean(vi[@varname(s)] for vi in chains) ≈ 3 atol = 0.2 diff --git a/test/serialization.jl b/test/serialization.jl index 7ea81e410..a2d9abb36 100644 --- a/test/serialization.jl +++ b/test/serialization.jl @@ -11,7 +11,7 @@ samples_m = last.(samples) @test mean(samples_s) ≈ 3 atol = 0.2 - @test mean(samples_m) ≈ 0 atol = 0.1 + @test mean(samples_m) ≈ 0 atol = 0.15 end @testset "pmap" begin # Add worker processes. diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index 6a8c545ca..a1bbfd503 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -58,6 +58,54 @@ end end + @testset "link!! & invlink!! on $(nameof(model))" for model in + DynamicPPL.TestUtils.DEMO_MODELS + values_constrained = rand(NamedTuple, model) + @testset "$(typeof(vi))" for vi in ( + SimpleVarInfo(Dict()), SimpleVarInfo(values_constrained), VarInfo(model) + ) + for vn in DynamicPPL.TestUtils.varnames(model) + vi = DynamicPPL.setindex!!(vi, get(values_constrained, vn), vn) + end + vi = last(DynamicPPL.evaluate!!(model, vi, DefaultContext())) + lp_orig = getlogp(vi) + + # `link!!` + vi_linked = link!!(deepcopy(vi), model) + lp_linked = getlogp(vi_linked) + values_unconstrained, lp_linked_true = DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian( + model, values_constrained... + ) + # Should result in the correct logjoint. + @test lp_linked ≈ lp_linked_true + # Should be approx. the same as the "lazy" transformation. + @test logjoint(model, vi_linked) ≈ lp_linked + + # TODO: Should not `VarInfo` also error here? The current implementation + # only warns and acts as a no-op. + if vi isa SimpleVarInfo + @test_throws AssertionError link!!(vi_linked, model) + end + + # `invlink!!` + vi_invlinked = invlink!!(deepcopy(vi_linked), model) + lp_invlinked = getlogp(vi_invlinked) + lp_invlinked_true = DynamicPPL.TestUtils.logjoint_true( + model, values_constrained... + ) + # Should result in the correct logjoint. + @test lp_invlinked ≈ lp_invlinked_true + # Should be approx. the same as the "lazy" transformation. + @test logjoint(model, vi_invlinked) ≈ lp_invlinked + + # Should result in same values. + @test all( + DynamicPPL.getindex_raw(vi_invlinked, vn) ≈ get(values_constrained, vn) for + vn in DynamicPPL.TestUtils.varnames(model) + ) + end + end + @testset "SimpleVarInfo on $(nameof(model))" for model in DynamicPPL.TestUtils.DEMO_MODELS # We might need to pre-allocate for the variable `m`, so we need @@ -170,4 +218,56 @@ @test lp ≈ lp_true end end + + @testset "Static transformation" begin + model = DynamicPPL.TestUtils.demo_static_transformation() + + varinfos = setup_varinfos( + model, rand(NamedTuple, model), [@varname(s), @varname(m)] + ) + @testset "$(short_varinfo_name(vi))" for vi in varinfos + # Initialize varinfo and link. + vi_linked = DynamicPPL.link!!(vi, model) + + # Make sure `maybe_invlink_before_eval!!` results in `invlink!!`. + @test !DynamicPPL.istrans( + DynamicPPL.maybe_invlink_before_eval!!( + deepcopy(vi), SamplingContext(), model + ), + ) + + # Resulting varinfo should no longer be transformed. + vi_result = last(DynamicPPL.evaluate!!(model, deepcopy(vi), SamplingContext())) + @test !DynamicPPL.istrans(vi_result) + + # Set the values to something that is out of domain if we're in constrained space. + for vn in keys(vi) + vi_linked = DynamicPPL.setindex!!(vi_linked, -rand(), vn) + end + + retval, vi_linked_result = DynamicPPL.evaluate!!( + model, deepcopy(vi_linked), DefaultContext() + ) + + @test DynamicPPL.getindex_raw(vi_linked, @varname(s)) ≠ retval.s # `s` is unconstrained in original + @test DynamicPPL.getindex_raw(vi_linked_result, @varname(s)) == retval.s # `s` is constrained in result + + # `m` should not be transformed. + @test vi_linked[@varname(m)] == retval.m + @test vi_linked_result[@varname(m)] == retval.m + + # Compare to truth. + retval_unconstrained, lp_true = DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian( + model, retval.s, retval.m + ) + + # Realizations in `vi_linked` should all be equal to the unconstrained realization. + @test DynamicPPL.getindex_raw(vi_linked, @varname(s)) ≈ retval_unconstrained.s + @test DynamicPPL.getindex_raw(vi_linked, @varname(m)) ≈ retval_unconstrained.m + + # The resulting varinfo should hold the correct logp. + lp = getlogp(vi_linked_result) + @test lp ≈ lp_true + end + end end diff --git a/test/varinfo.jl b/test/varinfo.jl index 8ded387d9..a94de4a29 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -45,20 +45,35 @@ @test hash(vn2) == hash(vn1) @test inspace(vn1, (:x,)) - function test_base!(vi) - empty!!(vi) + # Tests for `inspace` + space = (:x, :y, @varname(z[1]), @varname(M[1:10, :])) + + @test inspace(@varname(x), space) + @test inspace(@varname(y), space) + @test inspace(@varname(x[1]), space) + @test inspace(@varname(z[1][1]), space) + @test inspace(@varname(z[1][:]), space) + @test inspace(@varname(z[1][2:3:10]), space) + @test inspace(@varname(M[[2, 3], 1]), space) + @test inspace(@varname(M[:, 1:4]), space) + @test inspace(@varname(M[1, [2, 4, 6]]), space) + @test !inspace(@varname(z[2]), space) + @test !inspace(@varname(z), space) + + function test_base!!(vi_original) + vi = empty!!(vi_original) @test getlogp(vi) == 0 - @test get_num_produce(vi) == 0 + @test isempty(vi[:]) vn = @varname x dist = Normal(0, 1) r = rand(dist) - gid = Selector() + gid = DynamicPPL.Selector() @test isempty(vi) @test ~haskey(vi, vn) @test !(vn in keys(vi)) - push!!(vi, vn, r, dist, gid) + vi = push!!(vi, vn, r, dist, gid) @test ~isempty(vi) @test haskey(vi, vn) @test vn in keys(vi) @@ -68,37 +83,23 @@ @test vi[vn] == r @test vi[SampleFromPrior()][1] == r - vi[vn] = [2 * r] + vi = DynamicPPL.setindex!!(vi, 2 * r, vn) @test vi[vn] == 2 * r @test vi[SampleFromPrior()][1] == 2 * r - vi[SampleFromPrior()] = [3 * r] + vi = DynamicPPL.setindex!!(vi, [3 * r], SampleFromPrior()) @test vi[vn] == 3 * r @test vi[SampleFromPrior()][1] == 3 * r - empty!!(vi) + vi = empty!!(vi) @test isempty(vi) - push!!(vi, vn, r, dist, gid) - - function test_inspace() - space = (:x, :y, @varname(z[1]), @varname(M[1:10, :])) - - @test inspace(@varname(x), space) - @test inspace(@varname(y), space) - @test inspace(@varname(x[1]), space) - @test inspace(@varname(z[1][1]), space) - @test inspace(@varname(z[1][:]), space) - @test inspace(@varname(z[1][2:3:10]), space) - @test inspace(@varname(M[[2, 3], 1]), space) - @test inspace(@varname(M[:, 1:4]), space) - @test inspace(@varname(M[1, [2, 4, 6]]), space) - @test !inspace(@varname(z[2]), space) - @test !inspace(@varname(z), space) - end - return test_inspace() + return push!!(vi, vn, r, dist, gid) end + vi = VarInfo() - test_base!(vi) - test_base!(empty!!(TypedVarInfo(vi))) + test_base!!(vi) + test_base!!(TypedVarInfo(vi)) + test_base!!(SimpleVarInfo()) + test_base!!(SimpleVarInfo(Dict())) end @testset "flags" begin # Test flag setting: