-
Notifications
You must be signed in to change notification settings - Fork 2
03_Trained
Finally, I train my own models to be able to classify single-cell images into their respective perturbations. This so-called Weakly Supervised Representation Learning leads the pre-trained model to become more sensitive to cell inputs and partially displays better metric scores.
- Training Index
- Results
The training index determines the split between training, testing, and unused cells in an index.csv
file that can be found in the /trained/index/
folder. In this file, I define which sites are used for training, which sites are used in the test set (which happens after every epoch of training), and which sites and cells will not be used anywhere. Furthermore, during the sampling of an index file a sampling factor between 0 and 1 is set which determines the percentage of cells cropped into the sampling folder from the full amount provided by the index file. In other words: If the sampling factor is 0.2 and the site A01/1 is marked for training in the index.csv, then 20% of all ~200 cells in the A01/1 site are cropped and added to the sampling folder.
Throughout the experiments, different indexes and subsections of the full dataset will be used to train and test. I will be able to show the effect of training on the evaluation metrics as well as keep training times down by training on subsections of the data. Below I will list the different indexes created and their sampling
817 sample 1,487,477 cells
915 sample 322,238 cells
826 sample 1,546,126 cells
Top20 Moa 4,036,145 cells
812/811 sample) 1,457,934
823 sample 1,533,027
Check EC2 instance types for their technical details Technical details of the Nvidia A100-SXM4-40GB (CHTC server)
Inference times are compared by their inference speed of a site in seconds. Profiling the subsection of the LINCS dataset is around 90,000 sites or around 26 plates (or 1/5 of the entire LINCS dataset). A quick way to predict how long the inference step will take is to multiply the seconds taken to infer a site and multiply with 90,000.
Changing the batch size in the config file will slightly affect the time estimates below.
CPU: With 4 cores, the profiling runs at ~100 seconds per site. This would result in a very, very long inference time
P2 GPU: The cheapest GPU on AWS has an Nvidia Tesla K80 which infers a site in ~10s. Still very slow.
P3 GPU: The next EC2 generation (NVIDIA Tesla V100) can infer a site in ~2.5 seconds. This allows inferring the LINCS subsection within a handful of days.
CHTC: On the CTHC server (NVIDA Ampere 100) the speed goes down to ~0.2 seconds per site. The overall time drops to 5 hours.
Epoch speeds obviously depend on the size of the dataset (the number of crops) and on the batch size. Training with augmentations increases the training time by a factor of ~1.4.
The time estimates below were taken from training experiments where I had around 1.5 million cells in the sample folder.
P3 GPU: On the T 100 servers, an epoch takes ~ 6 hours.
CHTC: On the CTHC server (NVIDA Ampere 100) an epoch training takes around 2 hours. This means that a common training cycle with 20 Epochs can be done in two days. Below is an example of the training output.
13230/13230 [==============================] - 7301s 551ms/step - batch: 6614.5000 - size: 256.0000 - loss: 5.1071 - acc: 0.0776 - top_5: 0.2157 - val_loss: 6.4854 - val_acc: 0.0455 - val_top_5: 0.1378 - lr: 0.0010