Skip to content

Commit

Permalink
Some small improvements (#291)
Browse files Browse the repository at this point in the history
A couple of small improvements/fixes that I noticed recently.

Co-authored-by: Hong Ge <[email protected]>
  • Loading branch information
torfjelde and yebai committed Jul 27, 2021
1 parent 5609335 commit 5472d9d
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 10 deletions.
6 changes: 5 additions & 1 deletion src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,12 @@ end
function unwrap_right_left_vns(
right::MultivariateDistribution, left::AbstractMatrix, vn::VarName
)
# This an expression such as `x .~ MvNormal()` which we interpret as
# x[:, i] ~ MvNormal()
# for `i = size(left, 2)`. Hence the symbol should be `x[:, i]`,
# and we therefore add the `Colon()` below.
vns = map(axes(left, 2)) do i
return VarName(vn, (vn.indexing..., Tuple(i)))
return VarName(vn, (vn.indexing..., Colon(), Tuple(i)))
end
return unwrap_right_left_vns(right, left, vns)
end
Expand Down
12 changes: 5 additions & 7 deletions src/context_implementations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ alg_str(spl::Sampler) = string(nameof(typeof(spl.alg)))
require_gradient(spl::Sampler) = false
require_particles(spl::Sampler) = false

_getindex(x, inds::Tuple) = _getindex(x[first(inds)...], Base.tail(inds))
_getindex(x, inds::Tuple) = _getindex(view(x, first(inds)...), Base.tail(inds))
_getindex(x, inds::Tuple{}) = x

# assume
Expand Down Expand Up @@ -227,11 +227,8 @@ end

# fallback without sampler
function assume(dist::Distribution, vn::VarName, vi)
if !haskey(vi, vn)
error("variable $vn does not exist")
end
r = vi[vn]
return r, Bijectors.logpdf_with_trans(dist, vi[vn], istrans(vi, vn))
return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn))
end

# SampleFromPrior and SampleFromUniform
Expand Down Expand Up @@ -430,12 +427,13 @@ function dot_assume(
# m .~ Normal()
#
# in which case `var` will have `undef` elements, even if `m` is present in `vi`.
r = get_and_set_val!(Random.GLOBAL_RNG, vi, vns, dist, SampleFromPrior())
r = vi[vns]
lp = sum(zip(vns, eachcol(r))) do vn, ri
return Bijectors.logpdf_with_trans(dist, ri, istrans(vi, vn))
end
return r, lp
end

function dot_assume(
rng,
spl::Union{SampleFromPrior,SampleFromUniform},
Expand All @@ -462,7 +460,7 @@ function dot_assume(
# m .~ Normal()
#
# in which case `var` will have `undef` elements, even if `m` is present in `vi`.
r = get_and_set_val!(Random.GLOBAL_RNG, vi, vns, dists, SampleFromPrior())
r = reshape(vi[vec(vns)], size(vns))
lp = sum(Bijectors.logpdf_with_trans.(dists, r, istrans(vi, vns[1])))
return r, lp
end
Expand Down
2 changes: 2 additions & 0 deletions src/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,8 @@ end
end

@inline function findranges(f_ranges, f_idcs)
# Old implementation was using `mapreduce` but turned out
# to be type-unstable.
results = Int[]
for i in f_idcs
append!(results, f_ranges[i])
Expand Down
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
AbstractMCMC = "2.1, 3.0"
AbstractPPL = "0.1.3"
AbstractPPL = "0.1.4, 0.2"
Bijectors = "0.9.5"
Distributions = "< 0.25.11"
DistributionsAD = "0.6.3"
Expand Down
2 changes: 1 addition & 1 deletion test/turing/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"

[compat]
DynamicPPL = "0.12"
DynamicPPL = "0.13"
Turing = "0.15, 0.16"
julia = "1.3"

2 comments on commit 5472d9d

@yebai
Copy link
Member

@yebai yebai commented on 5472d9d Jul 30, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/41839

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.13.0 -m "<description of version>" 5472d9dcd4b7d5b16b962ecf7b200274c58791fd
git push origin v0.13.0

Please sign in to comment.