Skip to content

Commit

Permalink
Merge pull request #53 from hetalang/paramsfx
Browse files Browse the repository at this point in the history
ArrayPartition for parameters
  • Loading branch information
ivborissov committed Mar 8, 2024
2 parents fbe6a20 + ad04094 commit 572707b
Show file tree
Hide file tree
Showing 11 changed files with 274 additions and 27 deletions.
2 changes: 1 addition & 1 deletion src/HetaSimulator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ module HetaSimulator

# diffeq-related pkgs
using SciMLBase
using SciMLBase.RecursiveArrayTools: VectorOfArray, vecarr_to_vectors, DiffEqArray
using SciMLBase.RecursiveArrayTools: VectorOfArray, vecarr_to_vectors, DiffEqArray, ArrayPartition #, NamedArrayPartition
@reexport using SciMLBase.EnsembleAnalysis
@reexport using OrdinaryDiffEq
using Sundials
Expand Down
3 changes: 1 addition & 2 deletions src/estimator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,7 @@ function estimator(
#update_init_values(selected_prob[i], last(selected_scenario_pairs[i]).init_func, x)
scn = last(selected_scenario_pairs[i])
params_total = merge_strict(scn.parameters, x)
u0, p0 = scn.init_func(params_total)
remake(deepcopy(selected_prob[i]); u0=u0, p=p0) # tmp fix, waiting for general solution
remake_prob(selected_prob[i], scn.init_func, params_total; safetycopy=true) #?safetycopy=false
end
end

Expand Down
2 changes: 1 addition & 1 deletion src/events.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ function add_event(evt::TimeEvent, events_save::Tuple{Bool, Bool}=(false,false),
#scn_func(u, t, integrator) = t in tstops

function init_time_event(cb,u,t,integrator)
append!(tstops, evt.condition_func(integrator.sol.prob.p, integrator.sol.prob.tspan))
append!(tstops, evt.condition_func(integrator.sol.prob.p.x[2], integrator.sol.prob.tspan))
tf = integrator.sol.prob.tspan[2]
[add_tstop!(integrator, tstop) for tstop in tstops if tstop <= tf]
#[add_tstop!(integrator, tstop) for tstop in tstops]
Expand Down
10 changes: 3 additions & 7 deletions src/monte_carlo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,10 @@ function mc(

# check input names
scn_cons_names = collect(keys(scenario.parameters))
y_indexes = indexin(first.(parameters_variation), scn_cons_names)
y_lost = isnothing.(y_indexes)
cons_indexes = indexin(first.(parameters_variation), scn_cons_names)
y_lost = isnothing.(cons_indexes)
@assert !any(y_lost) "The following keys are not found: $(first.(parameters_variation)[y_lost])."

# changes reflecting the order of statics ans constants in params vector
len_statics = length(scenario.prob.p) - length(scn_cons_names)
cons_indexes = len_statics .+ y_indexes

parameters_variation_nt = NamedTuple(parameters_variation)

#(parallel_type == EnsembleSerial()) # tmp fix
Expand All @@ -82,7 +78,7 @@ function mc(

function _output(sol, i)
# take numbers from p
values_i = sol.prob.p[cons_indexes]
values_i = sol.prob.p.x[2][cons_indexes]
params_i = NamedTuple(zip(first.(parameters_variation), values_i))
# take simulated values from solution
sv = sol.prob.kwargs[:callback].discrete_callbacks[1].affect!.saved_values
Expand Down
28 changes: 20 additions & 8 deletions src/ode_problem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,13 @@ function build_ode_problem( # used in Scenario constructor only
_saveat = isnothing(saveat) ? time_type[] : time_type.(saveat)

# init
u0, p0 = model.init_func(params)

u0, _p0 = model.init_func(params)

# !!! temporary workaround to match the order of statics and constants
# ArrayPartition should be replaced with NamedArrayPartition when moving to Julia 1.10
#p0 =NamedArrayPartition(statics=_p0[1:length(_p0)-length(params)], constants=collect(eltype(_p0), params))
p0 = ArrayPartition(_p0[1:length(_p0)-length(params)], collect(eltype(_p0), params))

# check observables
if !isnothing(observables_)
records_ind = indexin(observables_, records(model))
Expand Down Expand Up @@ -89,17 +94,24 @@ collect_saveat(saveat::AbstractRange{S}) where S<:Real = Float64.(saveat)
=#

function remake_prob(scen::Scenario, params::NamedTuple; safetycopy=true)
prob0 = safetycopy ? deepcopy(scen.prob) : scen.prob
params_total = merge_strict(scen.parameters, params)
remake_prob(scen.prob, scen.init_func, params_total; safetycopy)
end

function remake_prob(prob::ODEProblem, init_func::Function, params::NamedTuple; safetycopy=true)
prob0 = safetycopy ? deepcopy(prob) : prob
if length(params) > 0
params_total = merge_strict(scen.parameters, params)
u0, p0 = scen.init_func(params_total)
#u0, dep_p0 = init_func(params)
#p0 = NamedArrayPartition(statics=dep_p0, constants=params)
u0, _p0 = init_func(params)
# !!! temporary workaround to match the order of statics and constants
# ArrayPartition should be replaced with NamedArrayPartition when moving to Julia 1.10
#p0 = NamedArrayPartition(statics=_p0[1:length(_p0)-length(params)], constants=collect(eltype(_p0), params))
p0 = ArrayPartition(_p0[1:length(_p0)-length(params)], collect(eltype(_p0), params))
prob0.u0 .= u0
# tmp to if additional params are provided
length(prob0.p) == length(p0) ? prob0.p .= p0 : remake(prob0; p=p0)

return prob0
#return remake(prob0; u0=u0, p=p0)
#tmp. remake produces StackOverflow with EnsembleDistributed(), Julia 1.7 and SciMLBase >= 1.36.0
else
return prob0
end
Expand Down
35 changes: 27 additions & 8 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
# check keys in y before merge
function merge_strict(x::NamedTuple, y::NamedTuple)

miss_keys = setdiff(keys(y), keys(x))
!isempty(miss_keys) && @warn "Keys $(miss_keys) not found."
# yidxs = findall(x->x ∉ keys(x), keys(y))
# @assert isempty(yidxs) "Cannot merge elements with keys $(keys(y)[yidxs]) in strict mode."

merge(x, y)
if length(y) > 0
miss_keys = setdiff(keys(y), keys(x))
!isempty(miss_keys) && @warn "Keys $(miss_keys) not found."
# yidxs = findall(x->x ∉ keys(x), keys(y))
# @assert isempty(yidxs) "Cannot merge elements with keys $(keys(y)[yidxs]) in strict mode."

return merge(x, y)
else
return x
end
end

dictkeys(d::Dict) = (collect(keys(d))...,)
Expand Down Expand Up @@ -88,4 +91,20 @@ function bool(s::AbstractString)
end
bool(b::Bool) = b

sanitizenames!(df::DataFrame) = rename!(df, strip.(names(df)))
sanitizenames!(df::DataFrame) = rename!(df, strip.(names(df)))

# tmp adding methods to ArrayPartition interface to support events
function Base.setindex!(A::ArrayPartition, X::AbstractArray, I::AbstractVector{Int})
Base.@_propagate_inbounds_meta
Base.@boundscheck Base.setindex_shape_check(X, length(I))
Base.require_one_based_indexing(X)
X′ = Base.unalias(A, X)
I′ = Base.unalias(A, I)
count = 1
for i in I′
@inbounds x = X′[count]
A[i] = x
count += 1
end
return A
end
8 changes: 8 additions & 0 deletions test/examples/single_comp_events/index.heta
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
// heta build --source ./in_vitro_model/index.heta

include ./qsp-units.heta
// include model.heta

include one-compartment.heta

// #export { format: DBSolve, filepath: dbsolve_in_vitro };
47 changes: 47 additions & 0 deletions test/examples/single_comp_events/one-compartment.heta
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
//heta build --source ./in_vitro_model/index.heta

t {units: hour};

F @Const {units: UL} = 0.97;
dose @Const {units: pg} = 10;
kabs @Const {units: 1/h} = 1e-2;
kel @Const {units: 1/h} = 1.2e-3;

time_start @Const {units: h} = 100;
time_end @Const {units: h} = 100.1;
time_inj @Const {units: h} = .1;
dose_inj @Const {units: pg} = 1;
switch @Record {units: UL} .= 0;

gut @Compartment {units: L} .= 1;
Vd @Compartment {units: L} .= 5;

a0 @Species {
compartment: gut,
units: pg,
output: true,
isAmount: true
};
c1 @Species {
compartment: Vd,
units: pg/L,
output: true
};
r_inj @Reaction { actors: > a0, units: pg/h };
r_abs @Reaction { actors: a0 > c1, units: pg/h };
r_el @Reaction { actors: c1 > , units: pg/h };

a0 .= F * dose;
c1 .= 0;

r_inj := switch * dose_inj / time_inj;
r_abs := kabs * a0;
r_el := kel * c1 * Vd;

sw @TimeSwitcher { start: time_start, active: false };
switch [sw]= 1;

sw_end @TimeSwitcher { start: time_end, active: false };
switch [sw_end]= 0;

// #export { format: SBML, filepath: sbml_output };
153 changes: 153 additions & 0 deletions test/examples/single_comp_events/qsp-units.heta
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
fmole #defineUnit {
units: [ { kind: mole, multiplier: 1e-15 } ]
};
pmole #defineUnit {
units: [ { kind: mole, multiplier: 1e-12 } ]
};
nmole #defineUnit {
units: [ { kind: mole, multiplier: 1e-9 } ]
};
umole #defineUnit {
units: [ { kind: mole, multiplier: 1e-6 } ]
};
mmole #defineUnit {
units: [ { kind: mole, multiplier: 1e-3 } ]
};
fM #defineUnit {
units: [ { kind: mole, multiplier: 1e-15 }, { kind: litre, exponent: -1} ]
};

pM #defineUnit {
units: [ { kind: mole, multiplier: 1e-12 }, { kind: litre, exponent: -1} ]
};
nM #defineUnit {
units: [ { kind: mole, multiplier: 1e-9 }, { kind: litre, exponent: -1} ]
};
uM #defineUnit {
units: [ { kind: mole, multiplier: 1e-6 }, { kind: litre, exponent: -1} ]
};
mM #defineUnit {
units: [ { kind: mole, multiplier: 1e-3 }, { kind: litre, exponent: -1} ]
};
M #defineUnit {
units: [ { kind: mole }, { kind: litre, exponent: -1} ]
};
kM #defineUnit {
units: [ { kind: mole, multiplier: 1e3 }, { kind: litre, exponent: -1 } ]
};

fL #defineUnit {
units: [ { kind: litre, multiplier: 1e-15 } ]
};
pL #defineUnit {
units: [ { kind: litre, multiplier: 1e-12 } ]
};
nL #defineUnit {
units: [ { kind: litre, multiplier: 1e-9 } ]
};
uL #defineUnit {
units: [ { kind: litre, multiplier: 1e-6 } ]
};
mL #defineUnit {
units: [ { kind: litre, multiplier: 1e-3 } ]
};
dL #defineUnit {
units: [ { kind: litre, multiplier: 1e-1 } ]
};
L #defineUnit {
units: [ { kind: litre } ]
};

fs #defineUnit {
units: [ { kind: second, multiplier: 1e-15 } ]
};
ps #defineUnit {
units: [ { kind: second, multiplier: 1e-12 } ]
};
ns #defineUnit {
units: [ { kind: second, multiplier: 1e-9 } ]
};
us #defineUnit {
units: [ { kind: second, multiplier: 1e-6 } ]
};
ms #defineUnit {
units: [ { kind: second, multiplier: 1e-3 } ]
};
s #defineUnit {
units: [ { kind: second } ]
};
h #defineUnit {
units: [ { kind: hour, multiplier: 1 } ]
};
week #defineUnit {
units: [ { kind: day, multiplier: 7 } ]
};

fg #defineUnit {
units: [ { kind: kilogram, multiplier: 1e-18 } ]
};
pg #defineUnit {
units: [ { kind: kilogram, multiplier: 1e-15 } ]
};
ng #defineUnit {
units: [ { kind: kilogram, multiplier: 1e-12 } ]
};
ug #defineUnit {
units: [ { kind: kilogram, multiplier: 1e-9 } ]
};
mg #defineUnit {
units: [ { kind: kilogram, multiplier: 1e-6 } ]
};
g #defineUnit {
units: [ { kind: kilogram, multiplier: 1e-3 } ]
};
kg #defineUnit {
units: [ { kind: kilogram } ]
};

kat #defineUnit {
units: [ { kind: katal } ]
};

cell #defineUnit {
units: [ { kind: item } ]
};
kcell #defineUnit {
units: [ { kind: item, multiplier: 1e3 } ]
};

cal #defineUnit {
units: [ { kind: joule, multiplier: 4.1868 } ]
};
kcal #defineUnit {
units: [ { kind: joule, multiplier: 4.1868e3 } ]
};

fm #defineUnit {
units: [ { kind: metre, multiplier: 1e-15 } ]
};
pm #defineUnit {
units: [ { kind: metre, multiplier: 1e-12 } ]
};
nm #defineUnit {
units: [ { kind: metre, multiplier: 1e-9 } ]
};
um #defineUnit {
units: [ { kind: metre, multiplier: 1e-6 } ]
};
mm #defineUnit {
units: [ { kind: metre, multiplier: 1e-13 } ]
};
cm #defineUnit {
units: [ { kind: metre, multiplier: 1e-2 } ]
};
m #defineUnit {
units: [ { kind: metre } ]
};

UL #defineUnit {
units: [ { kind: dimensionless } ]
};
percent #defineUnit {
units: [ { kind: dimensionless, multiplier: 1e-2 } ]
};
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ end

@testset "HetaSimulator" begin
@testset "Single-compartment model without events" begin include("single_comp_test.jl") end
@testset "Single-compartment model with events" begin include("single_comp_events_test.jl") end
@testset "Functions used in heta models" begin include("heta_funcs_test.jl") end
end

12 changes: 12 additions & 0 deletions test/single_comp_events_test.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
platform = load_platform("$HetaSimulatorDir/test/examples/single_comp_events");
model = platform.models[:nameless];

scn0 = Scenario(model, (0.,4000.); observables=[:a0,:c1])
s0 = sim(scn0; tstops=[100., 100.1])
@test isapprox(s0(100.1)[:a0]-s0(100.)[:a0], 0.0; atol=1e-2)
@test isapprox(s0(100.1)[:c1]-s0(100.)[:c1], 0.0; atol=1e-2)

scn1 = Scenario(model, (0.,4000.); observables=[:a0,:c1], events_active=[:sw=>true, :sw_end=>true], events_save=(true,true))
s1 = sim(scn1)
@test isapprox(s1(100.1)[:a0]-s1(100.)[:a0], 1.0; atol=1e-2)
@test isapprox(s1(100.1)[:c1]-s1(100.)[:c1], 0.0; atol=1e-2)

0 comments on commit 572707b

Please sign in to comment.