Skip to content

Commit

Permalink
add EventClock
Browse files Browse the repository at this point in the history
  • Loading branch information
baggepinnen committed Mar 4, 2024
1 parent 2d4668d commit 5108b3e
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 12 deletions.
25 changes: 22 additions & 3 deletions src/clock.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,15 +126,14 @@ end
is_concrete_time_domain(x) = x isa Union{AbstractClock, Continuous}

"""
ContinuousClock <: AbstractClock
ContinuousClock([t]; dt)
ContinuousClock()
ContinuousClock(t)
A clock that ticks at each solver step. This clock does generally not have equidistant tick intervals, instead, the tick interval depends on the adaptive step-size slection of the continuous solver, as well as any continuous event handling. If adaptivity of the solver is turned off and there are no continuous events, the tick interval will be given by the fixed solver time step `dt`.
"""
struct ContinuousClock <: AbstractClock
"Independent variable"
t::Union{Nothing, Symbolic}
"Period"
ContinuousClock(t::Union{Num, Symbolic}) = new(value(t))
end
ContinuousClock() = ContinuousClock(nothing)
Expand All @@ -144,3 +143,23 @@ Base.hash(c::ContinuousClock, seed::UInt) = seed ⊻ 0x953d7b9a18874b91
function Base.:(==)(c1::ContinuousClock, c2::ContinuousClock)
((c1.t === nothing || c2.t === nothing) || isequal(c1.t, c2.t))
end

"""
EventClock(t)
EventClock(t, root_equation)
A clock that ticks each time the continuously evaluated `root_equation` is true. This clock is used to trigger the exection of a discrete system when a continuous event occurs.
"""
struct EventClock <: AbstractClock
"Independent variable"
t::Union{Nothing, Symbolic}
cond::Any
EventClock(t::Union{Num, Symbolic}, ex) = new(value(t), ex)
end

sampletime(c) = nothing
Base.hash(c::EventClock, seed::UInt) = hash(c.cond, seed 0x253d7b9a18874b91)
function Base.:(==)(c1::EventClock, c2::EventClock)
((c1.t === nothing || c2.t === nothing) || isequal(c1.t, c2.t)) &&
isequal(c1.cond, c2.cond)
end
7 changes: 4 additions & 3 deletions src/discretedomain.jl
Original file line number Diff line number Diff line change
Expand Up @@ -173,16 +173,17 @@ end
function (xn::Num)(k::ShiftIndex)
@unpack clock, steps = k
x = value(xn)
t = clock.t
# Verify that the independent variables of k and x match and that the expression doesn't have multiple variables
vars = Symbolics.get_variables(x)
length(vars) == 1 ||
error("Cannot shift a multivariate expression $x. Either create a new unknown and shift this, or shift the individual variables in the expression.")
args = Symbolics.arguments(vars[]) # args should be one element vector with the t in x(t)
length(args) == 1 ||
error("Cannot shift an expression with multiple independent variables $x.")
isequal(args[], t) ||
if isa(clock, Clock)
isequal(args[], t) ||
error("Independent variable of $xn is not the same as that of the ShiftIndex $(k.t)")
end

# d, _ = propagate_time_domain(xn)
# if d != clock # this is only required if the variable has another clock
Expand All @@ -193,7 +194,7 @@ function (xn::Num)(k::ShiftIndex)
if steps == 0
return xn # x(k) needs no shift operator if the step of k is 0
end
Shift(t, steps)(xn) # a shift of k steps
Shift(clock isa Inferred ? nothing : clock.t, steps)(xn) # a shift of k steps
end

Base.:+(k::ShiftIndex, i::Int) = ShiftIndex(k.clock, k.steps + i)
Expand Down
1 change: 0 additions & 1 deletion src/systems/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,6 @@ function generate_rootfinding_callback(cbs, sys::AbstractODESystem, dvs = unknow
end

rhss = map(x -> x.rhs, eqs)
root_eq_vars = unique(collect(Iterators.flatten(map(ModelingToolkit.vars, rhss))))

u = map(x -> time_varying_as_func(value(x), sys), dvs)
p = map.(x -> time_varying_as_func(value(x), sys), reorder_parameters(sys, ps))
Expand Down
1 change: 1 addition & 0 deletions src/systems/clock_inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ function infer_clocks!(ci::ClockInference)
isempty(idxs) && continue
if !allequal(var_domain[i] for i in idxs)
display(fullvars[c′])
display(var_domain)
throw(ClockInferenceException("Clocks are not consistent in connected component $(fullvars[c′])"))
end
vd = var_domain[first(idxs)]
Expand Down
13 changes: 10 additions & 3 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1054,9 +1054,16 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map =
if clock isa Clock
PeriodicCallback(DiscreteSaveAffect(affect, sv), clock.dt)
elseif clock isa ContinuousClock
affect = DiscreteSaveAffect(affect, sv)
DiscreteCallback(Returns(true), affect,
initialize = (c, u, t, integrator) -> affect(integrator))
daffect = DiscreteSaveAffect(affect, sv)
DiscreteCallback(Returns(true), daffect,
initialize = (c, u, t, integrator) -> daffect(integrator))
elseif clock isa EventClock
tempsys = @set sys.continuous_events = [SymbolicContinuousCallback(clock.cond)]
cb = generate_rootfinding_callback(tempsys)
daffect = DiscreteSaveAffect(affect, sv)
@set! cb.affect! = daffect
@set! cb.affect_neg! = daffect
cb
else
error("$clock is not a supported clock type.")
end
Expand Down
51 changes: 49 additions & 2 deletions test/clock.jl
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,7 @@ end
## Test continuous clock

c = ModelingToolkit.ContinuousClock(t)

k = ShiftIndex()
@mtkmodel Counter begin
@variables begin
count(t) = 0
Expand Down Expand Up @@ -515,4 +515,51 @@ end
prob = ODEProblem(model, [], (0.0, 10.0))
sol = solve(prob, Tsit5(), kwargshandle = KeywordArgSilent)

@test sol.prob.kwargs[:disc_saved_values][1].t == sol.t[1:2:end] # Test that the discrete-tiem system executed at every step of the continuous solver. The solver saves each time step twice, one state value before discrete affect and one after.
@test sol.prob.kwargs[:disc_saved_values][1].t == sol.t[1:2:end] # Test that the discrete-time system executed at every step of the continuous solver. The solver saves each time step twice, one state value before discrete affect and one after.
@test sol.prob.kwargs[:disc_saved_values][1].saveval[2:end] == sol.u[1:2:(end - 2)]

## Test event clock

k = ShiftIndex()
@mtkmodel CrossingCounter begin
@variables begin
count(t) = 0
u(t) = 0
end
@equations begin
count(k+1) ~ u
end
end

@mtkmodel FirstOrder begin
@variables begin
x(t) = 0
end
@equations begin
D(x) ~ -x + sin(t)
end
end

@mtkmodel FirstOrderWithCrossingCounter begin
@components begin
counter = CrossingCounter()
fo = FirstOrder()
end
begin
c2 = ModelingToolkit.EventClock(t, fo.x ~ 0.1)
end
@equations begin
counter.u ~ Sample(c2)(fo.x)
end
end

@mtkbuild model = FirstOrderWithCrossingCounter()
prob = ODEProblem(model, [], (0.0, 30.0))
sol = solve(prob, Tsit5(), kwargshandle = KeywordArgSilent)

@test all(x -> isapprox(0.1, x[], rtol = 1e-6),
sol.prob.kwargs[:disc_saved_values][1].saveval[2:end]) # omit first value due to initial value of count in count(k+1) ~ u
@test length(sol.prob.kwargs[:disc_saved_values][1].t) == 10 # number of crossings of 0.1

# plot(sol)
# vline!(sol.prob.kwargs[:disc_saved_values][1].t)

0 comments on commit 5108b3e

Please sign in to comment.