From a425c41e0274efd408a028851302cf1267f2cea1 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Mon, 18 Nov 2024 11:06:43 +0000 Subject: [PATCH] fix test error by discard burn-in's --- test/ext/DynamicPPLMCMCChainsExt.jl | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/test/ext/DynamicPPLMCMCChainsExt.jl b/test/ext/DynamicPPLMCMCChainsExt.jl index 111ee7fbf..b7888acf6 100644 --- a/test/ext/DynamicPPLMCMCChainsExt.jl +++ b/test/ext/DynamicPPLMCMCChainsExt.jl @@ -35,11 +35,13 @@ end # Infer m_lin_reg = linear_reg(xs_train, ys_train) chain_lin_reg = sample( - DynamicPPL.LogDensityFunction(m_lin_reg, DynamicPPL.VarInfo(m_lin_reg)), + DynamicPPL.LogDensityFunction(m_lin_reg), AdvancedHMC.NUTS(0.65), - 200; + 1000; chain_type=MCMCChains.Chains, param_names=[:β], + discard_initial=100, + n_adapt=100, ) # Predict on two last indices @@ -156,9 +158,11 @@ end chain = sample( DynamicPPL.LogDensityFunction(m, DynamicPPL.VarInfo(m)), AdvancedHMC.NUTS(0.65), - 100; + 1000; chain_type=MCMCChains.Chains, param_names=param_names[model], + discard_initial=100, + n_adapt=100, ) chain_predict = DynamicPPL.predict(model(x, missing), chain) mean_prediction = [mean(chain_predict["y[$i]"].data) for i in 1:length(y)]