Skip to content

Commit

Permalink
minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
sunxd3 committed Nov 13, 2024
1 parent 1c1c907 commit bdf90b4
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 12 deletions.
6 changes: 0 additions & 6 deletions ext/DynamicPPLMCMCChainsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -316,12 +316,6 @@ julia> [first(t.θ.x) for t in transitions] # extract samples for `x`
[-1.704630494695469]
```
"""
function transitions_from_chain(
model::DynamicPPL.Model, chain::MCMCChains.Chains; kwargs...
)
return transitions_from_chain(Random.default_rng(), model, chain; kwargs...)
end

function transitions_from_chain(
rng::DynamicPPL.Random.AbstractRNG,
model::DynamicPPL.Model,
Expand Down
11 changes: 5 additions & 6 deletions test/ext/DynamicPPLMCMCChainsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -141,17 +141,16 @@ end
return y ~ MvNormal(mu, error^2 * I)
end

# Some data
x = randn(2, 100)
y = [1 + 2 * a + 3 * b for (a, b) in eachcol(x)]

param_names = Dict(
simple_linear1 => [:intercept, :coef],
simple_linear2 => [:intercept, :coef],
simple_linear3 => [:intercept, Symbol.(["coef[$i]" for i in 1:2])...],
simple_linear4 => [:intercept, :coef1, :coef2],
simple_linear1 => [:intercept, Symbol("coef[1]"), Symbol("coef[2]"), :error],
simple_linear2 => [:intercept, Symbol("coef[1]"), Symbol("coef[2]"), :error],
simple_linear3 => [:intercept, Symbol("coef[1]"), Symbol("coef[2]"), :error],
simple_linear4 => [:intercept, :coef1, :coef2, :error],
)
for model in [simple_linear1, simple_linear2, simple_linear3, simple_linear4]
@testset "$model" for model in [simple_linear1, simple_linear2, simple_linear3, simple_linear4]
m = model(x, y)
chain = sample(
DynamicPPL.LogDensityFunction(m, DynamicPPL.VarInfo(m)),
Expand Down

0 comments on commit bdf90b4

Please sign in to comment.