Skip to content

Commit

Permalink
Merge branch 'master' into tor/symbolics
Browse files Browse the repository at this point in the history
  • Loading branch information
torfjelde committed Jul 29, 2021
2 parents 23141be + 5472d9d commit a787e56
Show file tree
Hide file tree
Showing 6 changed files with 16 additions and 12 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.12.4"
version = "0.13.0"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand All @@ -15,7 +15,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[compat]
AbstractMCMC = "2, 3.0"
AbstractPPL = "0.1.2"
AbstractPPL = "0.2"
Bijectors = "0.5.2, 0.6, 0.7, 0.8, 0.9"
ChainRulesCore = "0.9.7, 0.10"
Distributions = "0.23.8, 0.24, 0.25"
Expand Down
6 changes: 5 additions & 1 deletion src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,12 @@ end
function unwrap_right_left_vns(
right::MultivariateDistribution, left::AbstractMatrix, vn::VarName
)
# This an expression such as `x .~ MvNormal()` which we interpret as
# x[:, i] ~ MvNormal()
# for `i = size(left, 2)`. Hence the symbol should be `x[:, i]`,
# and we therefore add the `Colon()` below.
vns = map(axes(left, 2)) do i
return VarName(vn, (vn.indexing..., Tuple(i)))
return VarName(vn, (vn.indexing..., Colon(), Tuple(i)))
end
return unwrap_right_left_vns(right, left, vns)
end
Expand Down
12 changes: 5 additions & 7 deletions src/context_implementations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ alg_str(spl::Sampler) = string(nameof(typeof(spl.alg)))
require_gradient(spl::Sampler) = false
require_particles(spl::Sampler) = false

_getindex(x, inds::Tuple) = _getindex(x[first(inds)...], Base.tail(inds))
_getindex(x, inds::Tuple) = _getindex(view(x, first(inds)...), Base.tail(inds))
_getindex(x, inds::Tuple{}) = x

# assume
Expand Down Expand Up @@ -227,11 +227,8 @@ end

# fallback without sampler
function assume(dist::Distribution, vn::VarName, vi)
if !haskey(vi, vn)
error("variable $vn does not exist")
end
r = vi[vn]
return r, Bijectors.logpdf_with_trans(dist, vi[vn], istrans(vi, vn))
return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn))
end

# SampleFromPrior and SampleFromUniform
Expand Down Expand Up @@ -430,12 +427,13 @@ function dot_assume(
# m .~ Normal()
#
# in which case `var` will have `undef` elements, even if `m` is present in `vi`.
r = get_and_set_val!(Random.GLOBAL_RNG, vi, vns, dist, SampleFromPrior())
r = vi[vns]
lp = sum(zip(vns, eachcol(r))) do vn, ri
return Bijectors.logpdf_with_trans(dist, ri, istrans(vi, vn))
end
return r, lp
end

function dot_assume(
rng,
spl::Union{SampleFromPrior,SampleFromUniform},
Expand All @@ -462,7 +460,7 @@ function dot_assume(
# m .~ Normal()
#
# in which case `var` will have `undef` elements, even if `m` is present in `vi`.
r = get_and_set_val!(Random.GLOBAL_RNG, vi, vns, dists, SampleFromPrior())
r = reshape(vi[vec(vns)], size(vns))
lp = sum(Bijectors.logpdf_with_trans.(dists, r, istrans(vi, vns[1])))
return r, lp
end
Expand Down
2 changes: 2 additions & 0 deletions src/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,8 @@ end
end

@inline function findranges(f_ranges, f_idcs)
# Old implementation was using `mapreduce` but turned out
# to be type-unstable.
results = Int[]
for i in f_idcs
append!(results, f_ranges[i])
Expand Down
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
AbstractMCMC = "2.1, 3.0"
AbstractPPL = "0.1.3"
AbstractPPL = "0.1.4, 0.2"
Bijectors = "0.9.5"
Distributions = "< 0.25.11"
DistributionsAD = "0.6.3"
Expand Down
2 changes: 1 addition & 1 deletion test/turing/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"

[compat]
DynamicPPL = "0.12"
DynamicPPL = "0.13"
Turing = "0.15, 0.16"
julia = "1.3"

0 comments on commit a787e56

Please sign in to comment.