Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Spatial constant rate jump #343

Draft
wants to merge 15 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/jumps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,7 @@ function JumpSet(vj, cj, rj, maj::MassActionJump{S, T, U, V}) where {S <: Number
end

JumpSet(jump::ConstantRateJump) = JumpSet((), (jump,), nothing, nothing)
JumpSet(jumps::AbstractVector{ConstantRateJump}) = JumpSet((), jumps, nothing, nothing)
JumpSet(jump::VariableRateJump) = JumpSet((jump,), (), nothing, nothing)
JumpSet(jump::RegularJump) = JumpSet((), (), jump, nothing)
JumpSet(jump::AbstractMassActionJump) = JumpSet((), (), nothing, jump)
Expand Down
7 changes: 5 additions & 2 deletions src/problem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -198,12 +198,15 @@

## Spatial jumps handling
if spatial_system !== nothing && hopping_constants !== nothing
(num_crjs(jumps) == num_vrjs(jumps) == 0) ||
error("Spatial aggregators only support MassActionJumps currently.")
(num_vrjs(jumps) == 0) ||
error("Spatial aggregators currently only support MassActionJumps and ConstantRateJumps.")

if is_spatial(aggregator)
kwargs = merge((; hopping_constants, spatial_system), kwargs)
else
if num_crjs(jumps) != 0
error("Use a spatial SSA, e.g. DirectCRDirect in order to use ConstantRateJumps.")

Check warning on line 208 in src/problem.jl

View check run for this annotation

Codecov / codecov/patch

src/problem.jl#L208

Added line #L208 was not covered by tests
end
prob, maj = flatten(maj, prob, spatial_system, hopping_constants; kwargs...)
end
end
Expand Down
10 changes: 8 additions & 2 deletions src/spatial/directcrdirect.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@

# a dependency graph is needed
if dep_graph === nothing
if length(rx_rates.cr_jumps) != 0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if length(rx_rates.cr_jumps) != 0
if !isempty(rx_rates.cr_jumps)

error("Provide a dependency graph to use DirectCRDirect with constant rate jumps.")

Check warning on line 43 in src/spatial/directcrdirect.jl

View check run for this annotation

Codecov / codecov/patch

src/spatial/directcrdirect.jl#L43

Added line #L43 was not covered by tests
end
dg = make_dependency_graph(num_specs, rx_rates.ma_jumps)
else
dg = dep_graph
Expand All @@ -54,6 +57,9 @@
end

if jumptovars_map === nothing
if length(rx_rates.cr_jumps) != 0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if length(rx_rates.cr_jumps) != 0
if !isempty(rx_rates.cr_jumps)

error("Provide a jump-to-species dependency graph to use DirectCRDirect with constant rate jumps.")

Check warning on line 61 in src/spatial/directcrdirect.jl

View check run for this annotation

Codecov / codecov/patch

src/spatial/directcrdirect.jl#L61

Added line #L61 was not covered by tests
end
jtov_map = jump_to_vars_map(rx_rates.ma_jumps)
else
jtov_map = jumptovars_map
Expand Down Expand Up @@ -94,7 +100,7 @@

next_jump = SpatialJump{Int}(typemax(Int), typemax(Int), typemax(Int)) #a placeholder
next_jump_time = typemax(typeof(end_time))
rx_rates = RxRates(num_sites(spatial_system), majumps)
rx_rates = RxRates(num_sites(spatial_system), majumps, constant_jumps)
hop_rates = HopRates(hopping_constants, spatial_system)
site_rates = zeros(typeof(end_time), num_sites(spatial_system))

Expand Down Expand Up @@ -199,4 +205,4 @@

number of constant rate jumps
"""
num_constant_rate_jumps(aggregator::DirectCRDirectJumpAggregation) = 0
num_constant_rate_jumps(aggregator::DirectCRDirectJumpAggregation) = length(aggregator.rx_rates.cr_jumps)

Check warning on line 208 in src/spatial/directcrdirect.jl

View check run for this annotation

Codecov / codecov/patch

src/spatial/directcrdirect.jl#L208

Added line #L208 was not covered by tests
10 changes: 8 additions & 2 deletions src/spatial/nsm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@

# a dependency graph is needed
if dep_graph === nothing
if length(rx_rates.cr_jumps) != 0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above

error("Provide a dependency graph to use NSM with constant rate jumps.")

Check warning on line 36 in src/spatial/nsm.jl

View check run for this annotation

Codecov / codecov/patch

src/spatial/nsm.jl#L36

Added line #L36 was not covered by tests
end
dg = make_dependency_graph(num_specs, rx_rates.ma_jumps)
else
dg = dep_graph
Expand All @@ -47,6 +50,9 @@
end

if jumptovars_map === nothing
if length(rx_rates.cr_jumps) != 0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above

error("Provide a jump-to-species dependency graph to use NSM with constant rate jumps.")

Check warning on line 54 in src/spatial/nsm.jl

View check run for this annotation

Codecov / codecov/patch

src/spatial/nsm.jl#L54

Added line #L54 was not covered by tests
end
jtov_map = jump_to_vars_map(rx_rates.ma_jumps)
else
jtov_map = jumptovars_map
Expand Down Expand Up @@ -83,7 +89,7 @@

next_jump = SpatialJump{Int}(typemax(Int), typemax(Int), typemax(Int)) #a placeholder
next_jump_time = typemax(typeof(end_time))
rx_rates = RxRates(num_sites(spatial_system), majumps)
rx_rates = RxRates(num_sites(spatial_system), majumps, constant_jumps)
hop_rates = HopRates(hopping_constants, spatial_system)

NSMJumpAggregation(next_jump, next_jump_time, end_time, rx_rates, hop_rates,
Expand Down Expand Up @@ -187,4 +193,4 @@

number of constant rate jumps
"""
num_constant_rate_jumps(aggregator::NSMJumpAggregation) = 0
num_constant_rate_jumps(aggregator::NSMJumpAggregation) = length(aggregator.rx_rates.cr_jumps)

Check warning on line 196 in src/spatial/nsm.jl

View check run for this annotation

Codecov / codecov/patch

src/spatial/nsm.jl#L196

Added line #L196 was not covered by tests
46 changes: 36 additions & 10 deletions src/spatial/reaction_rates.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""
A file with structs and functions for sampling reactions and updating reaction rates in spatial SSAs
A file with structs and functions for sampling reactions and updating reaction rates in spatial SSAs.
Massaction jumps go first in the indexing, then constant rate jumps.
"""

### spatial rx rates ###
struct RxRates{F, M}
struct RxRates{F, M, C}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For future consideration, do you know if storing the jumps in RxRates instead of within the aggregator directly results in an extra layer of indirection when accessing them (i.e. an extra pointer call)? If so, we might want to consider moving them up a level if that shows any performance benefit.

"rx_rates[i,j] is rate of reaction i at site j"
rates::Matrix{F}

Expand All @@ -12,20 +13,25 @@ struct RxRates{F, M}

"AbstractMassActionJump"
ma_jumps::M

"indexable collection of ConstantRateJump"
cr_jumps::C
end

"""
RxRates(num_sites::Int, ma_jumps::M) where {M}
RxRates(num_sites::Int, ma_jumps::M, cr_jumps::C) where {M, C}

initializes RxRates with zero rates
"""
function RxRates(num_sites::Int, ma_jumps::M) where {M}
numrxjumps = get_num_majumps(ma_jumps)
function RxRates(num_sites::Int, ma_jumps::M, cr_jumps::C) where {M, C}
numrxjumps = get_num_majumps(ma_jumps) + length(cr_jumps)
rates = zeros(Float64, numrxjumps, num_sites)
RxRates{Float64, M}(rates, vec(sum(rates, dims = 1)), ma_jumps)
RxRates{Float64, M, C}(rates, vec(sum(rates, dims = 1)), ma_jumps, cr_jumps)
end
RxRates(num_sites::Int, ma_jumps::M) where {M<:AbstractMassActionJump} = RxRates(num_sites, ma_jumps, ConstantRateJump[])
RxRates(num_sites::Int, cr_jumps::C) where {C} = RxRates(num_sites, SpatialMassActionJump(), cr_jumps)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
RxRates(num_sites::Int, cr_jumps::C) where {C} = RxRates(num_sites, SpatialMassActionJump(), cr_jumps)
RxRates(num_sites::Int, cr_jumps::C) where {C} = RxRates(num_sites, SpatialMassActionJump(), cr_jumps)

I'd specify a type for C here or not include C at all. Generally a generic type like this would be suggested to just be dropped since it isn't being used (it may even give a warning in tests).


num_rxs(rx_rates::RxRates) = get_num_majumps(rx_rates.ma_jumps)
num_rxs(rx_rates::RxRates) = get_num_majumps(rx_rates.ma_jumps) + length(rx_rates.cr_jumps)

"""
reset!(rx_rates::RxRates)
Expand All @@ -48,16 +54,21 @@ function total_site_rx_rate(rx_rates::RxRates, site)
end

"""
update_rx_rates!(rx_rates, rxs, u, site)
update_rx_rates!(rx_rates, rxs, integrator, site)

update rates of all reactions in rxs at site
"""
function update_rx_rates!(rx_rates::RxRates, rxs, u::AbstractMatrix, integrator,
site)
ma_jumps = rx_rates.ma_jumps
@inbounds for rx in rxs
rate = eval_massaction_rate(u, rx, ma_jumps, site)
set_rx_rate_at_site!(rx_rates, site, rx, rate)
if is_massaction(rx_rates, rx)
rate = eval_massaction_rate(u, rx, ma_jumps, site)
set_rx_rate_at_site!(rx_rates, site, rx, rate)
else
cr_jump = rx_rates.cr_jumps[rx - get_num_majumps(ma_jumps)]
set_rx_rate_at_site!(rx_rates, site, rx, cr_jump.rate(u, integrator.p, integrator.t, site))
Comment on lines +69 to +70
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't type stable is it? That is why we split the ConstantRateJump rates and affects in the non-spatial solvers and wrap them inside FunctionWrappers.

end
end
end

Expand All @@ -77,6 +88,16 @@ function sample_rx_at_site(rx_rates::RxRates, site, rng)
rand(rng) * total_site_rx_rate(rx_rates, site))
end

function execute_rx_at_site!(integrator, rx_rates::RxRates, rx, site)
if is_massaction(rx_rates, rx)
@inbounds executerx!((@view integrator.u[:, site]), rx,
rx_rates.ma_jumps)
else
cr_jump = rx_rates.cr_jumps[rx - get_num_majumps(rx_rates.ma_jumps)]
cr_jump.affect!(integrator, site)
Comment on lines +96 to +97
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't type stable is it? That is why we split the ConstantRateJump rates and affects in the non-spatial solvers and wrap them inside FunctionWrappers.

end
end

# helper functions
function set_rx_rate_at_site!(rx_rates::RxRates, site, rx, rate)
@inbounds old_rate = rx_rates.rates[rx, site]
Expand All @@ -90,5 +111,10 @@ function Base.show(io::IO, ::MIME"text/plain", rx_rates::RxRates)
println(io, "RxRates with $num_rxs reactions and $num_sites sites")
end

"Return true if jump is a massaction jump."
function is_massaction(rx_rates::RxRates, rx)
rx <= get_num_majumps(rx_rates.ma_jumps)
end

eval_massaction_rate(u, rx, ma_jumps::M, site) where {M <: SpatialMassActionJump} = evalrxrate(u, rx, ma_jumps, site)
eval_massaction_rate(u, rx, ma_jumps::M, site) where {M <: MassActionJump} = evalrxrate((@view u[:, site]), rx, ma_jumps)
6 changes: 6 additions & 0 deletions src/spatial/spatial_massaction_jump.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,12 @@ function SpatialMassActionJump(ma_jumps::MassActionJump{T, S, U, V}; scale_rates
scale_rates = scale_rates, useiszero = useiszero, nocopy = nocopy)
end

function SpatialMassActionJump()
empty_majump = MassActionJump(Vector{Float64}(),
Vector{Vector{Pair{Int, Int}}}(),
Vector{Vector{Pair{Int, Int}}}())
SpatialMassActionJump(empty_majump)
end
##############################################

function get_num_majumps(smaj::SpatialMassActionJump{Nothing, Nothing, S, U, V}) where
Expand Down
5 changes: 2 additions & 3 deletions src/spatial/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ struct SpatialJump{J}
"source location"
src::J

"index of jump as a hop or reaction"
"index of jump as a hop or reaction, hops first, then massaction reactions, then constant rate reactions"
jidx::Int

"destination location, equal to src for within-site reactions"
Expand Down Expand Up @@ -69,8 +69,7 @@ function update_state!(p, integrator)
execute_hop!(integrator, jump.src, jump.dst, jump.jidx)
else
rx_index = reaction_id_from_jump(p, jump)
@inbounds executerx!((@view integrator.u[:, jump.src]), rx_index,
p.rx_rates.ma_jumps)
execute_rx_at_site!(integrator, p.rx_rates, rx_index, jump.src)
end
# save jump that was just exectued
p.prev_jump = jump
Expand Down
22 changes: 22 additions & 0 deletions test/spatial/ABC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,23 @@ netstoch = [[1 => -1, 2 => -1, 3 => 1], [1 => 1, 2 => 1, 3 => -1]]
rates = [0.1 / mesh_size, 1.0]
majumps = MassActionJump(rates, reactstoch, netstoch)

# equivalent constant rate jumps
rate1(u,p,t,site) = u[1,site]*u[2,site] / 2
rate2(u,p,t,site) = u[3,site]
affect1!(integrator,site) = begin
integrator.u[1, site] -= 1
integrator.u[2, site] -= 1
integrator.u[3, site] += 1
end
affect2!(integrator,site) = begin
integrator.u[1, site] += 1
integrator.u[2, site] += 1
integrator.u[3, site] -= 1
end
crjumps = JumpSet([ConstantRateJump(rate1, affect1!), ConstantRateJump(rate2, affect2!)])
dep_graph = [[1,2],[1,2]]
jumptovars_map = [[1,2,3],[1,2,3]]

# spatial system setup
hopping_rate = diffusivity * (linear_size / domain_size)^2

Expand Down Expand Up @@ -56,6 +73,11 @@ jump_problems = JumpProblem[JumpProblem(prob, NSM(), majumps,
push!(jump_problems,
JumpProblem(prob, DirectCRDirect(), majumps, hopping_constants = hopping_constants,
spatial_system = grids[1], save_positions = (false, false), rng = rng))
# setup constant rate jump problems
push!(jump_problems, JumpProblem(prob, NSM(), crjumps, hopping_constants = hopping_constants,
spatial_system = CartesianGrid(dims), save_positions = (false, false), dep_graph = dep_graph, jumptovars_map = jumptovars_map, rng = rng))
push!(jump_problems, JumpProblem(prob, DirectCRDirect(), crjumps, hopping_constants = hopping_constants,
spatial_system = CartesianGrid(dims), save_positions = (false, false), dep_graph = dep_graph, jumptovars_map = jumptovars_map, rng = rng))
# setup flattenned jump prob
push!(jump_problems,
JumpProblem(prob, NRM(), majumps, hopping_constants = hopping_constants,
Expand Down
19 changes: 14 additions & 5 deletions test/spatial/reaction_rates.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,22 +24,31 @@ num_species = 3
reactstoch = [[1 => 1, 2 => 1], [3 => 1]]
netstoch = [[1 => -1, 2 => -1, 3 => 1], [1 => 1, 2 => 1, 3 => -1]]
rates = [0.1, 1.0]
num_rxs = length(rates)
ma_jumps = MassActionJump(rates, reactstoch, netstoch)
spatial_ma_jumps = SpatialMassActionJump(rates, reactstoch, netstoch)
rate_fn = (u, p, t, site) -> 1.0
affect_fn!(integrator) = nothing # a dummy reaction, does nothing
cr_jumps = [ConstantRateJump(rate_fn, affect_fn!)]
num_rxs = 3
u = ones(Int, num_species, num_nodes)
integrator = DummyIntegrator(u,nothing,nothing)
rng = StableRNG(12345)

# Test constructors
@test JP.RxRates(num_nodes, ma_jumps).ma_jumps == ma_jumps
@test JP.RxRates(num_nodes, spatial_ma_jumps).ma_jumps == spatial_ma_jumps
@test JP.RxRates(num_nodes, cr_jumps).cr_jumps == cr_jumps

# Tests for RxRates
rx_rates_list = [JP.RxRates(num_nodes, ma_jumps), JP.RxRates(num_nodes, spatial_ma_jumps)]
rx_rates_list = [JP.RxRates(num_nodes, ma_jumps, cr_jumps), JP.RxRates(num_nodes, spatial_ma_jumps, cr_jumps)]
for rx_rates in rx_rates_list
@test JP.num_rxs(rx_rates) == length(rates)
@test JP.num_rxs(rx_rates) == num_rxs
show(io, "text/plain", rx_rates)
for site in 1:num_nodes
JP.update_rx_rates!(rx_rates, 1:num_rxs, integrator, site)
@test JP.total_site_rx_rate(rx_rates, site) == 1.1
rx_props = [JP.evalrxrate(u[:, site], rx, ma_jumps) for rx in 1:num_rxs]
@test JP.total_site_rx_rate(rx_rates, site) == 2.1
majump_props = [JP.evalrxrate(u[:, site], rx, ma_jumps) for rx in 1:2]
rx_props = [majump_props..., 1.0]
rx_probs = rx_props / sum(rx_props)
d = Dict{Int, Int}()
for i in 1:num_samples
Expand Down
Loading