|
3 | 3 | Improved Ensemble parameter-efficiency with Packed-Ensembles
|
4 | 4 | ============================================================
|
5 | 5 |
|
6 |
| -*This tutorial is adapted from a notebook part of a lecture given at the `Helmholtz AI Conference <https://haicon24.de/>`_ by Sebastian Starke, Peter Steinbach, Gianni Franchi, and Olivier Laurent.* |
| 6 | +*This tutorial is adapted from a notebook part of a lecture given at the* |conference|_ *by Sebastian Starke, Peter Steinbach, Gianni Franchi, and Olivier Laurent.* |
| 7 | +
|
| 8 | +.. _conference: https://haicon24.de/ |
| 9 | +
|
| 10 | +.. |conference| replace:: *Helmholtz AI Conference* |
7 | 11 |
|
8 | 12 | In this notebook will work on the MNIST dataset that was introduced by Corinna Cortes, Christopher J.C. Burges, and later modified by Yann LeCun in the foundational paper:
|
9 | 13 |
|
|
12 | 16 | The MNIST dataset consists of 70 000 images of handwritten digits from 0 to 9. The images are grayscale and 28x28-pixel sized. The task is to classify the images into their respective digits. The dataset can be automatically downloaded using the `torchvision` library.
|
13 | 17 |
|
14 | 18 | In this notebook, we will train a model and an ensemble on this task and evaluate their performance. The performance will consist in the following metrics:
|
| 19 | +
|
15 | 20 | - Accuracy: the proportion of correctly classified images,
|
16 | 21 | - Brier score: a measure of the quality of the predicted probabilities,
|
17 | 22 | - Calibration error: a measure of the calibration of the predicted probabilities,
|
@@ -174,13 +179,16 @@ def optim_recipe(model, lr_mult: float = 1.0):
|
174 | 179 | # This table provides a lot of information:
|
175 | 180 | #
|
176 | 181 | # **OOD Detection: Binary Classification MNIST vs. FashionMNIST**
|
| 182 | +# |
177 | 183 | # - AUPR/AUROC/FPR95: Measures the quality of the OOD detection. The higher the better for AUPR and AUROC, the lower the better for FPR95.
|
178 | 184 | #
|
179 | 185 | # **Calibration: Reliability of the Predictions**
|
| 186 | +# |
180 | 187 | # - ECE: Expected Calibration Error. The lower the better.
|
181 | 188 | # - aECE: Adaptive Expected Calibration Error. The lower the better. (~More precise version of the ECE)
|
182 | 189 | #
|
183 | 190 | # **Classification Performance**
|
| 191 | +# |
184 | 192 | # - Accuracy: The ratio of correctly classified images. The higher the better.
|
185 | 193 | # - Brier: The quality of the predicted probabilities (Mean Squared Error of the predictions vs. ground-truth). The lower the better.
|
186 | 194 | # - Negative Log-Likelihood: The value of the loss on the test set. The lower the better.
|
@@ -236,7 +244,7 @@ def optim_recipe(model, lr_mult: float = 1.0):
|
236 | 244 | # We need to multiply the learning rate by 2 to account for the fact that we have 2 models
|
237 | 245 | # in the ensemble and that we average the loss over all the predictions.
|
238 | 246 | #
|
239 |
| -# #### Downloading the pre-trained models |
| 247 | +# **Downloading the pre-trained models** |
240 | 248 | #
|
241 | 249 | # We have put the pre-trained models on Hugging Face that you can download with the utility function
|
242 | 250 | # "hf_hub_download" imported just below. These models are trained for 75 epochs and are therefore not
|
@@ -393,9 +401,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
|
393 | 401 | # In constrast to calibration, the values of the confidence scores are not important, only the order of the scores. *Ideally, the best model will order all the correct predictions first, and all the incorrect predictions last.* In this case, there will be a threshold so that all the predictions above the threshold are correct, and all the predictions below the threshold are incorrect.
|
394 | 402 | #
|
395 | 403 | # In TorchUncertainty, we look at 3 different metrics for selective classification:
|
| 404 | +# |
396 | 405 | # - **AURC**: The area under the Risk (% of errors) vs. Coverage (% of classified samples) curve. This curve expresses how the risk of the model evolves as we increase the coverage (the proportion of predictions that are above the selection threshold). This metric will be minimized by a model able to perfectly separate the correct and incorrect predictions.
|
397 | 406 | #
|
398 | 407 | # The following metrics are computed at a fixed risk and coverage level and that have practical interests. The idea of these metrics is that you can set the selection threshold to achieve a certain level of risk and coverage, as required by the technical constraints of your application:
|
| 408 | +# |
399 | 409 | # - **Coverage at 5% Risk**: The proportion of predictions that are above the selection threshold when it is set for the risk to egal 5%. Set the risk threshold to your application constraints. The higher the better.
|
400 | 410 | # - **Risk at 80% Coverage**: The proportion of errors when the coverage is set to 80%. Set the coverage threshold to your application constraints. The lower the better.
|
401 | 411 | #
|
|
0 commit comments