Skip to content

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
torfjelde and github-actions[bot] authored Sep 6, 2023
1 parent a18b435 commit 34e422a
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 8 deletions.
7 changes: 5 additions & 2 deletions ext/DynamicPPLMCMCChainsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ end

# TODO: Add proper overload of `Base.getindex` to Turing.jl?
function _getindex(c::MCMCChains.Chains, sample_idx, vn::DynamicPPL.VarName, chain_idx)
DynamicPPL.supports_varname_indexing(c) || error("Chains do not support indexing using $vn.")
DynamicPPL.supports_varname_indexing(c) ||
error("Chains do not support indexing using $vn.")
return c[sample_idx, c.info.varname_to_symbol[vn], chain_idx]
end

Expand All @@ -23,7 +24,9 @@ function DynamicPPL.generated_quantities(model::DynamicPPL.Model, chain::MCMCCha
for vn in keys(chain.info.varname_to_symbol)
# FIXME: Make it so we can support `chain[sample_idx, vn, chain_idx]`
# indexing instead of the `chain[vn][sample_idx, chain_idx]` below.
DynamicPPL.nested_setindex!(varinfo, _getindex(chain, sample_idx, vn, chain_idx), vn)
DynamicPPL.nested_setindex!(
varinfo, _getindex(chain, sample_idx, vn, chain_idx), vn
)
end
else
# NOTE: This can be quite unreliable (but will warn the uesr in that case).
Expand Down
3 changes: 1 addition & 2 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1061,8 +1061,7 @@ end

# HACK: This should not be here.
@generated function ConstructionBase.setproperties(
C::LinearAlgebra.Cholesky,
patch::NamedTuple{names}
C::LinearAlgebra.Cholesky, patch::NamedTuple{names}
) where {names}
# Return early if we need be.
(:L in names && :U in names) && return :(error("Cannot set both L and U"))
Expand Down
4 changes: 3 additions & 1 deletion src/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1090,7 +1090,9 @@ function _nested_getindex(varinfo::VarInfo, md::Metadata, vn::VarName)
return get(val, lens)
end

nested_setindex!(vi::VarInfo, val, vn::VarName) = _nested_setindex!(vi, getmetadata(vi, vn), val, vn)
function nested_setindex!(vi::VarInfo, val, vn::VarName)
return _nested_setindex!(vi, getmetadata(vi, vn), val, vn)
end
function _nested_setindex!(vi::VarInfo, md::Metadata, val, vn::VarName)
# If `vn` is in `vns`, then we can just use the standard `setindex!`.
vns = md.vns
Expand Down
4 changes: 1 addition & 3 deletions test/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -299,9 +299,7 @@ end
vns_to_syms = OrderedDict(zip(vns, syms))

chain = MCMCChains.Chains(
permutedims(stack(vals)),
syms;
info = (varname_to_symbol = vns_to_syms,)
permutedims(stack(vals)), syms; info=(varname_to_symbol=vns_to_syms,)
)
display(chain)

Expand Down

0 comments on commit 34e422a

Please sign in to comment.