Skip to content

Commit

Permalink
Update to use MTK v9 (#176)
Browse files Browse the repository at this point in the history
  • Loading branch information
jClugstor authored May 24, 2024
1 parent c9c6995 commit 3d725b0
Show file tree
Hide file tree
Showing 7 changed files with 707 additions and 578 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci-threaded.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
- uses: actions/checkout@v2
- uses: julia-actions/setup-julia@v1
with:
version: '1.9'
version: '1.10'
- uses: julia-actions/cache@v1
with:
cache-compiled: "true"
Expand Down
1,239 changes: 683 additions & 556 deletions Manifest.toml

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
YAML = "ddb6d928-2868-570f-bddf-ab3f9cf99eb6"

[compat]
julia = "1.10.1"
ModelingToolkit = "9.13.0"
julia = "1.10"
EasyModelAnalysis = "1"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Expand Down
2 changes: 1 addition & 1 deletion src/SimulationService.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import JSON3
import JSONSchema
import LinearAlgebra: norm
import MathML
import ModelingToolkit: @parameters, substitute, Differential, Num, @variables, ODESystem, ODEProblem, ODESolution, structural_simplify, states, observed, parameters
import ModelingToolkit: @parameters, substitute, Differential, Num, @variables, ODESystem, ODEProblem, ODESolution, structural_simplify, unknowns, observed, parameters
import OpenAPI
import Oxygen
import Pkg
Expand Down
2 changes: 1 addition & 1 deletion src/model_parsers/RegNets.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,5 +92,5 @@ function ASKEM_ACSet_to_MTK(sg::ASKEMRegNetUntyped)

eqs = [D(vertex_funcs[i]) ~ v_rate_vars[i]*vertex_funcs[i] + sum((sg[e,:sign] ? 1 : -1) * e_rate_vars[e] * vertex_funcs[i] * vertex_funcs[sg[e,:src]] for e in incident(sg,i,:tgt); init = 0.0) for i in 1:nv(sg)]

sys = ODESystem(eqs, t, vertex_funcs, all_params; name = :system, defaults)
sys = structural_simplify(ODESystem(eqs, t, vertex_funcs, all_params; name = :system, defaults))
end
16 changes: 8 additions & 8 deletions src/operations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ end
# data
function amr_get(df::DataFrame, sys::ODESystem, ::Val{:data})
@info "parse dataset into calibrate format"
statelist = states(sys)
statelist = unknowns(sys)
statenames = string.(statelist)
statenames = [replace(nm, "(t)" => "") for nm in statenames]

Expand All @@ -188,7 +188,7 @@ function (o::IntermediateResults)(integrator)
(; iter, f, t, u, p) = integrator
if o.last_callback + o.every == iter
o.last_callback = iter
state_dict = Dict(states(f.sys) .=> u)
state_dict = Dict(unknowns(f.sys) .=> u)
param_dict = Dict(parameters(f.sys) .=> p)
publish_to_rabbitmq(; iter=iter, state=state_dict, params = param_dict, id=o.id,
retcode=SciMLBase.check_error(integrator))
Expand All @@ -197,20 +197,20 @@ function (o::IntermediateResults)(integrator)
end

# Intermediate results functor for calibrate
function (o::IntermediateResults)(p,lossval, ode_sol, ts)
function (o::IntermediateResults)(state,loss_val, ode_sol, ts)
if o.last_callback + o.every == o.iter
o.last_callback = o.iter
param_dict = Dict(parameters(ode_sol.prob.f.sys) .=> ode_sol.prob.p)
state_dict = Dict([state => ode_sol(first(ts))[state] for state in states(ode_sol.prob.f.sys)])
publish_to_rabbitmq(; iter = o.iter, loss = lossval, sol_data = state_dict, timesteps = first(ts), params = param_dict, id=o.id)
param_dict = Dict(parameters(ode_sol.prob.f.sys) .=> state.u)
state_dict = Dict([state => ode_sol(first(ts))[state] for state in unknowns(ode_sol.prob.f.sys)])
publish_to_rabbitmq(; iter = o.iter, loss = loss_val, sol_data = state_dict, timesteps = first(ts), params = param_dict, id=o.id)
end
o.iter = o.iter + 1
return false
end
#----------------------------------------------------------------------# dataframe_with_observables
function dataframe_with_observables(sol::ODESolution)
sys = sol.prob.f.sys
names = [states(sys); getproperty.(observed(sys), :lhs)]
names = [unknowns(sys); getproperty.(observed(sys), :lhs)]
cols = ["timestamp" => sol.t; [string(n) => sol[n] for n in names]]
DataFrame(cols)
end
Expand Down Expand Up @@ -280,7 +280,7 @@ end

function solve(o::Calibrate; callback)
prob = ODEProblem(o.sys, [], o.timespan)
statenames = [states(o.sys);getproperty.(observed(o.sys), :lhs)]
statenames = [unknowns(o.sys);getproperty.(observed(o.sys), :lhs)]

# bayesian datafit
if o.calibrate_method == "bayesian"
Expand Down
20 changes: 10 additions & 10 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ end
@testset "Petrinet AMR parsing" begin
amr = JSON3.read(HTTP.get("https://raw.githubusercontent.com/DARPA-ASKEM/simulation-integration/main/data/models/sidarthe.json").body)
sys = SimulationService.amr_get(amr.configuration, ODESystem)
@test string.(states(sys)) == ["Susceptible(t)", "Diagnosed(t)", "Infected(t)", "Ailing(t)", "Recognized(t)", "Healed(t)", "Threatened(t)", "Extinct(t)"]
@test string.(unknowns(sys)) == ["Susceptible(t)", "Diagnosed(t)", "Infected(t)", "Ailing(t)", "Recognized(t)", "Healed(t)", "Threatened(t)", "Extinct(t)"]
@test string.(parameters(sys)) == ["beta", "gamma", "delta", "alpha", "epsilon", "zeta", "lambda", "eta", "rho", "theta", "kappa", "mu", "nu", "xi", "tau", "sigma"]
@test map(x->string(x.lhs), observed(sys)) == ["Cases(t)", "Hospitalizations(t)", "Deaths(t)"]

Expand All @@ -117,14 +117,14 @@ end
df = CSV.read(HTTP.get("https://raw.githubusercontent.com/DARPA-ASKEM/simulation-integration/main/data/datasets/SIDARTHE_dataset.csv").body,DataFrame)
data = SimulationService.amr_get(df, sys, Val(:data))
@test data isa Vector{Pair{SymbolicUtils.BasicSymbolic{Real}, Tuple{Vector{Int64}, Vector{Float64}}}}
@test string.(first.(data)) == string.(states(sys))
@test string.(first.(data)) == string.(unknowns(sys))
@test all(all.(map(first.(last.(data))) do x; x .== 0:89; end))
end

@testset "Stock and Flow AMR parsing" begin
amr = JSON3.read(HTTP.get("https://raw.githubusercontent.com/DARPA-ASKEM/simulation-integration/main/data/models/SIR_stockflow.json").body)
sys = SimulationService.amr_get(amr, ODESystem)
@test string.(states(sys)) == ["S(t)", "I(t)", "R(t)"]
@test string.(unknowns(sys)) == ["S(t)", "I(t)", "R(t)"]
@test string.(parameters(sys)) == ["p_cbeta","p_N", "p_tr"]

priors = SimulationService.amr_get(amr, sys, Val(:priors))
Expand All @@ -138,7 +138,7 @@ end
@testset "RegNet AMR parsing" begin
amr = JSON3.read(HTTP.get("https://raw.githubusercontent.com/DARPA-ASKEM/simulation-integration/main/data/models/LV_sheep_foxes_regnet.json").body)
sys = SimulationService.amr_get(amr, ODESystem)
@test string.(states(sys)) == ["S(t)", "F(t)"]
@test string.(unknowns(sys)) == ["S(t)", "F(t)"]
@test string.(parameters(sys)) == ["beta","delta", "alpha","gamma"]

priors = SimulationService.amr_get(amr, sys, Val(:priors))
Expand Down Expand Up @@ -247,19 +247,19 @@ end

dfsim, dfparam = solve(o, callback = SimulationService.get_callback(op,SimulationService.Calibrate))

statenames = [states(o.sys); getproperty.(observed(o.sys), :lhs)]
statenames = [unknowns(o.sys); getproperty.(observed(o.sys), :lhs)]
@test names(dfsim) == vcat("timestamp",reduce(vcat,[string.("ensemble",i,"_", statenames) for i in 1:size(dfsim,2)÷length(statenames)]))
@test names(dfparam) == string.(parameters(sys))

#calibrate_method = "local"
#o = SimulationService.Calibrate(sys, (0.0, 89.0), priors, data, num_chains, num_iterations, calibrate_method, ode_method)
#dfsim, dfparam = SimulationService.solve(o; callback = nothing)
calibrate_method = "local"
o = SimulationService.Calibrate(sys, (0.0, 89.0), priors, data, num_chains, num_iterations, calibrate_method, ode_method)
dfsim, dfparam = SimulationService.solve(o; callback = SimulationService.get_callback(op,SimulationService.Calibrate))

calibrate_method = "global"
o = SimulationService.Calibrate(sys, (0.0, 89.0), priors, data, num_chains, num_iterations, calibrate_method, ode_method)
dfsim, dfparam = SimulationService.solve(o; callback = SimulationService.get_callback(op,SimulationService.Calibrate))

statenames = [states(o.sys);getproperty.(observed(o.sys), :lhs)]
statenames = [unknowns(o.sys);getproperty.(observed(o.sys), :lhs)]
@test names(dfsim) == vcat("timestamp",string.(statenames))
@test names(dfparam) == string.(parameters(sys))
end
Expand Down Expand Up @@ -415,7 +415,7 @@ end
op.id = "1"
dfsim, dfparam = SimulationService.solve(o, callback = SimulationService.get_callback(op,SimulationService.Calibrate))

statenames = [states(o.sys);getproperty.(observed(o.sys), :lhs)]
statenames = [unknowns(o.sys);getproperty.(observed(o.sys), :lhs)]
@test names(dfsim) == vcat("timestamp",string.(statenames))
end
end
Expand Down

0 comments on commit 3d725b0

Please sign in to comment.