Skip to content

Commit

Permalink
fix test error by discard burn-in's
Browse files Browse the repository at this point in the history
  • Loading branch information
sunxd3 committed Nov 18, 2024
1 parent c7d08b0 commit a425c41
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions test/ext/DynamicPPLMCMCChainsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,13 @@ end
# Infer
m_lin_reg = linear_reg(xs_train, ys_train)
chain_lin_reg = sample(
DynamicPPL.LogDensityFunction(m_lin_reg, DynamicPPL.VarInfo(m_lin_reg)),
DynamicPPL.LogDensityFunction(m_lin_reg),
AdvancedHMC.NUTS(0.65),
200;
1000;
chain_type=MCMCChains.Chains,
param_names=[],
discard_initial=100,
n_adapt=100,
)

# Predict on two last indices
Expand Down Expand Up @@ -156,9 +158,11 @@ end
chain = sample(
DynamicPPL.LogDensityFunction(m, DynamicPPL.VarInfo(m)),
AdvancedHMC.NUTS(0.65),
100;
1000;
chain_type=MCMCChains.Chains,
param_names=param_names[model],
discard_initial=100,
n_adapt=100,
)
chain_predict = DynamicPPL.predict(model(x, missing), chain)
mean_prediction = [mean(chain_predict["y[$i]"].data) for i in 1:length(y)]
Expand Down

0 comments on commit a425c41

Please sign in to comment.