diff --git a/Project.toml b/Project.toml index 3d7f636f9..b00eb4531 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.12.4" +version = "0.13.0" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" @@ -15,7 +15,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [compat] AbstractMCMC = "2, 3.0" -AbstractPPL = "0.1.2" +AbstractPPL = "0.2" Bijectors = "0.5.2, 0.6, 0.7, 0.8, 0.9" ChainRulesCore = "0.9.7, 0.10" Distributions = "0.23.8, 0.24, 0.25" diff --git a/src/compiler.jl b/src/compiler.jl index 91fe78e2b..c70bbff1e 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -98,8 +98,12 @@ end function unwrap_right_left_vns( right::MultivariateDistribution, left::AbstractMatrix, vn::VarName ) + # This an expression such as `x .~ MvNormal()` which we interpret as + # x[:, i] ~ MvNormal() + # for `i = size(left, 2)`. Hence the symbol should be `x[:, i]`, + # and we therefore add the `Colon()` below. vns = map(axes(left, 2)) do i - return VarName(vn, (vn.indexing..., Tuple(i))) + return VarName(vn, (vn.indexing..., Colon(), Tuple(i))) end return unwrap_right_left_vns(right, left, vns) end diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 3d492f5b1..cd7a92535 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -14,7 +14,7 @@ alg_str(spl::Sampler) = string(nameof(typeof(spl.alg))) require_gradient(spl::Sampler) = false require_particles(spl::Sampler) = false -_getindex(x, inds::Tuple) = _getindex(x[first(inds)...], Base.tail(inds)) +_getindex(x, inds::Tuple) = _getindex(view(x, first(inds)...), Base.tail(inds)) _getindex(x, inds::Tuple{}) = x # assume @@ -227,11 +227,8 @@ end # fallback without sampler function assume(dist::Distribution, vn::VarName, vi) - if !haskey(vi, vn) - error("variable $vn does not exist") - end r = vi[vn] - return r, Bijectors.logpdf_with_trans(dist, vi[vn], istrans(vi, vn)) + return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn)) end # SampleFromPrior and SampleFromUniform @@ -430,12 +427,13 @@ function dot_assume( # m .~ Normal() # # in which case `var` will have `undef` elements, even if `m` is present in `vi`. - r = get_and_set_val!(Random.GLOBAL_RNG, vi, vns, dist, SampleFromPrior()) + r = vi[vns] lp = sum(zip(vns, eachcol(r))) do vn, ri return Bijectors.logpdf_with_trans(dist, ri, istrans(vi, vn)) end return r, lp end + function dot_assume( rng, spl::Union{SampleFromPrior,SampleFromUniform}, @@ -462,7 +460,7 @@ function dot_assume( # m .~ Normal() # # in which case `var` will have `undef` elements, even if `m` is present in `vi`. - r = get_and_set_val!(Random.GLOBAL_RNG, vi, vns, dists, SampleFromPrior()) + r = reshape(vi[vec(vns)], size(vns)) lp = sum(Bijectors.logpdf_with_trans.(dists, r, istrans(vi, vns[1]))) return r, lp end diff --git a/src/varinfo.jl b/src/varinfo.jl index d7991c73d..15a9bb02b 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -511,6 +511,8 @@ end end @inline function findranges(f_ranges, f_idcs) + # Old implementation was using `mapreduce` but turned out + # to be type-unstable. results = Int[] for i in f_idcs append!(results, f_ranges[i]) diff --git a/test/Project.toml b/test/Project.toml index 37c509610..9ca62c79e 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -20,7 +20,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] AbstractMCMC = "2.1, 3.0" -AbstractPPL = "0.1.3" +AbstractPPL = "0.1.4, 0.2" Bijectors = "0.9.5" Distributions = "< 0.25.11" DistributionsAD = "0.6.3" diff --git a/test/turing/Project.toml b/test/turing/Project.toml index 67b8d5645..05e6fb55e 100644 --- a/test/turing/Project.toml +++ b/test/turing/Project.toml @@ -5,6 +5,6 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" [compat] -DynamicPPL = "0.12" +DynamicPPL = "0.13" Turing = "0.15, 0.16" julia = "1.3"