Skip to content

Commit

Permalink
remake_prob w safetycopy added
Browse files Browse the repository at this point in the history
  • Loading branch information
ivborissov committed Jul 11, 2023
1 parent 2b54510 commit 4ec17e4
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 13 deletions.
17 changes: 11 additions & 6 deletions src/estimator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,13 @@ function estimator(
scenario_pairs::AbstractVector{Pair{Symbol, C}},
parameters_fitted::Vector{Pair{Symbol,Float64}};
parameters::Union{Nothing, Vector{P}}=nothing,
#parameters::Vector{Pair{Symbol, Float64}}=Pair{Symbol, Float64}[],
alg=DEFAULT_ALG,
reltol=DEFAULT_FITTING_RELTOL,
abstol=DEFAULT_FITTING_ABSTOL,
parallel_type=EnsembleSerial(),
kwargs... # other arguments to sim
) where {C<:AbstractScenario, P<:Pair}
) where {C<:AbstractScenario,P<:Pair}

# names of parameters used in fitting and saved in parameters field of solution
parameters_fitted_names = first.(parameters_fitted)
Expand All @@ -62,15 +63,19 @@ function estimator(

# update saveat and initial values
selected_prob = []
for scn in selected_scenario_pairs
prob_i = remake_saveat(last(scn).prob, last(scn).measurements)
for scen_pair in selected_scenario_pairs
scen = last(scen_pair)
prob_i = !isnothing(parameters) ? remake_prob(scen, NamedTuple(parameters); safetycopy=true) : deepcopy(scen.prob)
prob_i = remake_saveat(prob_i, scen.measurements)
#=
prob_i = if !isnothing(parameters)
constants_total_i = merge_strict(last(scn).parameters, NamedTuple(parameters))
u0, p0 = last(scn).init_func(constants_total_i)
constants_total_i = merge_strict(scen.parameters, NamedTuple(parameters))
u0, p0 = scen.init_func(constants_total_i)
remake(prob_i; u0=u0, p=p0)
else
prob_i
end
=#
push!(selected_prob, prob_i)
end

Expand All @@ -80,7 +85,7 @@ function estimator(
scn = last(selected_scenario_pairs[i])
constants_total = merge_strict(scn.parameters, x)
u0, p0 = scn.init_func(constants_total)
remake(selected_prob[i]; u0=u0, p=p0)
remake(deepcopy(selected_prob[i]); u0=u0, p=p0) # tmp fix, waiting for general solution
end
end

Expand Down
15 changes: 13 additions & 2 deletions src/monte_carlo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,15 @@ function mc(
function prob_func(prob,i,repeat)
verbose && println("Processing iteration $i")
progress_bar && (parallel_type != EnsembleDistributed() ? next!(p) : put!(progch, true))

prob_i = remake_prob(scenario, generate_cons(parameters_variation_nt, i); safetycopy=true)
return prob_i
#=
constants_total_i = merge_strict(scenario.parameters, generate_cons(parameters_variation_nt, i))
u0, p0 = scenario.init_func(constants_total_i)
return remake(scenario.prob; u0=u0, p=p0)
=#
end

function _output(sol, i)
Expand All @@ -84,7 +89,8 @@ function mc(
prob = EnsembleProblem(scenario.prob;
prob_func = prob_func,
output_func = _output,
reduction = reduction_func
reduction = reduction_func,
safetycopy = false
)

if progress_bar && (parallel_type == EnsembleDistributed())
Expand Down Expand Up @@ -229,10 +235,14 @@ function mc(
scn_i = last(scenario_pairs[iter_i[2]])
parameters_i = parameters_pregenerated[iter_i[1]]

prob_i = remake_prob(scn_i, parameters_i; safetycopy=true)
return prob_i
#=
constants_total_i = merge_strict(scn_i.parameters, parameters_i)
u0, p0 = scn_i.init_func(constants_total_i)
return remake(scn_i.prob; u0=u0, p=p0)
=#
end

function _output(sol, i)
Expand All @@ -249,7 +259,8 @@ function mc(
prob = EnsembleProblem(last(scenario_pairs[1]).prob;
prob_func = prob_func,
output_func = _output,
reduction = reduction_func
reduction = reduction_func,
safetycopy = false
)

if progress_bar && (parallel_type == EnsembleDistributed())
Expand Down
15 changes: 13 additions & 2 deletions src/ode_problem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,19 @@ collect_saveat(saveat::Vector{S}) where S<:Real = Float64.(saveat)
collect_saveat(saveat::AbstractRange{S}) where S<:Real = Float64.(saveat)
=#

function remake_saveat(prob, saveat; tspan = prob.tspan)

function remake_prob(scen::Scenario, params::NamedTuple; safetycopy=true)
prob0 = safetycopy ? deepcopy(scen.prob) : scen.prob
if length(params) > 0
constants_total = merge_strict(scen.parameters, params)
u0, p0 = scen.init_func(constants_total)
return remake(prob0; u0=u0, p=p0)
else
return prob0
end
end

function remake_saveat(prob, saveat; tspan=prob.tspan)

scb_orig = prob.kwargs[:callback].discrete_callbacks[1].affect!
utype = eltype(prob.u0)
save_scope = scb_orig.save_scope
Expand Down
13 changes: 10 additions & 3 deletions src/simulate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ function sim(

parameters_nt = NamedTuple(parameters)

prob = remake_prob(scenario, parameters_nt; safetycopy=true)
#=
prob = if length(parameters_nt) > 0
constants_total = merge_strict(scenario.parameters, parameters_nt)
u0, p0 = scenario.init_func(constants_total)
Expand All @@ -46,7 +48,7 @@ function sim(
else
deepcopy(scenario.prob)
end

=#
#= variant 2
prob = let
constants_total = merge_strict(scenario.parameters, parameters_nt)
Expand Down Expand Up @@ -112,13 +114,17 @@ function sim(
function prob_func(prob,i,repeat)
next!(p)
scn_i = last(scenario_pairs[i])
prob_i = remake_prob(scn_i, parameters_nt; safetycopy=true)
return prob_i
#=
constants_total_i = merge_strict(scn_i.parameters, parameters_nt)
if length(parameters_nt) > 0
u0, p0 = scn_i.init_func(constants_total_i)
remake(scn_i.prob; u0=u0, p=p0)
else
scn_i.prob
deepcopy(scn_i.prob)
end
=#
end

function _output(sol,i)
Expand All @@ -134,7 +140,8 @@ function sim(
prob = EnsembleProblem(EMPTY_PROBLEM;
prob_func = prob_func,
output_func = _output,
reduction = _reduction
reduction = _reduction,
safetycopy = false # deepcopy scn_i.prob
)

solution = solve(prob, alg, parallel_type;
Expand Down

0 comments on commit 4ec17e4

Please sign in to comment.