diff --git a/tutorials/10-bayesian-stochastic-differential-equations/index.qmd b/tutorials/10-bayesian-stochastic-differential-equations/index.qmd index 6166ec270..785d77f26 100644 --- a/tutorials/10-bayesian-stochastic-differential-equations/index.qmd +++ b/tutorials/10-bayesian-stochastic-differential-equations/index.qmd @@ -17,6 +17,7 @@ Pkg.instantiate(); ```{julia} using Turing using DifferentialEquations +using DifferentialEquations.EnsembleAnalysis # Load StatsPlots for visualizations and diagnostics. using StatsPlots @@ -117,6 +118,10 @@ prob_sde = SDEProblem(lotka_volterra!, multiplicative_noise!, u0, tspan, p) ensembleprob = EnsembleProblem(prob_sde) data = solve(ensembleprob, SOSRI(); saveat=0.1, trajectories=1000) plot(EnsembleSummary(data)) + +# We generate new noisy observations based on the stochastic model for the parameter estimation tasks in this tutorial. +# We create our observations by adding random normally distributed noise to the mean of the ensemble simulation. +sdedata = reduce(hcat, timeseries_steps_mean(data).u) + 0.8 * randn(size(reduce(hcat, timeseries_steps_mean(data).u))) ``` ```{julia} @@ -132,17 +137,24 @@ plot(EnsembleSummary(data)) # Simulate stochastic Lotka-Volterra model. p = [α, β, γ, δ, ϕ1, ϕ2] - predicted = solve(prob, SOSRI(); p=p, saveat=0.1) + remake(prob, p = p) + ensembleprob = EnsembleProblem(prob) + predicted = solve(ensembleprob, SOSRI(); saveat=0.1, trajectories = 1000) # Early exit if simulation could not be computed successfully. - if predicted.retcode !== :Success - Turing.@addlogprob! -Inf - return nothing + for i in 1:length(predicted) + if !SciMLBase.successful_retcode(predicted[i]) + Turing.@addlogprob! -Inf + return nothing + end end # Observations. - for i in 1:length(predicted) - data[:, i] ~ MvNormal(predicted[i], σ^2 * I) + # We compute the likelihood for each trajectory of our simulation in order to better approximate the overall likelihood of our choice of parameters + for j in 1:length(predicted) + for i in 1:length(predicted[j]) + data[:, i] ~ MvNormal(predicted[j][i], σ^2 * I) + end end return nothing @@ -154,9 +166,8 @@ Therefore we use NUTS with a low target acceptance rate of `0.25` and specify a SGHMC might be a more suitable algorithm to be used here. ```{julia} -model_sde = fitlv_sde(odedata, prob_sde) +model_sde = fitlv_sde(sdedata, prob_sde) -setadbackend(:forwarddiff) chain_sde = sample( model_sde, NUTS(0.25),