Code repository for Mathematical Model-Driven Deep Learning Enables Personalized Adaptive Therapy. Trains an A3C Deep Reinforcement Learning (DRL) model to treat a virtual cancer patient, modelled by a simple Lotka-Volterra ODE model.
demo
Documented example scripts for single and multi-patient trainingimages
Image outputs files to document research. Does not contain final paper images.models
Pre-trained models, for use in evaluation and figure scriptspaper_figures
All code used to generate figures for the paperutils
Utility functions used in training and figure generation, including virtual patient model
To setup up the Python environment within which to run this:
conda create -n DRL_env_tf15 python=3.5 seaborn tqdm py4j
conda activate DRL_env_tf15
pip install tensorflow==1.5
conda install ipykernel
The section describes the workflow implemented in drlUtils/run_evaluation
- it does not need to be implemented manually by a user but provides a step-by-step guide to the functionality of this script.
It uses a pre-trained model, as implemented in drlUtils/run_training
with classes from drlModel.py
. This created a Worker class which copies the network - multiple workers can be ran asynchronously to generate the A3C network.
-
Define parameters for model. This includes parameters for both the DRL model, such as reward metric and allowed treatments, and for the tumour model. It is also possible to define the model and results filepaths at this point. Further details of this are available in a separate file
parameters.md
. -
Set up the A3C network - contains one global network, and num_workers copies used during training to increase learning performance. This is imported from
utils/drlModel
, and initialised with the parameters provided in step 1. -
DRL Model is repeatedly evaluated on an instance of the virtual patient (represented by a ODE system), parametrised as in step 1. This requires repeated iteration over time, until and end condition is reached (either 100 years survival, or tumour growth passed 1.2 times the original size).
- Record the current state of the system in an internal observation list. This contains the initial tumour size, the tumour sizes each day for the past week and the daily growth (i.e. the per-day change for each day over the past week).
- Decide on next treatment using probabilities from policy network output, based on the observation input.
- Treat the patient and observe the response.
- Calculate the reward from that action.
- Record this timestep - see detail below.
This is all packaged in the single and multiple patient examples in the demo/
directory. Various model performance metrics are then calculated in the jupyter notebooks within the same directory.
It uses a pre-trained model, as implemented in train.py
with classes from drlModel.py
. This created a Worker class which copies the network - multiple workers can be ran asynchronously to generate the A3C network.
Output data is stored as a list of dictionaries per timestep, which are then converted into a pandas Dataframe for csv output. This contains the following columns:
ReplicateId
- Iteration NumberTime
- Timestep (i.e. the time the current treatment starts at)S
- Volume* of susceptible cells in tumour (not accessible to DRL model)R
- Volume* of resistant cells in tumour (not accessible to DRL model)TumourSize
- Volume* of total tumour - i.e. the sum ofS
andR
(accessible to DRL model)Support_Hol
: Probability of taking Treatment actionSupport_Treat
: Probability of taking Holiday actionAction
- Action taken in timestep - either "H" for Holiday or "T" for TreatmentDrugConcentration
- Concentration of drug delivered in treatment (0 if holiday)
Note that values such as the cell distribution within the tumour are output for analysis, but this information is not available to the DRL during evaluation. See step 3.1 in the workflow summary for more information on the data provided to the DRL model.
* I.e. A continuous metric directly proportional to the number of cells in the tumour.
The DRL model was trained on a virtual patient, represented by an ODE model. This was a simple 2-population Lotka-Volterra tumour model, where
Both species follow a modified logistic growth model with growth rates
For the susceptible population, this growth rate is also modified by the drug concentration
Finally, both species have a natural death rate, of
This model is implemented in the LotkaVolterraModel
class, which inherits from ODEModel
. This parent class sets parameters such as error tolerances for the solver, and then solves the ODE model for each treatment period sequentially.
The A3C network consists of a global network, with many duplicates (workers) updating to this asynchronously). An example (run on a machine with 4 cores, and so 3 workers) is given below:
Each network (global or worker) takes in 15 dimensional input:
- Initial tumour size
- Tumour size for the last 7 timesteps
- Tumour growth for the last 7 timesteps (I.e. differences between sequential tumour size measurements)
-
Input Layers
- LSTM – Long Short Term Memory layer, gives 15-dimensional output. Can detect useful time-lagged correlations between components?
-
Hidden Layers
- Fully connected layers for each size in [128, 64, 32, 16, 10]
- Each are multiplied by the previous output to produce a Tensor of hidden units.
- Uses a rectified linear activation function
-
Output Layers
- Policy - Fully connected layer of output size
n_DoseOptions
, softmax activation function - Value- Fully connected layer of output size 1, no activation function (linear behaviour)
- Policy - Fully connected layer of output size
- Base reward of 0.1 per cycle (timestep) survived
- Additional reward of 0.05 for cycles without treatment
- Punishment of –0.1 for exceeding 20% tumour growth
- Ultimate reward of 5 for indefinite survival (100 years*)
This uses discounting to determine how important future rewards are to the current state, based on a discount factor (
As well as ensuring all reward sums therefore converge, this factor is used to prioritise shorter term results given uncertainty in future behaviour of the system (tumour). However the value tends to be very close to one, to avoid particularly short-sighted decisions, as we are primarily focused on the final outcome in this context.
The training scripts will generate a number of files in the models/
directory, including checkpoints of the trained model, and log files of the training.
The model checkpoints are saved regularly, and stored in directories such as models\test_currSizeOnly_p25_monthly
. These can be read back into tensorflow for further training or evaluation using the tf.train.get_checkpoint_state()
function.
The log files for training are not essential for evaluation, but may be useful to determine the state of training - they are updated iteratively and so can be viewed while training is still running. They may be viewed with tensorboard, which can be run using tensorboard --logdir path/to/logs
from the home directory, and accessed at http://localhost:6006
. These are provided on a per-worker basis, and also to summarise the global network.
A treatment history is also provided for each worker, giving a history of the complete treatment schedule delivered to each virtual patient used in training, and their survival time (i.e. time to progression).
Finally, a parameters file is provided to detail the parameter values used in this simulation.