Skip to content

Commit ebb43c5

Browse files
authored
Merge branch 'dev' into rm-baselines
2 parents bd0d1c6 + 2255238 commit ebb43c5

File tree

20 files changed

+171
-50
lines changed

20 files changed

+171
-50
lines changed

.github/workflows/build-docs.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ env:
1414

1515
jobs:
1616
documentation:
17-
runs-on: self-hosted
17+
runs-on: [self-hosted]
1818
steps:
1919
- uses: actions/checkout@v4
2020

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ Check out all our tutorials at [torch-uncertainty.github.io/auto_tutorials](http
105105

106106
The following projects use TorchUncertainty:
107107

108+
- _Towards Understanding and Quantifying Uncertainty for Text-to-Image Generation_ - [CVPR 2025](https://openaccess.thecvf.com/content/CVPR2025/papers/Franchi_Towards_Understanding_and_Quantifying_Uncertainty_for_Text-to-Image_Generation_CVPR_2025_paper.pdf)
108109
- _Towards Understanding Why Label Smoothing Degrades Selective Classification and How to Fix It_ - [ICLR 2025](https://arxiv.org/abs/2403.14715)
109110
- _A Symmetry-Aware Exploration of Bayesian Neural Network Posteriors_ - [ICLR 2024](https://arxiv.org/abs/2310.08287)
110111

auto_tutorial_source/Bayesian_Methods/tutorial_bayesian.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,7 @@
1616
For more information on Bayesian Neural Networks, we refer to the following resources:
1717
1818
- Weight Uncertainty in Neural Networks `ICML2015 <https://arxiv.org/pdf/1505.05424.pdf>`_
19-
- Hands-on Bayesian Neural Networks - a Tutorial for Deep Learning Users `IEEE Computational Intelligence Magazine
20-
<https://arxiv.org/pdf/2007.06823.pdf>`_
19+
- Hands-on Bayesian Neural Networks - a Tutorial for Deep Learning Users `IEEE Computational Intelligence Magazine <https://arxiv.org/pdf/2007.06823.pdf>`_
2120
2221
Training a Bayesian LeNet using TorchUncertainty models and Lightning
2322
---------------------------------------------------------------------

auto_tutorial_source/Classification/tutorial_ood_detection.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,4 +146,5 @@
146146
# ----------
147147
#
148148
# [1] Hendrycks, D., & Gimpel, K. (2016). A baseline for detecting misclassified and out-of-distribution examples in neural networks. In ICLR 2017.
149+
#
149150
# [2] Hendrycks, D., Basart, S., Mazeika, M., Zou, A., Kwon, J., Mostajabi, M., ... & Song, D. (2019). Scaling out-of-distribution detection for real-world settings. In ICML 2022.

auto_tutorial_source/Ensemble_Methods/tutorial_from_de_to_pe.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,11 @@
33
Improved Ensemble parameter-efficiency with Packed-Ensembles
44
============================================================
55
6-
*This tutorial is adapted from a notebook part of a lecture given at the `Helmholtz AI Conference <https://haicon24.de/>`_ by Sebastian Starke, Peter Steinbach, Gianni Franchi, and Olivier Laurent.*
6+
*This tutorial is adapted from a notebook part of a lecture given at the* |conference|_ *by Sebastian Starke, Peter Steinbach, Gianni Franchi, and Olivier Laurent.*
7+
8+
.. _conference: https://haicon24.de/
9+
10+
.. |conference| replace:: *Helmholtz AI Conference*
711
812
In this notebook will work on the MNIST dataset that was introduced by Corinna Cortes, Christopher J.C. Burges, and later modified by Yann LeCun in the foundational paper:
913
@@ -12,6 +16,7 @@
1216
The MNIST dataset consists of 70 000 images of handwritten digits from 0 to 9. The images are grayscale and 28x28-pixel sized. The task is to classify the images into their respective digits. The dataset can be automatically downloaded using the `torchvision` library.
1317
1418
In this notebook, we will train a model and an ensemble on this task and evaluate their performance. The performance will consist in the following metrics:
19+
1520
- Accuracy: the proportion of correctly classified images,
1621
- Brier score: a measure of the quality of the predicted probabilities,
1722
- Calibration error: a measure of the calibration of the predicted probabilities,
@@ -174,13 +179,16 @@ def optim_recipe(model, lr_mult: float = 1.0):
174179
# This table provides a lot of information:
175180
#
176181
# **OOD Detection: Binary Classification MNIST vs. FashionMNIST**
182+
#
177183
# - AUPR/AUROC/FPR95: Measures the quality of the OOD detection. The higher the better for AUPR and AUROC, the lower the better for FPR95.
178184
#
179185
# **Calibration: Reliability of the Predictions**
186+
#
180187
# - ECE: Expected Calibration Error. The lower the better.
181188
# - aECE: Adaptive Expected Calibration Error. The lower the better. (~More precise version of the ECE)
182189
#
183190
# **Classification Performance**
191+
#
184192
# - Accuracy: The ratio of correctly classified images. The higher the better.
185193
# - Brier: The quality of the predicted probabilities (Mean Squared Error of the predictions vs. ground-truth). The lower the better.
186194
# - Negative Log-Likelihood: The value of the loss on the test set. The lower the better.
@@ -236,7 +244,7 @@ def optim_recipe(model, lr_mult: float = 1.0):
236244
# We need to multiply the learning rate by 2 to account for the fact that we have 2 models
237245
# in the ensemble and that we average the loss over all the predictions.
238246
#
239-
# #### Downloading the pre-trained models
247+
# **Downloading the pre-trained models**
240248
#
241249
# We have put the pre-trained models on Hugging Face that you can download with the utility function
242250
# "hf_hub_download" imported just below. These models are trained for 75 epochs and are therefore not
@@ -393,9 +401,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
393401
# In constrast to calibration, the values of the confidence scores are not important, only the order of the scores. *Ideally, the best model will order all the correct predictions first, and all the incorrect predictions last.* In this case, there will be a threshold so that all the predictions above the threshold are correct, and all the predictions below the threshold are incorrect.
394402
#
395403
# In TorchUncertainty, we look at 3 different metrics for selective classification:
404+
#
396405
# - **AURC**: The area under the Risk (% of errors) vs. Coverage (% of classified samples) curve. This curve expresses how the risk of the model evolves as we increase the coverage (the proportion of predictions that are above the selection threshold). This metric will be minimized by a model able to perfectly separate the correct and incorrect predictions.
397406
#
398407
# The following metrics are computed at a fixed risk and coverage level and that have practical interests. The idea of these metrics is that you can set the selection threshold to achieve a certain level of risk and coverage, as required by the technical constraints of your application:
408+
#
399409
# - **Coverage at 5% Risk**: The proportion of predictions that are above the selection threshold when it is set for the risk to egal 5%. Set the risk threshold to your application constraints. The higher the better.
400410
# - **Risk at 80% Coverage**: The proportion of errors when the coverage is set to 80%. Set the coverage threshold to your application constraints. The lower the better.
401411
#

auto_tutorial_source/Post_Hoc_Methods/tutorial_scaler.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@
101101
# We also compute and plot the top-label calibration figure. We see that the
102102
# model is not well calibrated.
103103
fig, ax = ece.plot()
104+
fig.tight_layout()
104105
fig.show()
105106

106107
# %%
@@ -143,6 +144,7 @@
143144
# that the model is now better calibrated. If the temperature is greater than 1,
144145
# the final model is less confident than before.
145146
fig, ax = ece.plot()
147+
fig.tight_layout()
146148
fig.show()
147149

148150
# %%

docs/source/api.rst

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,6 @@ Classes
159159
BatchEnsemble
160160
CheckpointCollector
161161
EMA
162-
MCDropout
163162
StochasticModel
164163
SWA
165164
SWAG

docs/source/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
f"{datetime.now().year!s}, Adrien Lafage and Olivier Laurent"
2121
)
2222
author = "Adrien Lafage and Olivier Laurent"
23-
release = "0.7.0"
23+
release = "0.7.0.post1"
2424

2525
# -- General configuration ---------------------------------------------------
2626
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "flit_core.buildapi"
44

55
[project]
66
name = "torch_uncertainty"
7-
version = "0.7.0"
7+
version = "0.7.0.post1"
88
authors = [
99
{ name = "ENSTA U2IS AI", email = "[email protected]" },
1010
{ name = "Adrien Lafage", email = "[email protected]" },
Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import warnings
2+
from urllib.error import URLError
3+
14
import pytest
25

36
from torch_uncertainty.datamodules.classification import (
@@ -13,27 +16,30 @@ class TestHTRU2DataModule:
1316
"""Testing the HTRU2DataModule datamodule class."""
1417

1518
def test_htru2(self) -> None:
16-
dm = HTRU2DataModule(root="./data/", batch_size=128)
19+
try:
20+
dm = HTRU2DataModule(root="./data/", batch_size=128)
1721

18-
dm.prepare_data()
19-
dm.setup()
22+
dm.prepare_data()
23+
dm.setup()
2024

21-
dm.train_dataloader()
22-
dm.val_dataloader()
23-
dm.test_dataloader()
25+
dm.train_dataloader()
26+
dm.val_dataloader()
27+
dm.test_dataloader()
2428

25-
dm.setup("test")
26-
dm.test_dataloader()
29+
dm.setup("test")
30+
dm.test_dataloader()
2731

28-
dm = HTRU2DataModule(root="./data/", batch_size=128, val_split=0.1)
32+
dm = HTRU2DataModule(root="./data/", batch_size=128, val_split=0.1)
2933

30-
dm.prepare_data()
31-
dm.setup()
34+
dm.prepare_data()
35+
dm.setup()
3236

33-
with pytest.raises(ValueError):
34-
dm.setup("other")
37+
with pytest.raises(ValueError):
38+
dm.setup("other")
3539

36-
dm = BankMarketingDataModule(root="./data/", batch_size=128)
37-
dm = DOTA2GamesDataModule(root="./data/", batch_size=128)
38-
dm = OnlineShoppersDataModule(root="./data/", batch_size=128)
39-
dm = SpamBaseDataModule(root="./data/", batch_size=128)
40+
dm = BankMarketingDataModule(root="./data/", batch_size=128)
41+
dm = DOTA2GamesDataModule(root="./data/", batch_size=128)
42+
dm = OnlineShoppersDataModule(root="./data/", batch_size=128)
43+
dm = SpamBaseDataModule(root="./data/", batch_size=128)
44+
except URLError as e:
45+
warnings.warn(f"Data download failed due to network error: {e}", stacklevel=2)

0 commit comments

Comments
 (0)