Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes to Bayes SDE notebook #449

Merged
merged 2 commits into from
May 27, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 19 additions & 9 deletions tutorials/10-bayesian-stochastic-differential-equations/index.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ Pkg.instantiate();
```{julia}
using Turing
using DifferentialEquations
using DifferentialEquations.EnsembleAnalysis

# Load StatsPlots for visualizations and diagnostics.
using StatsPlots
Expand Down Expand Up @@ -117,6 +118,10 @@ prob_sde = SDEProblem(lotka_volterra!, multiplicative_noise!, u0, tspan, p)
ensembleprob = EnsembleProblem(prob_sde)
data = solve(ensembleprob, SOSRI(); saveat=0.1, trajectories=1000)
plot(EnsembleSummary(data))

# We generate new noisy observations based on the stochastic model for the parameter estimation tasks in this tutorial.
# We create our observations by adding random normally distributed noise to the mean of the ensemble simulation.
sdedata = reduce(hcat, timeseries_steps_mean(data).u) + 0.8 * randn(size(reduce(hcat, timeseries_steps_mean(data).u)))
```

```{julia}
Expand All @@ -132,17 +137,23 @@ plot(EnsembleSummary(data))

# Simulate stochastic Lotka-Volterra model.
p = [α, β, γ, δ, ϕ1, ϕ2]
predicted = solve(prob, SOSRI(); p=p, saveat=0.1)
remake(prob, p = p)
ensembleprob = EnsembleProblem(prob)
predicted = solve(ensembleprob, SOSRI(); saveat=0.1, trajectories = 1000)

# Early exit if simulation could not be computed successfully.
if predicted.retcode !== :Success
Turing.@addlogprob! -Inf
return nothing
for i in 1:length(predicted)
if !SciMLBase.successful_retcode(predicted[i])
Turing.@addlogprob! -Inf
return nothing
end
end

# Observations.
for i in 1:length(predicted)
data[:, i] ~ MvNormal(predicted[i], σ^2 * I)
for j in 1:length(predicted)
for i in 1:length(predicted[j])
data[:, i] ~ MvNormal(predicted[j][i], σ^2 * I)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding a bit of explanation on the reasoning behind this likelihood function, e.g. following discussions in TuringLang/Turing.jl#2216

end
end

return nothing
Expand All @@ -154,9 +165,8 @@ Therefore we use NUTS with a low target acceptance rate of `0.25` and specify a
SGHMC might be a more suitable algorithm to be used here.

```{julia}
model_sde = fitlv_sde(odedata, prob_sde)
model_sde = fitlv_sde(sdedata, prob_sde)

setadbackend(:forwarddiff)
chain_sde = sample(
model_sde,
NUTS(0.25),
Expand All @@ -165,4 +175,4 @@ chain_sde = sample(
progress=false,
)
plot(chain_sde)
```
```