Skip to content

Commit

Permalink
Merge pull request #97 from SABS-R3-Epidemiology/AddProspensity
Browse files Browse the repository at this point in the history
Add prospensity
  • Loading branch information
rccreswell authored Nov 5, 2021
2 parents 8725efa + dc7d787 commit da925e2
Show file tree
Hide file tree
Showing 13 changed files with 6,451 additions and 191 deletions.
4 changes: 4 additions & 0 deletions docs/source/core.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Overview:
- :class:`SEIRForwardModel`
- :class:`SEIRParameters`
- :class:`SEIROutputCollector`
- :class:`StochasticOutputCollector`

SEIR Core
*********
Expand All @@ -23,3 +24,6 @@ SEIR Core

.. autoclass:: SEIROutputCollector
:members:

.. autoclass:: StochasticOutputCollector
:members:
4 changes: 4 additions & 0 deletions docs/source/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ Overview:
- :class:`ReducedModel`
- :class:`SEIRModel`
- :class:`DeterministicSEIRModel`
- :class:`StochasticSEIRModel`

SEIR Model
**********
Expand All @@ -27,3 +28,6 @@ SEIR Model

.. autoclass:: DeterministicSEIRModel
:members:

.. autoclass:: StochasticSEIRModel
:members:
4 changes: 4 additions & 0 deletions docs/source/plots.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ Overview:
- :class:`IncidenceNumberPlot`
- :class:`CompartmentPlot`
- :class:`SubplotFigure`
- :class:`ConfigurablePlotter`


.. autoclass:: IncidenceNumberPlot
Expand All @@ -21,4 +22,7 @@ Overview:
:members:

.. autoclass:: SubplotFigure
:members:

.. autoclass:: ConfigurablePlotter
:members:
416 changes: 225 additions & 191 deletions examples/SEIR_simulation.ipynb

Large diffs are not rendered by default.

5,498 changes: 5,498 additions & 0 deletions examples/Stoch_simulation.ipynb

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions seirmo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
ReducedModel,
SEIRModel
)
from ._stoch_model import StochasticSEIRModel

from ._stochastic_output_collector import StochasticOutputCollector

from ._simulation import (
SimulationController
Expand Down
87 changes: 87 additions & 0 deletions seirmo/_stoch_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
#
# This file is part of seirmo (https://github.com/SABS-R3-Epidemiology/seirmo/)
# which is released under the BSD 3-clause license. See accompanying LICENSE.md
# for copyright notice and full license details.
#

import numpy as np
import seirmo as se
from ._gillespie import solve_gillespie


class StochasticSEIRModel(se.SEIRForwardModel):
r"""
ODE model: Stochastic SEIR
The SEIR Model has four compartments:
susceptible individuals (:math:`S`),
exposed but not yet infectious (:math:`E`),
infectious (:math:`I`) and recovered (:math:`R`):
Possible processes between compartments:
Exposure: S -> E, at rate :math:\beta S(t)I(t)``
Infection: E -> I, at rate :math:\kappa E(t)``
Recovery: I -> R, at rate :math:\gamma I(t)``
Can be used in conjunction with solve_gillespie(),
a stochastic ODE solver implemented in this package.
Extends :class:`SEIRForwardModel`.
"""
def __init__(self, params_names: list):
super(StochasticSEIRModel, self).__init__()
self._parameters = se.SEIRParameters(params_names)
# sets up n compartments, returns names of output variables
# Define the names of the compartments to record - default all
self._output_collector = se.StochasticOutputCollector(
['S', 'E', 'I', 'R'])

def update_propensity(self, current_states: np.ndarray) -> np.ndarray:
''' This function takes the current populations in each
of the N compartments and returns a NxN array where the entry (i,j)
gives the rate of transfer of the population of compartment i
to compartment j.
Each non-zero element here corresponds to one equation in the
SEIR model.
Non-zero diagonal elements would correspond to no change in the
overall population.
Warning - negative elements should be avoided - a negative value at
(i,j) corresponds to a positive element at (j,i) and should be
implemented as such if required.
'''

params_names = self._parameters.parameter_names()
beta = self._parameters[params_names.index('beta')]
kappa = self._parameters[params_names.index('kappa')]
gamma = self._parameters[params_names.index('gamma')]

[t, S, E, I, R] = current_states
N = len(current_states) - 1
propens_matrix = np.zeros((N, N))
propens_matrix[0, 1] = beta * S * I
propens_matrix[1, 2] = kappa * E
propens_matrix[2, 3] = gamma * I

return propens_matrix

def simulate(self, parameters: np.ndarray, times: list,
max_t_step: float = 0.01):
self._parameters.configure_parameters(parameters) # array of length 7
# with values of beta
# gamma kappa and initial
self._output_collector.begin(times)

initial_states = self._parameters[:4] # input initial values

for point in solve_gillespie(
lambda states: self.update_propensity(states), # states
# includes t as first argument
initial_states,
[times[0], times[-1]], max_t_step):

self._output_collector.report(point)
self._output_collector.retrieve()

return self._output_collector.retrieve()
45 changes: 45 additions & 0 deletions seirmo/_stochastic_output_collector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#
# This file is part of seirmo (https://github.com/SABS-R3-Epidemiology/seirmo/)
# which is released under the BSD 3-clause license. See accompanying LICENSE.md
# for copyright notice and full license details.
#

import numpy as np
import seirmo as se


class StochasticOutputCollector(se.SEIROutputCollector):
def begin(self, times):
self._data = np.full((len(times), len(self._output_names)), np.nan)
self._index = 0
self._times = np.array(times)

def report(self, data: np.ndarray) -> np.array:
"""Report data as a column vector into an array at each timestep.
:param data: numpy array containing the data of the model resolution
:return: numpy array containing the model solution
"""
if self._index >= self._data.shape[0]:
return
assert data.shape == (self._data.shape[1] + 1,), 'Invalid Data Shape'
gill_time = data[0]
if gill_time >= self._times[self._index]:
self._data[self._index, :] = np.transpose(data[1:])
self._index += 1

def retrieve_time(self, index: int) -> np.ndarray:
"""Return data as a column vector at a time point requested. Asserts
timepoint is within the 'past' of the model.
:param time_point: specified time at which we want the data
:return: data as a column for the specified time step
:rtype: numpy array column
"""

assert (
index < self._index
and index >= 0
and index < self._data.shape[0]
)
return np.transpose(self._data[index, :])
2 changes: 2 additions & 0 deletions seirmo/plots/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,5 @@
IncidenceNumberPlot,
CompartmentPlot,
SubplotFigure)

from ._plot_from_numpy import ConfigurablePlotter # noqa
191 changes: 191 additions & 0 deletions seirmo/plots/_plot_from_numpy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
#
# This file is part of seirmo (https://github.com/SABS-R3-Epidemiology/seirmo/)
# which is released under the BSD 3-clause license. See accompanying LICENSE.md
# for copyright notice and full license details.
#

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors


class ConfigurablePlotter:
"""
A figure class that visualises the population of each compartment over time
Configurable to plot multiple subplots in one figure, with customised
labels or colours
Implements addfill() method to plot a shaded region between two datasets
(I.e. when plotting confidence intervals)
"""

def __init__(self):
pass

def begin(self, subplots_rows: int = 1, subplots_columns: int = 1):
"""
Begins creating a figure, with given number of subfigures
Replaces init class so object can be reused"""
if type(subplots_rows) != int:
raise TypeError("Number of rows of subplots must be an integer")
if type(subplots_columns) != int:
raise TypeError("Number of columns of subplots must be an integer")
if subplots_rows <= 0:
raise ValueError("Number of rows of subplots must be positive")
if subplots_columns <= 0:
raise ValueError("Number of columns of subplots must be positive")

self._fig, self._axes = plt.subplots(subplots_rows, subplots_columns)
self._size = subplots_columns * subplots_rows
self._nrows = subplots_rows
self._ncolumns = subplots_columns
# we store a figure object and multiple axes objects

# Ensure self._axes array is always 2D
if self._nrows == 1 and self._ncolumns == 1:
self._axes = np.array(self._axes)[np.newaxis, np.newaxis]
elif self._nrows == 1:
self._axes = np.array(self._axes)[np.newaxis, :]
elif self._ncolumns == 1:
self._axes = np.array(self._axes)[:, np.newaxis]

def __getitem__(self, index):
"""If figure = ConfigurablePlotter(), then figure.begin().
Figure[0] will return the matplot figure, and figure[1] will
return the subplot axis objects"""
if index == 0:
return self._fig
elif index == 1:
return self._axes
else:
raise ValueError("Index must be 0 (for figure) or 1 (for axes)")

def add_data_to_plot(
self,
times: np.ndarray,
data_array: np.ndarray,
position: list = [0, 0],
xlabel: str = "time",
ylabels: list = [],
colours: list = [],
new_axis=False,
):
"""Main code to add new data into the plot
:params:: times: np.ndarray, independent x- variable
:params:: data_array: np.ndarray, multiple dependent y- variables
Data should has one row per timestep,
and one column for each dependent variable
:params:: position: list of integers, gives index of subplot to use
:params:: xlabel: str
:params:: ylabel: list of strings (a single string is also accepted)
:params:: colours: list of valid colour specifiers (ie strings or
rgb tuples)
:params:: new_axis: boolean, set to true if data should
be plotted on a second x axis"""

if len(data_array.shape) == 1: # Turn any 1D input into 2D
if type(times) != np.ndarray or np.sum(np.shape(times)) == 1:
# I.e. if only one np.int, or one element array
times = np.array(times, ndmin=2)
data_array = data_array[np.newaxis, :]
else:
data_array = data_array[:, np.newaxis]

assert (
times.shape[0] == data_array.shape[0]
), "data and times are not the same length"
data_width = data_array.shape[1] # saves the number of y-var

assert (
position[0] < self._nrows and position[1] < self._ncolumns
), "position and shape are not compatible"

if new_axis:
axis = self._axes[position[0], position[1]].twinx()
else:
axis = self._axes[position[0], position[1]]

# Format user inputs
if len(colours) == 0: # Default value, if no colous specified
colours = plt.cm.viridis(np.linspace(0, 1, data_width))
else:
colours = matplotlib.colors.to_rgba_array(colours)
assert data_width == np.shape(colours)[0],\
'Unexpected number of colours'

if isinstance(ylabels, str):
ylabels = [ylabels] # Converts string input to list
try:
iter(ylabels)
except TypeError:
raise TypeError('Unexpected type of ylabels')

# Plot over data array iteratively
if len(ylabels) > 0: # If ylabels have been specified for inclusion
assert data_width == len(ylabels), 'Unexpected number of ylabels'
for i in range(data_width):
axis.plot(times, data_array[:, i], color=colours[i],
label=ylabels[i])
axis.legend()
else: # Plot without a figure legend
for i in range(data_width):
axis.plot(times, data_array[:, i], color=colours[i])

plt.xlabel(xlabel)
self._fig.tight_layout()
return self._fig, self._axes

def add_fill(
self,
times: np.ndarray,
ymin: np.ndarray,
ymax: np.ndarray,
position: list = [0, 0],
xlabel: str = "time",
ylabel: str = "number of people",
colour: str = ["b"],
alpha: float = 0.2,
):
"""Code to plot shaded region between two datasets
:params:: times: np.ndarray, independent x- variable
:params:: ymin: np.ndarray, dependent y- variables
:params:: ymin: np.ndarray, comparison y- variables
:params:: position: list of integers, gives index of subplot to use
:params:: xlabel: str
:params:: ylabel: list of strings
:params:: colour: any valid colour specifier
:params:: alpha: float, indicate transparency of filled region
N.B While it is recommended that y_min should be the (generally)
smaller dataset for readability, this is not required, and the
datasets may cross (i.e. y_min may be larger in sections)"""

assert (
position[0] < self._nrows and position[1] < self._ncolumns
), "position and shape are not compatible"
axis = self._axes[position[0], position[1]]

# plots the data
axis.fill_between(
times,
np.squeeze(ymin),
np.squeeze(ymax),
color=colour,
alpha=alpha,
label=ylabel,
)
axis.legend()
plt.xlabel(xlabel)
self._fig.tight_layout()
return self._fig, self._axes

def show(self):
plt.show()

def write_to_file(self, filename: str = "SEIR_stochastic_simulation.pdf"):
self._fig.savefig(filename)

def __del__(self):
if hasattr(self, "_fig"):
plt.close(self._fig) # Close figure upon deletion
Loading

0 comments on commit da925e2

Please sign in to comment.