-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #97 from SABS-R3-Epidemiology/AddProspensity
Add prospensity
- Loading branch information
Showing
13 changed files
with
6,451 additions
and
191 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
# | ||
# This file is part of seirmo (https://github.com/SABS-R3-Epidemiology/seirmo/) | ||
# which is released under the BSD 3-clause license. See accompanying LICENSE.md | ||
# for copyright notice and full license details. | ||
# | ||
|
||
import numpy as np | ||
import seirmo as se | ||
from ._gillespie import solve_gillespie | ||
|
||
|
||
class StochasticSEIRModel(se.SEIRForwardModel): | ||
r""" | ||
ODE model: Stochastic SEIR | ||
The SEIR Model has four compartments: | ||
susceptible individuals (:math:`S`), | ||
exposed but not yet infectious (:math:`E`), | ||
infectious (:math:`I`) and recovered (:math:`R`): | ||
Possible processes between compartments: | ||
Exposure: S -> E, at rate :math:\beta S(t)I(t)`` | ||
Infection: E -> I, at rate :math:\kappa E(t)`` | ||
Recovery: I -> R, at rate :math:\gamma I(t)`` | ||
Can be used in conjunction with solve_gillespie(), | ||
a stochastic ODE solver implemented in this package. | ||
Extends :class:`SEIRForwardModel`. | ||
""" | ||
def __init__(self, params_names: list): | ||
super(StochasticSEIRModel, self).__init__() | ||
self._parameters = se.SEIRParameters(params_names) | ||
# sets up n compartments, returns names of output variables | ||
# Define the names of the compartments to record - default all | ||
self._output_collector = se.StochasticOutputCollector( | ||
['S', 'E', 'I', 'R']) | ||
|
||
def update_propensity(self, current_states: np.ndarray) -> np.ndarray: | ||
''' This function takes the current populations in each | ||
of the N compartments and returns a NxN array where the entry (i,j) | ||
gives the rate of transfer of the population of compartment i | ||
to compartment j. | ||
Each non-zero element here corresponds to one equation in the | ||
SEIR model. | ||
Non-zero diagonal elements would correspond to no change in the | ||
overall population. | ||
Warning - negative elements should be avoided - a negative value at | ||
(i,j) corresponds to a positive element at (j,i) and should be | ||
implemented as such if required. | ||
''' | ||
|
||
params_names = self._parameters.parameter_names() | ||
beta = self._parameters[params_names.index('beta')] | ||
kappa = self._parameters[params_names.index('kappa')] | ||
gamma = self._parameters[params_names.index('gamma')] | ||
|
||
[t, S, E, I, R] = current_states | ||
N = len(current_states) - 1 | ||
propens_matrix = np.zeros((N, N)) | ||
propens_matrix[0, 1] = beta * S * I | ||
propens_matrix[1, 2] = kappa * E | ||
propens_matrix[2, 3] = gamma * I | ||
|
||
return propens_matrix | ||
|
||
def simulate(self, parameters: np.ndarray, times: list, | ||
max_t_step: float = 0.01): | ||
self._parameters.configure_parameters(parameters) # array of length 7 | ||
# with values of beta | ||
# gamma kappa and initial | ||
self._output_collector.begin(times) | ||
|
||
initial_states = self._parameters[:4] # input initial values | ||
|
||
for point in solve_gillespie( | ||
lambda states: self.update_propensity(states), # states | ||
# includes t as first argument | ||
initial_states, | ||
[times[0], times[-1]], max_t_step): | ||
|
||
self._output_collector.report(point) | ||
self._output_collector.retrieve() | ||
|
||
return self._output_collector.retrieve() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
# | ||
# This file is part of seirmo (https://github.com/SABS-R3-Epidemiology/seirmo/) | ||
# which is released under the BSD 3-clause license. See accompanying LICENSE.md | ||
# for copyright notice and full license details. | ||
# | ||
|
||
import numpy as np | ||
import seirmo as se | ||
|
||
|
||
class StochasticOutputCollector(se.SEIROutputCollector): | ||
def begin(self, times): | ||
self._data = np.full((len(times), len(self._output_names)), np.nan) | ||
self._index = 0 | ||
self._times = np.array(times) | ||
|
||
def report(self, data: np.ndarray) -> np.array: | ||
"""Report data as a column vector into an array at each timestep. | ||
:param data: numpy array containing the data of the model resolution | ||
:return: numpy array containing the model solution | ||
""" | ||
if self._index >= self._data.shape[0]: | ||
return | ||
assert data.shape == (self._data.shape[1] + 1,), 'Invalid Data Shape' | ||
gill_time = data[0] | ||
if gill_time >= self._times[self._index]: | ||
self._data[self._index, :] = np.transpose(data[1:]) | ||
self._index += 1 | ||
|
||
def retrieve_time(self, index: int) -> np.ndarray: | ||
"""Return data as a column vector at a time point requested. Asserts | ||
timepoint is within the 'past' of the model. | ||
:param time_point: specified time at which we want the data | ||
:return: data as a column for the specified time step | ||
:rtype: numpy array column | ||
""" | ||
|
||
assert ( | ||
index < self._index | ||
and index >= 0 | ||
and index < self._data.shape[0] | ||
) | ||
return np.transpose(self._data[index, :]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,3 +8,5 @@ | |
IncidenceNumberPlot, | ||
CompartmentPlot, | ||
SubplotFigure) | ||
|
||
from ._plot_from_numpy import ConfigurablePlotter # noqa |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,191 @@ | ||
# | ||
# This file is part of seirmo (https://github.com/SABS-R3-Epidemiology/seirmo/) | ||
# which is released under the BSD 3-clause license. See accompanying LICENSE.md | ||
# for copyright notice and full license details. | ||
# | ||
|
||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
import matplotlib.colors | ||
|
||
|
||
class ConfigurablePlotter: | ||
""" | ||
A figure class that visualises the population of each compartment over time | ||
Configurable to plot multiple subplots in one figure, with customised | ||
labels or colours | ||
Implements addfill() method to plot a shaded region between two datasets | ||
(I.e. when plotting confidence intervals) | ||
""" | ||
|
||
def __init__(self): | ||
pass | ||
|
||
def begin(self, subplots_rows: int = 1, subplots_columns: int = 1): | ||
""" | ||
Begins creating a figure, with given number of subfigures | ||
Replaces init class so object can be reused""" | ||
if type(subplots_rows) != int: | ||
raise TypeError("Number of rows of subplots must be an integer") | ||
if type(subplots_columns) != int: | ||
raise TypeError("Number of columns of subplots must be an integer") | ||
if subplots_rows <= 0: | ||
raise ValueError("Number of rows of subplots must be positive") | ||
if subplots_columns <= 0: | ||
raise ValueError("Number of columns of subplots must be positive") | ||
|
||
self._fig, self._axes = plt.subplots(subplots_rows, subplots_columns) | ||
self._size = subplots_columns * subplots_rows | ||
self._nrows = subplots_rows | ||
self._ncolumns = subplots_columns | ||
# we store a figure object and multiple axes objects | ||
|
||
# Ensure self._axes array is always 2D | ||
if self._nrows == 1 and self._ncolumns == 1: | ||
self._axes = np.array(self._axes)[np.newaxis, np.newaxis] | ||
elif self._nrows == 1: | ||
self._axes = np.array(self._axes)[np.newaxis, :] | ||
elif self._ncolumns == 1: | ||
self._axes = np.array(self._axes)[:, np.newaxis] | ||
|
||
def __getitem__(self, index): | ||
"""If figure = ConfigurablePlotter(), then figure.begin(). | ||
Figure[0] will return the matplot figure, and figure[1] will | ||
return the subplot axis objects""" | ||
if index == 0: | ||
return self._fig | ||
elif index == 1: | ||
return self._axes | ||
else: | ||
raise ValueError("Index must be 0 (for figure) or 1 (for axes)") | ||
|
||
def add_data_to_plot( | ||
self, | ||
times: np.ndarray, | ||
data_array: np.ndarray, | ||
position: list = [0, 0], | ||
xlabel: str = "time", | ||
ylabels: list = [], | ||
colours: list = [], | ||
new_axis=False, | ||
): | ||
"""Main code to add new data into the plot | ||
:params:: times: np.ndarray, independent x- variable | ||
:params:: data_array: np.ndarray, multiple dependent y- variables | ||
Data should has one row per timestep, | ||
and one column for each dependent variable | ||
:params:: position: list of integers, gives index of subplot to use | ||
:params:: xlabel: str | ||
:params:: ylabel: list of strings (a single string is also accepted) | ||
:params:: colours: list of valid colour specifiers (ie strings or | ||
rgb tuples) | ||
:params:: new_axis: boolean, set to true if data should | ||
be plotted on a second x axis""" | ||
|
||
if len(data_array.shape) == 1: # Turn any 1D input into 2D | ||
if type(times) != np.ndarray or np.sum(np.shape(times)) == 1: | ||
# I.e. if only one np.int, or one element array | ||
times = np.array(times, ndmin=2) | ||
data_array = data_array[np.newaxis, :] | ||
else: | ||
data_array = data_array[:, np.newaxis] | ||
|
||
assert ( | ||
times.shape[0] == data_array.shape[0] | ||
), "data and times are not the same length" | ||
data_width = data_array.shape[1] # saves the number of y-var | ||
|
||
assert ( | ||
position[0] < self._nrows and position[1] < self._ncolumns | ||
), "position and shape are not compatible" | ||
|
||
if new_axis: | ||
axis = self._axes[position[0], position[1]].twinx() | ||
else: | ||
axis = self._axes[position[0], position[1]] | ||
|
||
# Format user inputs | ||
if len(colours) == 0: # Default value, if no colous specified | ||
colours = plt.cm.viridis(np.linspace(0, 1, data_width)) | ||
else: | ||
colours = matplotlib.colors.to_rgba_array(colours) | ||
assert data_width == np.shape(colours)[0],\ | ||
'Unexpected number of colours' | ||
|
||
if isinstance(ylabels, str): | ||
ylabels = [ylabels] # Converts string input to list | ||
try: | ||
iter(ylabels) | ||
except TypeError: | ||
raise TypeError('Unexpected type of ylabels') | ||
|
||
# Plot over data array iteratively | ||
if len(ylabels) > 0: # If ylabels have been specified for inclusion | ||
assert data_width == len(ylabels), 'Unexpected number of ylabels' | ||
for i in range(data_width): | ||
axis.plot(times, data_array[:, i], color=colours[i], | ||
label=ylabels[i]) | ||
axis.legend() | ||
else: # Plot without a figure legend | ||
for i in range(data_width): | ||
axis.plot(times, data_array[:, i], color=colours[i]) | ||
|
||
plt.xlabel(xlabel) | ||
self._fig.tight_layout() | ||
return self._fig, self._axes | ||
|
||
def add_fill( | ||
self, | ||
times: np.ndarray, | ||
ymin: np.ndarray, | ||
ymax: np.ndarray, | ||
position: list = [0, 0], | ||
xlabel: str = "time", | ||
ylabel: str = "number of people", | ||
colour: str = ["b"], | ||
alpha: float = 0.2, | ||
): | ||
"""Code to plot shaded region between two datasets | ||
:params:: times: np.ndarray, independent x- variable | ||
:params:: ymin: np.ndarray, dependent y- variables | ||
:params:: ymin: np.ndarray, comparison y- variables | ||
:params:: position: list of integers, gives index of subplot to use | ||
:params:: xlabel: str | ||
:params:: ylabel: list of strings | ||
:params:: colour: any valid colour specifier | ||
:params:: alpha: float, indicate transparency of filled region | ||
N.B While it is recommended that y_min should be the (generally) | ||
smaller dataset for readability, this is not required, and the | ||
datasets may cross (i.e. y_min may be larger in sections)""" | ||
|
||
assert ( | ||
position[0] < self._nrows and position[1] < self._ncolumns | ||
), "position and shape are not compatible" | ||
axis = self._axes[position[0], position[1]] | ||
|
||
# plots the data | ||
axis.fill_between( | ||
times, | ||
np.squeeze(ymin), | ||
np.squeeze(ymax), | ||
color=colour, | ||
alpha=alpha, | ||
label=ylabel, | ||
) | ||
axis.legend() | ||
plt.xlabel(xlabel) | ||
self._fig.tight_layout() | ||
return self._fig, self._axes | ||
|
||
def show(self): | ||
plt.show() | ||
|
||
def write_to_file(self, filename: str = "SEIR_stochastic_simulation.pdf"): | ||
self._fig.savefig(filename) | ||
|
||
def __del__(self): | ||
if hasattr(self, "_fig"): | ||
plt.close(self._fig) # Close figure upon deletion |
Oops, something went wrong.