Skip to content

Commit

Permalink
Merge pull request #1047 from amath-idm/add-calibration
Browse files Browse the repository at this point in the history
Add calibration
  • Loading branch information
cliffckerr authored May 18, 2021
2 parents 834b74f + f5c6485 commit af3f6a5
Show file tree
Hide file tree
Showing 15 changed files with 380 additions and 197 deletions.
7 changes: 4 additions & 3 deletions .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,11 @@ jobs:
- name: Install Covasim
run: pip install -e .
- name: Install tests
run: pip install pytest
working-directory: ./tests
run: pip install -r requirements_test.txt
- name: Run integration tests
working-directory: ./tests
run: pytest -v test_*.py --durations=0
run: pytest -v test_*.py --workers auto --durations=0
- name: Run unit tests
working-directory: ./tests/unittests
run: pytest -v test_*.py --durations=0
run: pytest -v test_*.py --workers auto --durations=0
11 changes: 10 additions & 1 deletion CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ Coming soon

These are the major improvements we are currently working on. If there is a specific bugfix or feature you would like to see, please `create an issue <https://github.com/InstituteforDiseaseModeling/covasim/issues/new/choose>`__.

- Expanded tutorials (health care workers, calibration, exercises, etc.)
- Continued updates to vaccine and variant parameters and workflows
- Multi-region and geographical support
- Economics and costing analysis

Expand All @@ -25,6 +25,15 @@ Latest versions (3.0.x)
~~~~~~~~~~~~~~~~~~~~~~~


Version 3.0.3 (2021-05-17)
--------------------------
- Added a new class, ``cv.Calibration``, that can perform automatic calibration. Simplest usage is ``sim.calibrate(calib_pars)``. Note: this requires Optuna, which is not installed by default; please install separately via ``pip install optuna``. See the updated calibration tutorial for more information.
- Added a new result, ``known_deaths``, which counts only deaths among people who have been diagnosed.
- ``sim.compute_fit()`` now returns the fit by default, and creates ``sim.fit`` (previously, this was stored in ``sim.results.fit``).
- *Regression information*: Calls to ``sim.results.fit`` should be replaced with ``sim.fit``. The ``output`` parameter for ``sim.compute_fit()`` has been removed since it now always outputs the ``Fit`` object.
- *GitHub info*: PR `1047 <https://github.com/amath-idm/covasim/pull/1047>`__


Version 3.0.2 (2021-04-26)
--------------------------
- Added Novavax as one of the default vaccines.
Expand Down
189 changes: 188 additions & 1 deletion covasim/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
but which are useful for particular investigations.
'''

import os
import numpy as np
import pylab as pl
import pandas as pd
Expand All @@ -12,9 +13,15 @@
from . import interventions as cvi
from . import settings as cvset
from . import plotting as cvpl
from . import run as cvr
try:
import optuna as op
except ImportError as E: # pragma: no cover
errormsg = f'Optuna import failed ({str(E)}), please install first (pip install optuna)'
op = ImportError(errormsg)


__all__ = ['Analyzer', 'snapshot', 'age_histogram', 'daily_age_stats', 'daily_stats', 'Fit', 'TransTree']
__all__ = ['Analyzer', 'snapshot', 'age_histogram', 'daily_age_stats', 'daily_stats', 'Fit', 'Calibration', 'TransTree']


class Analyzer(sc.prettyobj):
Expand Down Expand Up @@ -1223,6 +1230,186 @@ def plot(self, keys=None, width=0.8, fig_args=None, axis_args=None, plot_args=No
return fig



class Calibration(Analyzer):
'''
A class to handle calibration of Covasim simulations. Uses the Optuna hyperparameter
optimization library (optuna.org), which must be installed separately (via
pip install optuna).
Note: running a calibration does not guarantee a good fit! You must ensure that
you run for a sufficient number of iterations, have enough free parameters, and
that the parameters have wide enough bounds. Please see the tutorial on calibration
for more information.
Args:
sim (Sim): the simulation to calibrate
calib_pars (dict): a dictionary of the parameters to calibrate of the format dict(key1=[best, low, high])
custom_fn (function): a custom function for modifying the simulation; receives the sim and calib_pars as inputs, should return the modified sim
n_trials (int): the number of trials per worker
n_workers (int): the number of parallel workers (default: maximum
total_trials (int): if n_trials is not supplied, calculate by dividing this number by n_workers)
name (str): the name of the database (default: 'covasim_calibration')
db_name (str): the name of the database file (default: 'covasim_calibration.db')
storage (str): the location of the database (default: sqlite)
label (str): a label for this calibration object
verbose (bool): whether to print details of the calibration
kwargs (dict): passed to cv.Calibration()
Returns:
A Calibration object
**Example**::
sim = cv.Sim(datafile='data.csv')
calib_pars = dict(beta=[0.015, 0.010, 0.020])
calib = cv.Calibration(sim, calib_pars, total_trials=100)
calib.calibrate()
calib.plot()
'''

def __init__(self, sim, calib_pars=None, custom_fn=None, n_trials=None, n_workers=None, total_trials=None, name=None, db_name=None, storage=None, label=None, verbose=True):
super().__init__(label=label) # Initialize the Analyzer object
if isinstance(op, Exception): raise op # If Optuna failed to import, raise that exception now
import multiprocessing as mp

# Handle run arguments
if n_trials is None: n_trials = 20
if n_workers is None: n_workers = mp.cpu_count()
if name is None: name = 'covasim_calibration'
if db_name is None: db_name = f'{name}.db'
if storage is None: storage = f'sqlite:///{db_name}'
if total_trials is not None: n_trials = total_trials/n_workers
self.run_args = sc.objdict(n_trials=int(n_trials), n_workers=int(n_workers), name=name, db_name=db_name, storage=storage)

# Handle other inputs
self.sim = sim
self.calib_pars = calib_pars
self.custom_fn = custom_fn
self.verbose = verbose
self.calibrated = False
return


def run_sim(self, calib_pars, label=None, return_sim=False):
''' Create and run a simulation '''
sim = self.sim.copy()
if label: sim.label = label
valid_pars = {k:v for k,v in calib_pars.items() if k in sim.pars}
sim.update_pars(valid_pars)
if self.custom_fn:
sim = self.custom_fn(sim, calib_pars)
else:
if len(valid_pars) != len(calib_pars):
extra = set(calib_pars.keys()) - set(valid_pars.keys())
errormsg = f'The following parameters are not part of the sim, nor is a custom function specified to use them: {sc.strjoin(extra)}'
raise ValueError(errormsg)
sim.run()
sim.compute_fit()
if return_sim:
return sim
else:
return sim.fit.mismatch


def run_trial(self, trial):
''' Define the objective for Optuna '''
pars = {}
for key, (best,low,high) in self.calib_pars.items():
pars[key] = trial.suggest_uniform(key, low, high) # Sample from beta values within this range
mismatch = self.run_sim(pars)
return mismatch


def worker(self):
''' Run a single worker '''
if self.verbose:
op.logging.set_verbosity(op.logging.DEBUG)
else:
op.logging.set_verbosity(op.logging.ERROR)
study = op.load_study(storage=self.run_args.storage, study_name=self.run_args.name)
output = study.optimize(self.run_trial, n_trials=self.run_args.n_trials)
return output


def run_workers(self):
''' Run multiple workers in parallel '''
output = sc.parallelize(self.worker, iterarg=self.run_args.n_workers)
return output


def make_study(self):
''' Make a study, deleting one if it already exists '''
if os.path.exists(self.run_args.db_name):
os.remove(self.run_args.db_name)
print(f'Removed existing calibration {self.run_args.db_name}')
output = op.create_study(storage=self.run_args.storage, study_name=self.run_args.name)
return output


def calibrate(self, calib_pars=None, verbose=True, **kwargs):
'''
Actually perform calibration.
Args:
calib_pars (dict): if supplied, overwrite stored calib_pars
kwargs (dict): if supplied, overwrite stored run_args (n_trials, n_workers, etc.)
'''

# Load and validate calibration parameters
if calib_pars is not None:
self.calib_pars = calib_pars
if self.calib_pars is None:
errormsg = 'You must supply calibration parameters either when creating the calibration object or when calling calibrate().'
raise ValueError(errormsg)
self.run_args.update(kwargs) # Update optuna settings

# Run the optimization
t0 = sc.tic()
self.make_study()
self.run_workers()
study = op.load_study(storage=self.run_args.storage, study_name=self.run_args.name)
self.best_pars = sc.objdict(study.best_params)
self.elapsed = sc.toc(t0, output=True)

# Compare the results
self.initial_pars = sc.objdict({k:v[0] for k,v in self.calib_pars.items()})
self.before = self.run_sim(calib_pars=self.initial_pars, label='Before calibration', return_sim=True)
self.after = self.run_sim(calib_pars=self.best_pars, label='After calibration', return_sim=True)

# Tidy up
self.calibrated = True
if verbose:
self.summarize()

return


def summarize(self):
if self.calibrated:
print(f'Calibration for {self.run_args.n_workers*self.run_args.n_trials} total trials completed in {self.elapsed:0.1f} s.')
before = self.before.fit.mismatch
after = self.after.fit.mismatch
print('\nInitial parameter values:')
print(self.initial_pars)
print('\nBest parameter values:')
print(self.best_pars)
print(f'\nMismatch before calibration: {before:n}')
print(f'Mismatch after calibration: {after:n}')
print(f'Percent improvement: {((before-after)/before)*100:0.1f}%')
return before, after
else:
print('Calibration not yet run; please run calib.calibrate()')
return


def plot(self, **kwargs):
msim = cvr.MultiSim([self.before, self.after])
fig = msim.plot(**kwargs)
return fig



class TransTree(Analyzer):
'''
A class for holding a transmission tree. There are several different representations
Expand Down
8 changes: 8 additions & 0 deletions covasim/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def __init__(self):
'tested',
'diagnosed',
'recovered',
'known_dead',
'dead',
'known_contact',
'quarantined',
Expand Down Expand Up @@ -146,6 +147,7 @@ def __init__(self):
'recovered': 'Number recovered',
'dead': 'Number dead',
'diagnosed': 'Number of confirmed cases',
'known_dead': 'Number of confirmed deaths',
'quarantined': 'Number in quarantine',
'vaccinated': 'Number of people vaccinated',
}
Expand All @@ -167,6 +169,7 @@ def __init__(self):
'deaths': 'deaths',
'tests': 'tests',
'diagnoses': 'diagnoses',
'known_deaths': 'known deaths',
'quarantined': 'quarantined people',
'vaccinations': 'vaccinations',
'vaccinated': 'vaccinated people'
Expand Down Expand Up @@ -260,6 +263,8 @@ def get_default_colors():
c.critical = '#b86113'
c.deaths = '#000000'
c.dead = c.deaths
c.known_dead = c.deaths
c.known_deaths = c.deaths
c.default = '#000000'
c.pop_nabs = '#32733d'
c.pop_protection = '#9e1149'
Expand All @@ -273,6 +278,7 @@ def get_default_colors():
'cum_severe',
'cum_critical',
'cum_deaths',
'cum_known_deaths',
'cum_diagnoses',
'new_infections',
'new_severe',
Expand Down Expand Up @@ -332,6 +338,7 @@ def get_default_plots(which='default', kind='sim', sim=None):
'cum_severe',
'cum_critical',
'cum_deaths',
'cum_known_deaths',
],
})

Expand All @@ -345,6 +352,7 @@ def get_default_plots(which='default', kind='sim', sim=None):
],
'Cumulative deaths': [
'cum_deaths',
'cum_known_deaths',
],
})

Expand Down
6 changes: 3 additions & 3 deletions covasim/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ def get_strain_pars(default=False):
),

b1351 = dict(
rel_imm_strain = 0.066, # Immunity protection obtained from a natural infection with wild type, relative to wild type. TODO: add source
rel_imm_strain = 1.0, # Immunity protection obtained from a natural infection with wild type, relative to wild type. TODO: add source
rel_beta = 1.4,
rel_symp_prob = 1.0,
rel_severe_prob = 1.4,
Expand Down Expand Up @@ -416,9 +416,9 @@ def get_cross_immunity(default=False):

b1351 = dict(
wild = 0.066, # https://www.nature.com/articles/s41586-021-03471-w
b117 = 0.1, # Assumption
b117 = 0.5, # Assumption
b1351 = 1.0, # Default for own-immunity
p1 = 0.1, # Assumption
p1 = 0.5, # Assumption
),

p1 = dict(
Expand Down
40 changes: 22 additions & 18 deletions covasim/people.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,16 +168,18 @@ def update_states_pre(self, t):

# Initialize
self.t = t
self.is_exp = self.true('exposed') # For storing the interim values since used in every subsequent calculation
self.is_exp = self.true('exposed') # For storing the interim values since used in every subsequent calculation

# Perform updates
self.init_flows()
self.flows['new_infectious'] += self.check_infectious() # For people who are exposed and not infectious, check if they begin being infectious
self.flows['new_symptomatic'] += self.check_symptomatic()
self.flows['new_severe'] += self.check_severe()
self.flows['new_critical'] += self.check_critical()
self.flows['new_deaths'] += self.check_death()
self.flows['new_recoveries'] += self.check_recovery() # TODO: check logic here
self.flows['new_infectious'] += self.check_infectious() # For people who are exposed and not infectious, check if they begin being infectious
self.flows['new_symptomatic'] += self.check_symptomatic()
self.flows['new_severe'] += self.check_severe()
self.flows['new_critical'] += self.check_critical()
self.flows['new_recoveries'] += self.check_recovery()
new_deaths, new_known_deaths = self.check_death()
self.flows['new_deaths'] += new_deaths
self.flows['new_known_deaths'] += new_known_deaths

return

Expand Down Expand Up @@ -298,20 +300,22 @@ def check_recovery(self, inds=None, filter_inds='is_exp'):
def check_death(self):
''' Check whether or not this person died on this timestep '''
inds = self.check_inds(self.dead, self.date_dead, filter_inds=self.is_exp)
self.susceptible[inds] = False
self.exposed[inds] = False
self.infectious[inds] = False
self.symptomatic[inds] = False
self.severe[inds] = False
self.critical[inds] = False
self.known_contact[inds] = False
self.quarantined[inds] = False
self.recovered[inds] = False
self.dead[inds] = True
self.dead[inds] = True
diag_inds = inds[self.diagnosed[inds]] # Check whether the person was diagnosed before dying
self.known_dead[diag_inds] = True
self.susceptible[inds] = False
self.exposed[inds] = False
self.infectious[inds] = False
self.symptomatic[inds] = False
self.severe[inds] = False
self.critical[inds] = False
self.known_contact[inds] = False
self.quarantined[inds] = False
self.recovered[inds] = False
self.infectious_strain[inds] = np.nan
self.exposed_strain[inds] = np.nan
self.recovered_strain[inds] = np.nan
return len(inds)
return len(inds), len(diag_inds)


def check_diagnosed(self):
Expand Down
Loading

0 comments on commit af3f6a5

Please sign in to comment.