Skip to content

Commit ffc7267

Browse files
authored
Use interpolations to update forcing in Basin (#2006)
Fixes another bullet point from #601, namely: - use interpolation objects instead of `update_basin`, which still relies on time being sorted first Now we can also use `parse_static_and_time` for Basin, bringing our node types more in line with each other and more code reuse. For this I had to extend it a bit to support other interpolation types, and support IDs not being in time nor static. We still add time of new data as tstops to ensure the solver knows about net data as it happens. Labeling this as breaking since before this we allowed a node ID to be in both the static and time tables. In `parse_static_and_time` (the other node types) we specifically check for this with a clear error message. I had to update `basic_transient`, to remove the static table there.
1 parent bd5442d commit ffc7267

File tree

10 files changed

+168
-202
lines changed

10 files changed

+168
-202
lines changed

core/src/callback.jl

Lines changed: 42 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,7 @@ function create_callbacks(
1010
u0::ComponentVector,
1111
saveat,
1212
)::Tuple{CallbackSet, SavedResults}
13-
(;
14-
starttime,
15-
basin,
16-
flow_boundary,
17-
level_boundary,
18-
user_demand,
19-
tabulated_rating_curve,
20-
) = parameters
13+
(; starttime, basin, flow_boundary, level_boundary, user_demand) = parameters
2114
callbacks = SciMLBase.DECallback[]
2215

2316
# Check for negative storage
@@ -42,7 +35,13 @@ function create_callbacks(
4235
end
4336

4437
# Update Basin forcings
45-
tstops = get_tstops(basin.time.time, starttime)
38+
# All variables are given at the same time, so just precipitation works
39+
times = [itp.t for itp in basin.forcing.precipitation]
40+
tstops = Float64[]
41+
for t in times
42+
append!(tstops, t)
43+
end
44+
unique!(sort!(tstops))
4645
basin_cb = PresetTimeCallback(tstops, update_basin!; save_positions = (false, false))
4746
push!(callbacks, basin_cb)
4847

@@ -699,27 +698,45 @@ function save_subgrid_level(u, t, integrator)
699698
return copy(integrator.p.subgrid.level)
700699
end
701700

702-
"Load updates from 'Basin / time' into the parameters"
701+
"Update one current vertical flux from an interpolation at time t."
702+
function set_flux!(
703+
fluxes::AbstractVector{Float64},
704+
interpolations::Vector{ScalarConstantInterpolation},
705+
i::Int,
706+
t,
707+
)::Nothing
708+
val = interpolations[i](t)
709+
# keep old value if new value is NaN
710+
if !isnan(val)
711+
fluxes[i] = val
712+
end
713+
return nothing
714+
end
715+
716+
"""
717+
Update all current vertical fluxes from an interpolation at time t.
718+
719+
This runs in a callback rather than the RHS since that gives issues with the discontinuities
720+
in the ConstantInterpolations we use, failing the vertical_flux_means test.
721+
"""
703722
function update_basin!(integrator)::Nothing
704-
(; p) = integrator
723+
(; p, t) = integrator
705724
(; basin) = p
706-
(; node_id, time, vertical_flux) = basin
707-
t = datetime_since(integrator.t, integrator.p.starttime)
708725

709-
rows = searchsorted(time.time, t)
710-
timeblock = view(time, rows)
711-
712-
table = (;
713-
vertical_flux.precipitation,
714-
vertical_flux.potential_evaporation,
715-
vertical_flux.drainage,
716-
vertical_flux.infiltration,
717-
)
726+
update_basin!(basin, t)
727+
return nothing
728+
end
718729

719-
for row in timeblock
720-
i = searchsortedfirst(node_id, NodeID(NodeType.Basin, row.node_id, 0))
721-
set_table_row!(table, row, i)
730+
function update_basin!(basin::Basin, t)::Nothing
731+
(; vertical_flux, forcing) = basin
732+
for id in basin.node_id
733+
i = id.idx
734+
set_flux!(vertical_flux.precipitation, forcing.precipitation, i, t)
735+
set_flux!(vertical_flux.potential_evaporation, forcing.potential_evaporation, i, t)
736+
set_flux!(vertical_flux.infiltration, forcing.infiltration, i, t)
737+
set_flux!(vertical_flux.drainage, forcing.drainage, i, t)
722738
end
739+
723740
return nothing
724741
end
725742

core/src/model.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ function Model(config::Config)::Model
8383
# tell the solver to stop when new data comes in
8484
tstops = Vector{Float64}[]
8585
for schema_version in [
86+
BasinTimeV1,
8687
FlowBoundaryTimeV1,
8788
LevelBoundaryTimeV1,
8889
UserDemandTimeV1,

core/src/parameter.jl

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,10 @@ Base.isless(id_1::NodeID, id_2::NodeID)::Bool = id_1.value < id_2.value
108108
Base.isless(id_1::Integer, id_2::NodeID)::Bool = id_1 < id_2.value
109109
Base.isless(id_1::NodeID, id_2::Integer)::Bool = id_1.value < id_2
110110

111+
"ConstantInterpolation from a Float64 to a Float64"
112+
const ScalarConstantInterpolation =
113+
ConstantInterpolation{Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64, (1,)}
114+
111115
"LinearInterpolation from a Float64 to a Float64"
112116
const ScalarInterpolation = LinearInterpolation{
113117
Vector{Float64},
@@ -373,6 +377,20 @@ end
373377
Dict{String, ScalarInterpolation}[]
374378
end
375379

380+
"""
381+
Data source for Basin parameter updates over time
382+
383+
This is used for both static and dynamic values,
384+
the length of each Vector is the number of Basins.
385+
"""
386+
@kwdef struct BasinForcing
387+
precipitation::Vector{ScalarConstantInterpolation} = ScalarConstantInterpolation[]
388+
potential_evaporation::Vector{ScalarConstantInterpolation} =
389+
ScalarConstantInterpolation[]
390+
drainage::Vector{ScalarConstantInterpolation} = ScalarConstantInterpolation[]
391+
infiltration::Vector{ScalarConstantInterpolation} = ScalarConstantInterpolation[]
392+
end
393+
376394
"""
377395
Requirements:
378396
@@ -389,7 +407,7 @@ else
389407
T = Vector{Float64}
390408
end
391409
"""
392-
@kwdef struct Basin{V, C, CD, D} <: AbstractParameterNode
410+
@kwdef struct Basin{V, CD, D} <: AbstractParameterNode
393411
node_id::Vector{NodeID}
394412
inflow_ids::Vector{Vector{NodeID}} = [NodeID[]]
395413
outflow_ids::Vector{Vector{NodeID}} = [NodeID[]]
@@ -420,8 +438,7 @@ end
420438
# Values for allocation if applicable
421439
demand::Vector{Float64} = zeros(length(node_id))
422440
allocated::Vector{Float64} = zeros(length(node_id))
423-
# Data source for parameter updates
424-
time::StructVector{BasinTimeV1, C, Int}
441+
forcing::BasinForcing = BasinForcing()
425442
# Storage for each Basin at the previous time step
426443
storage_prev::Vector{Float64} = zeros(length(node_id))
427444
# Level for each Basin at the previous time step
@@ -929,11 +946,11 @@ const ModelGraph = MetaGraph{
929946
Float64,
930947
}
931948

932-
@kwdef mutable struct Parameters{C1, C2, C3, C4, C6, C7, C8, C9, C10, C11}
949+
@kwdef mutable struct Parameters{C1, C3, C4, C6, C7, C8, C9, C10, C11}
933950
const starttime::DateTime
934951
const graph::ModelGraph
935952
const allocation::Allocation
936-
const basin::Basin{C1, C2, C3, C4}
953+
const basin::Basin{C1, C3, C4}
937954
const linear_resistance::LinearResistance
938955
const manning_resistance::ManningResistance
939956
const tabulated_rating_curve::TabulatedRatingCurve

core/src/read.jl

Lines changed: 80 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ function parse_static_and_time(
1515
time::Union{StructVector, Nothing} = nothing,
1616
defaults::NamedTuple = (; active = true),
1717
time_interpolatables::Vector{Symbol} = Symbol[],
18+
interpolation_type::Type{<:AbstractInterpolation} = LinearInterpolation,
19+
is_complete::Bool = true,
1820
)::Tuple{NamedTuple, Bool}
1921
# E.g. `PumpStatic`
2022
static_type = eltype(static)
@@ -42,7 +44,16 @@ function parse_static_and_time(
4244
# If the type is a union, then the associated parameter is optional and
4345
# the type is of the form Union{Missing,ActualType}
4446
parameter_type = if parameter_name in time_interpolatables
45-
ScalarInterpolation
47+
# We need the concrete type to store in the parameters
48+
# The interpolation_type is not concrete because they don't have the
49+
# constructors we use
50+
if interpolation_type == LinearInterpolation
51+
ScalarInterpolation
52+
elseif interpolation_type == ConstantInterpolation
53+
ScalarConstantInterpolation
54+
else
55+
error("Unknown interpolation type.")
56+
end
4657
elseif isa(parameter_type, Union)
4758
nonmissingtype(parameter_type)
4859
else
@@ -120,7 +131,7 @@ function parse_static_and_time(
120131
val = defaults[parameter_name]
121132
end
122133
if parameter_name in time_interpolatables
123-
val = LinearInterpolation(
134+
val = interpolation_type(
124135
[val, val],
125136
trivial_timespan;
126137
cache_parameters = true,
@@ -149,19 +160,16 @@ function parse_static_and_time(
149160
for parameter_name in parameter_names
150161
# If the parameter is interpolatable, create an interpolation object
151162
if parameter_name in time_interpolatables
152-
val, is_valid = get_scalar_interpolation(
163+
val = get_scalar_interpolation(
153164
config.starttime,
154165
t_end,
155166
time,
156167
node_id,
157168
parameter_name;
158169
default_value = hasproperty(defaults, parameter_name) ?
159170
defaults[parameter_name] : NaN,
171+
interpolation_type,
160172
)
161-
if !is_valid
162-
errors = true
163-
@error "A $parameter_name time series for $node_id has repeated times, this can not be interpolated."
164-
end
165173
else
166174
# Activity of transient nodes is assumed to be true
167175
if parameter_name == :active
@@ -173,6 +181,19 @@ function parse_static_and_time(
173181
end
174182
getfield(out, parameter_name)[node_id.idx] = val
175183
end
184+
elseif !is_complete
185+
# Apply the defaults just like if it was in static but missing
186+
for parameter_name in parameter_names
187+
val = defaults[parameter_name]
188+
if parameter_name in time_interpolatables
189+
val = interpolation_type(
190+
[val, val],
191+
trivial_timespan;
192+
cache_parameters = true,
193+
)
194+
end
195+
getfield(out, parameter_name)[node_id.idx] = val
196+
end
176197
else
177198
@error "$node_id data not in any table."
178199
errors = true
@@ -651,24 +672,20 @@ function ConcentrationData(
651672
for group in IterTools.groupby(row -> row.substance, data_id)
652673
first_row = first(group)
653674
substance = first_row.substance
654-
itp, no_duplication = get_scalar_interpolation(
675+
itp = get_scalar_interpolation(
655676
config.starttime,
656677
t_end,
657678
StructVector(group),
658679
NodeID(:Basin, first_row.node_id, 0),
659-
:concentration,
680+
:concentration;
681+
interpolation_type = LinearInterpolation,
660682
)
661683
concentration_external_id["concentration_external.$substance"] = itp
662684
if any(itp.u .< 0)
663685
errors = true
664686
@error "Found negative concentration(s) in `Basin / concentration_external`." node_id =
665687
id, substance
666688
end
667-
if !no_duplication
668-
errors = true
669-
@error "There are repeated time values for in `Basin / concentration_external`." node_id =
670-
id substance
671-
end
672689
end
673690
push!(concentration_external, concentration_external_id)
674691
end
@@ -691,26 +708,52 @@ function ConcentrationData(
691708
end
692709

693710
function Basin(db::DB, config::Config, graph::MetaGraph)::Basin
694-
node_id = get_node_ids(db, NodeType.Basin)
695-
n = length(node_id)
696-
697711
# both static and time are optional, but we need fallback defaults
698712
static = load_structvector(db, config, BasinStaticV1)
699713
time = load_structvector(db, config, BasinTimeV1)
700714
state = load_structvector(db, config, BasinStateV1)
701715

702-
# Forcing
703-
precipitation = zeros(n)
704-
potential_evaporation = zeros(n)
705-
drainage = zeros(n)
706-
infiltration = zeros(n)
707-
table = (; precipitation, potential_evaporation, drainage, infiltration)
716+
_, _, node_id, valid =
717+
static_and_time_node_ids(db, static, time, NodeType.Basin; is_complete = false)
718+
if !valid
719+
error("Problems encountered when parsing Basin static and time node IDs.")
720+
end
721+
722+
time_interpolatables =
723+
[:precipitation, :potential_evaporation, :drainage, :infiltration]
724+
parsed_parameters, valid = parse_static_and_time(
725+
db,
726+
config,
727+
Basin;
728+
static,
729+
time,
730+
time_interpolatables,
731+
interpolation_type = ConstantInterpolation,
732+
defaults = (;
733+
precipitation = NaN,
734+
potential_evaporation = NaN,
735+
drainage = NaN,
736+
infiltration = NaN,
737+
),
738+
is_complete = false,
739+
)
708740

709-
set_static_value!(table, node_id, static)
710-
set_current_value!(table, node_id, time, config.starttime)
711-
check_no_nans(table, "Basin")
741+
forcing = BasinForcing(;
742+
parsed_parameters.precipitation,
743+
parsed_parameters.potential_evaporation,
744+
parsed_parameters.drainage,
745+
parsed_parameters.infiltration,
746+
)
712747

713-
vertical_flux = ComponentVector(; table...)
748+
# Current forcing is stored as separate array for BMI access
749+
# These are updated from the interpolation objects at runtime
750+
n = length(node_id)
751+
vertical_flux = ComponentVector(;
752+
precipitation = zeros(n),
753+
potential_evaporation = zeros(n),
754+
drainage = zeros(n),
755+
infiltration = zeros(n),
756+
)
714757

715758
# Profiles
716759
area, level = create_storage_tables(db, config)
@@ -736,11 +779,14 @@ function Basin(db::DB, config::Config, graph::MetaGraph)::Basin
736779
vertical_flux,
737780
storage_to_level,
738781
level_to_area,
739-
time,
782+
forcing,
740783
concentration_data,
741784
concentration_time,
742785
)
743786

787+
# Ensure the initial data is loaded at t0 for BMI
788+
update_basin!(basin, 0.0)
789+
744790
storage0 = get_storages_from_levels(basin, state.level)
745791
@assert length(storage0) == n "Basin / state length differs from number of Basins"
746792
basin.storage0 .= storage0
@@ -1074,39 +1120,30 @@ function user_demand_time!(
10741120

10751121
active[user_demand_idx] = true
10761122
demand_from_timeseries[user_demand_idx] = true
1077-
return_factor_itp, is_valid_return = get_scalar_interpolation(
1123+
return_factor_itp = get_scalar_interpolation(
10781124
config.starttime,
10791125
t_end,
10801126
StructVector(group),
10811127
NodeID(:UserDemand, first_row.node_id, 0),
10821128
:return_factor;
1129+
interpolation_type = LinearInterpolation,
10831130
)
1084-
if is_valid_return
1085-
return_factor[user_demand_idx] = return_factor_itp
1086-
else
1087-
@error "The return_factor(t) relationship for UserDemand $(first_row.node_id) from the time table has repeated timestamps, this can not be interpolated."
1088-
errors = true
1089-
end
1131+
return_factor[user_demand_idx] = return_factor_itp
10901132

10911133
min_level[user_demand_idx] = first_row.min_level
10921134

10931135
priority_idx = findsorted(priorities, first_row.priority)
1094-
demand_p_itp, is_valid_demand = get_scalar_interpolation(
1136+
demand_p_itp = get_scalar_interpolation(
10951137
config.starttime,
10961138
t_end,
10971139
StructVector(group),
10981140
NodeID(:UserDemand, first_row.node_id, 0),
10991141
:demand;
11001142
default_value = 0.0,
1143+
interpolation_type = LinearInterpolation,
11011144
)
11021145
demand[user_demand_idx, priority_idx] = demand_p_itp(0.0)
1103-
1104-
if is_valid_demand
1105-
demand_itp[user_demand_idx][priority_idx] = demand_p_itp
1106-
else
1107-
@error "The demand(t) relationship for UserDemand $(first_row.node_id) of priority $(first_row.priority_idx) from the time table has repeated timestamps, this can not be interpolated."
1108-
errors = true
1109-
end
1146+
demand_itp[user_demand_idx][priority_idx] = demand_p_itp
11101147
end
11111148
return errors
11121149
end

0 commit comments

Comments
 (0)