diff --git a/README.md b/README.md
index c0120a7b..4b2f5b9b 100644
--- a/README.md
+++ b/README.md
@@ -69,6 +69,7 @@ supports tensorflow versions 2.x.
| **Tutorial Name** | Notebook |
| :-------------------------- | :----------------------------------------------------------------------------------------------------------------------------------------------------------------: |
| Getting Started 1 - Creating a 1-Lipschitz neural network | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/deel-ai/deel-lip/blob/master/docs/notebooks/Getting_started_1.ipynb) |
+| Getting Started 2 - Training an adversarially robust 1-Lipschitz neural network | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/deel-ai/deel-lip/blob/master/docs/notebooks/Getting_started_2.ipynb) |
| Wasserstein distance estimation on toy example | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/deel-ai/deel-lip/blob/master/docs/notebooks/demo1.ipynb) |
| HKR Classifier on toy dataset | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/deel-ai/deel-lip/blob/master/docs/notebooks/demo2.ipynb) |
| HKR classifier on MNIST dataset | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/deel-ai/deel-lip/blob/master/docs/notebooks/demo3.ipynb) |
diff --git a/docs/assets/noise.PNG b/docs/assets/noise.PNG
new file mode 100644
index 00000000..ec1c51a0
Binary files /dev/null and b/docs/assets/noise.PNG differ
diff --git a/docs/assets/pigs.png b/docs/assets/pigs.png
new file mode 100644
index 00000000..2617f8da
Binary files /dev/null and b/docs/assets/pigs.png differ
diff --git a/docs/index.md b/docs/index.md
index 4621a07b..eff7a217 100644
--- a/docs/index.md
+++ b/docs/index.md
@@ -66,6 +66,7 @@ supports tensorflow versions 2.x.
| **Tutorial Name** | Notebook |
| :-------------------------- | :----------------------------------------------------------------------------------------------------------------------------------------------------------------: |
| Getting started 1 - Creating a 1-Lipschitz neural network | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/deel-ai/deel-lip/blob/master/docs/notebooks/Getting_started_1.ipynb) |
+| Getting started 2 - Training an adversarially robust 1-Lipschitz neural network | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/deel-ai/deel-lip/blob/master/docs/notebooks/Getting_started_2.ipynb) |
| Wasserstein distance estimation on toy example | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/deel-ai/deel-lip/blob/master/docs/notebooks/demo1.ipynb) |
| HKR Classifier on toy dataset | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/deel-ai/deel-lip/blob/master/docs/notebooks/demo2.ipynb) |
| HKR classifier on MNIST dataset | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/deel-ai/deel-lip/blob/master/docs/notebooks/demo3.ipynb) |
diff --git a/docs/notebooks/Getting_started_2.ipynb b/docs/notebooks/Getting_started_2.ipynb
new file mode 100644
index 00000000..35c978a3
--- /dev/null
+++ b/docs/notebooks/Getting_started_2.ipynb
@@ -0,0 +1,811 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "86557797-36de-44bb-b277-2ee7a22b1458",
+ "metadata": {},
+ "source": [
+ "# ๐ Getting started 2: Training adversarially robust 1-Lipschitz neural networks for classification\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "23c0071a-2546-4143-af7e-496d9ffa2fd4",
+ "metadata": {},
+ "source": [
+ "The goal of this series of tutorials is to show the different usages of `deel-lip`.\n",
+ "\n",
+ "In the first notebook, we have shown how to create 1-Lipschitz neural networks with\n",
+ "`deel-lip`. In this second notebook, we will show how to train adversarially robust\n",
+ "1-Lipschitz neural networks with `deel-lip`.\n",
+ "\n",
+ "In particular, we will cover the following:\n",
+ "\n",
+ "1. [๐ Theoretical background](#theoretical_background) A brief theoretical background\n",
+ " on adversarial robustness. This section can be safely skipped if one is not\n",
+ " interested in the theory.\n",
+ "\n",
+ "2. [๐ช Training provable adversarially robust 1-Lipschitz neural networks on the MNIST dataset](#deel_keras)\n",
+ " Using the MNIST dataset, we will show examples of training adversarially robust\n",
+ " 1-Lipschitz neural networks using `deel-lip` loss functions\n",
+ " `TauCategoricalCrossentropy` and `MulticlassHKR`.\n",
+ "\n",
+ "We will also see that:\n",
+ "\n",
+ "- when training robust models, there is an accuracy-robustness trade-off\n",
+ "- the `MulticlassKR` loss function can be used to assess the adversarial robustness of\n",
+ " the resulting models\n",
+ "\n",
+ "## ๐ Theoretical background \n",
+ "\n",
+ "### Adversarial attacks\n",
+ "\n",
+ "In the context of classification problems, an adversarial attack is the result of adding\n",
+ "an _adversarial perturbation_ $\\epsilon$ to the input data point $x$ of a trained\n",
+ "predictive model $A$, with the intent to change its prediction (for simplicity, $A$\n",
+ "returns a class as opposed to a set of logits in the formalism used below).\n",
+ "\n",
+ "In simple mathematical terms, an adversarial example (i.e. a successful adversarial\n",
+ "attack) can be transcribed as below:\n",
+ "\n",
+ "$$A(x)=y_1,$$\n",
+ "\n",
+ "$$A(x+\\epsilon)=y_{\\epsilon},$$\n",
+ "\n",
+ "where:\n",
+ "\n",
+ "$$y_1\\neq y_\\epsilon.$$\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "f8769d17-2c47-4491-ad22-7785324f88ec",
+ "metadata": {},
+ "source": [
+ "### An adversarial example\n",
+ "\n",
+ "The following example is directly taken from\n",
+ "https://adversarial-ml-tutorial.org/introduction/.\n",
+ "\n",
+ "![pigs.png](../assets/pigs.png)\n",
+ "\n",
+ "The first image is correctly classified as a **pig** by a classifier. The second image\n",
+ "is incorrectly classified as an **airplane** by the same classifier.\n",
+ "\n",
+ "While both images cannot be distinguished from our (human) perspective, the second image\n",
+ "is in fact the result of surimposing \"noise\" (i.e. adding an adversarial perturbation)\n",
+ "to the original first image.\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "84e3cb53-98f7-44cb-bb9b-29b94e4fc6bd",
+ "metadata": {},
+ "source": [
+ "Below is a visualization of the added noise, zoomed-in by a factor of 50 so that we can\n",
+ "see it:\n",
+ "\n",
+ "![noise.png](../assets/noise.PNG)\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "d4125128-de5d-4267-83ae-892e9d8ced69",
+ "metadata": {},
+ "source": [
+ "### Adversarial robustness of 1-Lipschitz neural network\n",
+ "\n",
+ "The adversarial robustness of a predictive model is its ability to remain accurate and\n",
+ "reliable when subjected to adversarial perturbations.\n",
+ "\n",
+ "A major advantage of 1-Lipschitz neural networks is that they can offer provable\n",
+ "guarantees on their robustness for any particular input $x$, by providing a\n",
+ "_certificate_ $\\epsilon_x$. Such a guarantee can be understood by using the following\n",
+ "terminology:\n",
+ "\n",
+ "> \"For an input $x$, we can certify that there are no adversarial perturbations\n",
+ "> constrained to be under the certificate $\\epsilon_x$ that will change our model's\n",
+ "> prediction.\"\n",
+ "\n",
+ "In simple mathematical terms:\n",
+ "\n",
+ "For a given $x$, $\\forall \\epsilon$ such that $\\|\\epsilon\\|<\\epsilon_x$, we obtain that:\n",
+ "\n",
+ "$$A(x)=y,$$\n",
+ "\n",
+ "$$A(x+\\epsilon)=y_{\\epsilon},$$\n",
+ "\n",
+ "then:\n",
+ "\n",
+ "$$y_{\\epsilon}=y.$$\n",
+ "\n",
+ "๐ก We will use certificates in this notebook as a metric to evaluate the provable\n",
+ "adversarial robustness of deep learning 1-Lispchitz models.\n",
+ "\n",
+ "๐ก Depending on the type of norm you choose (e.g. $L_1$ or $L_2$), the guarantee you can\n",
+ "offer will differ, as $\\|\\epsilon\\|_2<\\epsilon_x$ and $\\|\\epsilon\\|_1<\\epsilon_x$ are\n",
+ "not equivalent.\n",
+ "\n",
+ "๐จ **Note**: _`deel-lip` only deals with $L_2$ norm, as previously said in the first\n",
+ "notebook 'Getting started 1'_\n",
+ "\n",
+ "As such, an additional example of guarantee that could be obtained with `deel-lip` with\n",
+ "a more precise formulation would be:\n",
+ "\n",
+ "> \"For an input $x$, we can certify that are no adversarial perturbations constrained to\n",
+ "> be within a $L_2$-norm ball of certificate $\\epsilon_{x,L_2}$ that will change our\n",
+ "> model's prediction.\"\n",
+ "\n",
+ "For a given $x$, $\\forall \\epsilon$ such that $\\|\\epsilon\\|_2<\\epsilon_{x,L_2}$, we\n",
+ "obtain that: $$A(x)=y,$$ $$A(x+\\epsilon)=y_{\\epsilon},$$ then: $$y_{\\epsilon}=y.$$\n",
+ "\n",
+ "## ๐ช Training provable adversarially robust 1-Lipschitz neural networks on the MNIST dataset \n",
+ "\n",
+ "### ๐พ MNIST dataset\n",
+ "\n",
+ "MNIST dataset contains a large number of 28x28 handwritten digit images to which are\n",
+ "associated digit labels.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "d3cf25c6-1691-4935-bdac-9f182a87ef32",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "\n",
+ "os.environ[\"TF_CPP_MIN_LOG_LEVEL\"] = \"2\"\n",
+ "\n",
+ "import tensorflow as tf\n",
+ "from tensorflow.keras.datasets import mnist\n",
+ "from tensorflow.keras.utils import to_categorical\n",
+ "import numpy as np"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "90da93e3-9c59-4e72-b817-01f97b324c51",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Load MNIST Database\n",
+ "(X_train, y_train_ord), (X_test, y_test_ord) = mnist.load_data()\n",
+ "\n",
+ "# standardize and reshape the data\n",
+ "X_train = np.expand_dims(X_train, -1) / 255\n",
+ "X_test = np.expand_dims(X_test, -1) / 255\n",
+ "\n",
+ "# one hot encode the labels\n",
+ "y_train = to_categorical(y_train_ord)\n",
+ "y_test = to_categorical(y_test_ord)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "df14bc65-3ebb-4f3c-925b-9a21cefa7e23",
+ "metadata": {},
+ "source": [
+ "### ๐ฎ Control over the accuracy-robustness trade-off with `deel-lip`'s loss functions.\n",
+ "\n",
+ "When training neural networks, there is always a compromise between the robustness and\n",
+ "the accuracy of the models. In simple terms, achieving stronger robustness often\n",
+ "involves sacrificing some performance (at the extreme point, the most robust function\n",
+ "being the constant function).\n",
+ "\n",
+ "In this section, we will show the pivotal role of `deel-lip`'s loss functions in\n",
+ "training 1-Lipschitz networks. Each of these functions comes with its own set of\n",
+ "hyperparameters, enabling you to precisely navigate and adjust the balance between\n",
+ "accuracy and robustness.\n",
+ "\n",
+ "We show two cases. In the first case, we use `deel-lip`'s `TauCategoricalCrossentropy`\n",
+ "from the `losses` submodule. In the second case, we use another loss function from\n",
+ "`deel-lip`: `MulticlassHKR`.\n",
+ "\n",
+ "#### ๐ฎ Prediction Model\n",
+ "\n",
+ "Since we will be instantiating the same model four times within our examples, we\n",
+ "encapsulate the code for creating the model within a function to enhance conciseness:\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "359fa47e-86f3-440d-971f-0a49fe7dcdc7",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from deel import lip\n",
+ "\n",
+ "from tensorflow.keras.optimizers import Adam\n",
+ "from tensorflow.keras.layers import Input, Flatten\n",
+ "\n",
+ "\n",
+ "def create_conv_model(name_model, input_shape, output_shape):\n",
+ " \"\"\"\n",
+ " A simple convolutional neural network, made to be 1-Lipschitz.\n",
+ " \"\"\"\n",
+ " model = lip.Sequential(\n",
+ " [\n",
+ " Input(shape=input_shape),\n",
+ " lip.layers.SpectralConv2D(\n",
+ " filters=16,\n",
+ " kernel_size=(3, 3),\n",
+ " use_bias=True,\n",
+ " kernel_initializer=\"orthogonal\",\n",
+ " ),\n",
+ " lip.layers.GroupSort2(),\n",
+ " lip.layers.ScaledL2NormPooling2D(\n",
+ " pool_size=(2, 2), data_format=\"channels_last\"\n",
+ " ),\n",
+ " lip.layers.SpectralConv2D(\n",
+ " filters=32,\n",
+ " kernel_size=(3, 3),\n",
+ " use_bias=True,\n",
+ " kernel_initializer=\"orthogonal\",\n",
+ " ),\n",
+ " lip.layers.GroupSort2(),\n",
+ " lip.layers.ScaledL2NormPooling2D(\n",
+ " pool_size=(2, 2), data_format=\"channels_last\"\n",
+ " ),\n",
+ " Flatten(),\n",
+ " lip.layers.SpectralDense(\n",
+ " 64,\n",
+ " use_bias=True,\n",
+ " kernel_initializer=\"orthogonal\",\n",
+ " ),\n",
+ " lip.layers.GroupSort2(),\n",
+ " lip.layers.SpectralDense(\n",
+ " output_shape,\n",
+ " activation=None,\n",
+ " use_bias=False,\n",
+ " kernel_initializer=\"orthogonal\",\n",
+ " ),\n",
+ " ],\n",
+ " name=name_model,\n",
+ " )\n",
+ "\n",
+ " return model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "b2d7a4ce-df5e-4bf0-b461-5049fc749680",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "input_shape = X_train.shape[1:]\n",
+ "output_shape = y_train.shape[-1]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "d4657a28-95c5-4f35-90bd-53a48bb44a5e",
+ "metadata": {},
+ "source": [
+ "#### Cross-entropy loss: `TauCategoricalCrossentropy`\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "c8aa1c60-0271-4b76-a0b7-7d2bf89f4550",
+ "metadata": {},
+ "source": [
+ "Similar to the classes we have seen in \"Getting started 1\", the\n",
+ "`TauCategoricalCrossentropy` class inherits from its equivalent in `keras`, but it comes\n",
+ "with an additional settable parameter named 'temperature' and denoted as: `tau`. This\n",
+ "parameter will allow to adjust the robustness of our model. The lower the temperature\n",
+ "is, the more robust our model becomes, but it also becomes less accurate.\n",
+ "\n",
+ "To show the impact of the parameter `tau` on both the performance and robustness of our\n",
+ "model, we will train two models on the MNIST dataset. The first model will have a\n",
+ "temperature of 100, the second model will have a temperature of 3.\n",
+ "\n",
+ "Note: The performance achieved in this tutorial is not state-of-the-art. It is\n",
+ "presented solely for educational purposes. Performance can be enhanced by employing a\n",
+ "different network architecture or by training for additional epochs.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "fc8a3115-ccb4-4ce3-86a2-b3c0a22866b3",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# high-temperature model\n",
+ "model_1 = create_conv_model(\"cross_entropy_model_1\", input_shape, output_shape)\n",
+ "\n",
+ "temperature_1 = 100.0\n",
+ "\n",
+ "model_1.compile(\n",
+ " loss=lip.losses.TauCategoricalCrossentropy(tau=temperature_1),\n",
+ " optimizer=Adam(1e-4),\n",
+ " # notice the use of lip.losses.Certificate_Multiclass,\n",
+ " # to assess adversarial robustness\n",
+ " metrics=[\n",
+ " \"accuracy\",\n",
+ " lip.metrics.CategoricalProvableAvgRobustness(disjoint_neurons=False),\n",
+ " ],\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "795bf0ad-5495-4b55-aa24-baa6205415b5",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# low-temperature model\n",
+ "model_2 = create_conv_model(\"cross_entropy_model_2\", input_shape, output_shape)\n",
+ "\n",
+ "temperature_2 = 3.0\n",
+ "\n",
+ "model_2.compile(\n",
+ " loss=lip.losses.TauCategoricalCrossentropy(tau=temperature_2),\n",
+ " optimizer=Adam(1e-4),\n",
+ " metrics=[\n",
+ " \"accuracy\",\n",
+ " lip.metrics.CategoricalProvableAvgRobustness(disjoint_neurons=False),\n",
+ " ],\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "13859a2a-cd4a-4cc5-bc90-f61a07cd22ad",
+ "metadata": {},
+ "source": [
+ "๐ก Notice that we use the accuracy metric to measure the performance, and we use the\n",
+ "`Certificate_Multiclass` loss to measure adversarial robustness. The latter is a measure\n",
+ "of our model's average certificates: **the higher this measure is, the more robust our\n",
+ "model is**.\n",
+ "\n",
+ "**๐จ Note:** _This is true only for 1-Lipschitz neural networks_\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "f878ac98-e80a-4d9a-a2a4-ef607e8b8001",
+ "metadata": {},
+ "source": [
+ "We fit both our models and observe the results.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "bba315c1-c4df-47aa-ae13-202831555355",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 1/10\n",
+ "235/235 [==============================] - 6s 10ms/step - loss: 0.0104 - accuracy: 0.7916 - CategoricalProvableAvgRobustness: 0.0290 - val_loss: 0.0028 - val_accuracy: 0.9184 - val_CategoricalProvableAvgRobustness: 0.0392\n",
+ "Epoch 2/10\n",
+ "235/235 [==============================] - 2s 8ms/step - loss: 0.0024 - accuracy: 0.9291 - CategoricalProvableAvgRobustness: 0.0414 - val_loss: 0.0019 - val_accuracy: 0.9443 - val_CategoricalProvableAvgRobustness: 0.0426\n",
+ "Epoch 3/10\n",
+ "235/235 [==============================] - 2s 8ms/step - loss: 0.0017 - accuracy: 0.9491 - CategoricalProvableAvgRobustness: 0.0451 - val_loss: 0.0014 - val_accuracy: 0.9574 - val_CategoricalProvableAvgRobustness: 0.0471\n",
+ "Epoch 4/10\n",
+ "235/235 [==============================] - 2s 7ms/step - loss: 0.0013 - accuracy: 0.9592 - CategoricalProvableAvgRobustness: 0.0481 - val_loss: 0.0011 - val_accuracy: 0.9658 - val_CategoricalProvableAvgRobustness: 0.0512\n",
+ "Epoch 5/10\n",
+ "235/235 [==============================] - 2s 7ms/step - loss: 0.0011 - accuracy: 0.9668 - CategoricalProvableAvgRobustness: 0.0503 - val_loss: 9.6942e-04 - val_accuracy: 0.9701 - val_CategoricalProvableAvgRobustness: 0.0520\n",
+ "Epoch 6/10\n",
+ "235/235 [==============================] - 2s 7ms/step - loss: 9.1108e-04 - accuracy: 0.9722 - CategoricalProvableAvgRobustness: 0.0524 - val_loss: 8.7599e-04 - val_accuracy: 0.9728 - val_CategoricalProvableAvgRobustness: 0.0536\n",
+ "Epoch 7/10\n",
+ "235/235 [==============================] - 2s 7ms/step - loss: 8.5019e-04 - accuracy: 0.9736 - CategoricalProvableAvgRobustness: 0.0545 - val_loss: 8.2655e-04 - val_accuracy: 0.9739 - val_CategoricalProvableAvgRobustness: 0.0561\n",
+ "Epoch 8/10\n",
+ "235/235 [==============================] - 2s 7ms/step - loss: 7.4599e-04 - accuracy: 0.9764 - CategoricalProvableAvgRobustness: 0.0565 - val_loss: 6.7749e-04 - val_accuracy: 0.9784 - val_CategoricalProvableAvgRobustness: 0.0597\n",
+ "Epoch 9/10\n",
+ "235/235 [==============================] - 2s 7ms/step - loss: 6.4887e-04 - accuracy: 0.9799 - CategoricalProvableAvgRobustness: 0.0579 - val_loss: 7.2717e-04 - val_accuracy: 0.9774 - val_CategoricalProvableAvgRobustness: 0.0604\n",
+ "Epoch 10/10\n",
+ "235/235 [==============================] - 2s 7ms/step - loss: 5.9098e-04 - accuracy: 0.9818 - CategoricalProvableAvgRobustness: 0.0597 - val_loss: 6.2794e-04 - val_accuracy: 0.9806 - val_CategoricalProvableAvgRobustness: 0.0623\n"
+ ]
+ }
+ ],
+ "source": [
+ "# fit the high-temperature model\n",
+ "result_1 = model_1.fit(\n",
+ " X_train,\n",
+ " y_train,\n",
+ " batch_size=256,\n",
+ " epochs=10,\n",
+ " validation_data=(X_test, y_test),\n",
+ " shuffle=True,\n",
+ " # verbose=1,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "914ac450-e529-406e-8233-79facc4c1190",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 1/10\n",
+ "235/235 [==============================] - 4s 10ms/step - loss: 0.3420 - accuracy: 0.7768 - CategoricalProvableAvgRobustness: 0.2767 - val_loss: 0.1553 - val_accuracy: 0.9009 - val_CategoricalProvableAvgRobustness: 0.4921\n",
+ "Epoch 2/10\n",
+ "235/235 [==============================] - 2s 8ms/step - loss: 0.1320 - accuracy: 0.9126 - CategoricalProvableAvgRobustness: 0.5823 - val_loss: 0.1054 - val_accuracy: 0.9314 - val_CategoricalProvableAvgRobustness: 0.6599\n",
+ "Epoch 3/10\n",
+ "235/235 [==============================] - 2s 8ms/step - loss: 0.1022 - accuracy: 0.9321 - CategoricalProvableAvgRobustness: 0.6845 - val_loss: 0.0880 - val_accuracy: 0.9437 - val_CategoricalProvableAvgRobustness: 0.7208\n",
+ "Epoch 4/10\n",
+ "235/235 [==============================] - 2s 8ms/step - loss: 0.0892 - accuracy: 0.9419 - CategoricalProvableAvgRobustness: 0.7354 - val_loss: 0.0794 - val_accuracy: 0.9481 - val_CategoricalProvableAvgRobustness: 0.7705\n",
+ "Epoch 5/10\n",
+ "235/235 [==============================] - 2s 8ms/step - loss: 0.0814 - accuracy: 0.9481 - CategoricalProvableAvgRobustness: 0.7680 - val_loss: 0.0734 - val_accuracy: 0.9548 - val_CategoricalProvableAvgRobustness: 0.7979\n",
+ "Epoch 6/10\n",
+ "235/235 [==============================] - 2s 8ms/step - loss: 0.0761 - accuracy: 0.9522 - CategoricalProvableAvgRobustness: 0.7906 - val_loss: 0.0684 - val_accuracy: 0.9588 - val_CategoricalProvableAvgRobustness: 0.8190\n",
+ "Epoch 7/10\n",
+ "235/235 [==============================] - 2s 8ms/step - loss: 0.0721 - accuracy: 0.9553 - CategoricalProvableAvgRobustness: 0.8094 - val_loss: 0.0645 - val_accuracy: 0.9618 - val_CategoricalProvableAvgRobustness: 0.8292\n",
+ "Epoch 8/10\n",
+ "235/235 [==============================] - 2s 8ms/step - loss: 0.0691 - accuracy: 0.9577 - CategoricalProvableAvgRobustness: 0.8233 - val_loss: 0.0631 - val_accuracy: 0.9599 - val_CategoricalProvableAvgRobustness: 0.8397\n",
+ "Epoch 9/10\n",
+ "235/235 [==============================] - 2s 8ms/step - loss: 0.0666 - accuracy: 0.9593 - CategoricalProvableAvgRobustness: 0.8341 - val_loss: 0.0603 - val_accuracy: 0.9645 - val_CategoricalProvableAvgRobustness: 0.8583\n",
+ "Epoch 10/10\n",
+ "235/235 [==============================] - 2s 8ms/step - loss: 0.0648 - accuracy: 0.9608 - CategoricalProvableAvgRobustness: 0.8431 - val_loss: 0.0602 - val_accuracy: 0.9621 - val_CategoricalProvableAvgRobustness: 0.8598\n"
+ ]
+ }
+ ],
+ "source": [
+ "# fit the low-temperature model\n",
+ "result_2 = model_2.fit(\n",
+ " X_train,\n",
+ " y_train,\n",
+ " batch_size=256,\n",
+ " epochs=10,\n",
+ " validation_data=(X_test, y_test),\n",
+ " shuffle=True,\n",
+ " # verbose=1,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "id": "af5c9695-023b-4fd4-a1f8-d630f3298f43",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Model accuracy: 0.9806\n",
+ "Model's mean certificate: 0.0623\n",
+ "Loss' temperature: 100.0\n"
+ ]
+ }
+ ],
+ "source": [
+ "# metrics for the high-temperature model => performance-oriented\n",
+ "print(f\"Model accuracy: {result_1.history['val_accuracy'][-1]:.4f}\")\n",
+ "print(\n",
+ " f\"Model's mean certificate: {result_1.history['val_CategoricalProvableAvgRobustness'][-1]:.4f}\"\n",
+ ")\n",
+ "print(f\"Loss' temperature: {model_1.loss.tau.numpy():.1f}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "id": "f5ed957c-4308-4f75-9a41-dcbc88dcc341",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Model accuracy: 0.9621\n",
+ "Model's mean certificate: 0.8598\n",
+ "Loss' temperature: 3.0\n"
+ ]
+ }
+ ],
+ "source": [
+ "# metrics for the low-temperature model => robustness-oriented\n",
+ "print(f\"Model accuracy: {result_2.history['val_accuracy'][-1]:.4f}\")\n",
+ "print(\n",
+ " f\"Model's mean certificate: {result_2.history['val_CategoricalProvableAvgRobustness'][-1]:.4f}\"\n",
+ ")\n",
+ "print(f\"Loss' temperature: {model_2.loss.tau.numpy():.1f}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "4c77b7cd-3d8c-4030-9063-d1e142c47b2f",
+ "metadata": {},
+ "source": [
+ "When decreasing the temperature, we observe a large increase in robustness, but a slight\n",
+ "decrease in accuracy.\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "e8353a6a-4591-45d6-aeb5-eb9c49231f09",
+ "metadata": {},
+ "source": [
+ "#### Hinge-KantorovichโRubinstein loss: `MulticlassHKR`\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "cbea6020-4102-4956-bbb4-87b169caaf39",
+ "metadata": {},
+ "source": [
+ "We work in the same way as in the previous section. The difference lies in the\n",
+ "parameters that control the robustness.\n",
+ "\n",
+ "We count two of them: `min_margin` (minimal margin) and `alpha` (regularization factor).\n",
+ "\n",
+ "As will be shown in the following, a higher minimal margin and a lower alpha increases\n",
+ "robustness.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "id": "a964983d-4347-4f57-8489-98e6a05cfd0c",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# performance-oriented model\n",
+ "model_3 = create_conv_model(\"HKR_model_3\", input_shape, output_shape)\n",
+ "\n",
+ "min_margin_3 = 0.01\n",
+ "alpha_3 = 1000\n",
+ "\n",
+ "model_3.compile(\n",
+ " loss=lip.losses.MulticlassHKR(min_margin=min_margin_3, alpha=alpha_3),\n",
+ " optimizer=Adam(1e-4),\n",
+ " metrics=[\n",
+ " \"accuracy\",\n",
+ " lip.metrics.CategoricalProvableAvgRobustness(disjoint_neurons=False),\n",
+ " ],\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "id": "5b57e8ad-19fc-4791-8e28-d0e613cbddc5",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# robustness-oriented model\n",
+ "model_4 = create_conv_model(\"HKR_model_4\", input_shape, output_shape)\n",
+ "\n",
+ "min_margin_4 = 0.2\n",
+ "alpha_4 = 50\n",
+ "\n",
+ "model_4.compile(\n",
+ " loss=lip.losses.MulticlassHKR(min_margin=min_margin_4, alpha=alpha_4),\n",
+ " optimizer=Adam(1e-4),\n",
+ " metrics=[\n",
+ " \"accuracy\",\n",
+ " lip.metrics.CategoricalProvableAvgRobustness(disjoint_neurons=False),\n",
+ " ],\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "3b9f6ce2-ec2b-4f4d-b282-918d434eb9b3",
+ "metadata": {},
+ "source": [
+ "We fit both our models and observe the results.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "id": "b976968d-4494-4935-8338-06ebd99b6c78",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 1/10\n",
+ "235/235 [==============================] - 4s 10ms/step - loss: 6.4450 - accuracy: 0.7472 - CategoricalProvableAvgRobustness: 0.0258 - val_loss: 1.7008 - val_accuracy: 0.9056 - val_CategoricalProvableAvgRobustness: 0.0361\n",
+ "Epoch 2/10\n",
+ "235/235 [==============================] - 2s 7ms/step - loss: 1.4576 - accuracy: 0.9168 - CategoricalProvableAvgRobustness: 0.0372 - val_loss: 1.0670 - val_accuracy: 0.9411 - val_CategoricalProvableAvgRobustness: 0.0411\n",
+ "Epoch 3/10\n",
+ "235/235 [==============================] - 2s 6ms/step - loss: 0.9756 - accuracy: 0.9410 - CategoricalProvableAvgRobustness: 0.0410 - val_loss: 0.7664 - val_accuracy: 0.9550 - val_CategoricalProvableAvgRobustness: 0.0431\n",
+ "Epoch 4/10\n",
+ "235/235 [==============================] - 2s 7ms/step - loss: 0.7388 - accuracy: 0.9532 - CategoricalProvableAvgRobustness: 0.0441 - val_loss: 0.6139 - val_accuracy: 0.9621 - val_CategoricalProvableAvgRobustness: 0.0462\n",
+ "Epoch 5/10\n",
+ "235/235 [==============================] - 2s 8ms/step - loss: 0.5798 - accuracy: 0.9605 - CategoricalProvableAvgRobustness: 0.0464 - val_loss: 0.4938 - val_accuracy: 0.9671 - val_CategoricalProvableAvgRobustness: 0.0477\n",
+ "Epoch 6/10\n",
+ "235/235 [==============================] - 2s 7ms/step - loss: 0.5032 - accuracy: 0.9645 - CategoricalProvableAvgRobustness: 0.0486 - val_loss: 0.4849 - val_accuracy: 0.9683 - val_CategoricalProvableAvgRobustness: 0.0512\n",
+ "Epoch 7/10\n",
+ "235/235 [==============================] - 2s 7ms/step - loss: 0.4317 - accuracy: 0.9688 - CategoricalProvableAvgRobustness: 0.0510 - val_loss: 0.4210 - val_accuracy: 0.9710 - val_CategoricalProvableAvgRobustness: 0.0521\n",
+ "Epoch 8/10\n",
+ "235/235 [==============================] - 2s 8ms/step - loss: 0.3550 - accuracy: 0.9722 - CategoricalProvableAvgRobustness: 0.0526 - val_loss: 0.3933 - val_accuracy: 0.9743 - val_CategoricalProvableAvgRobustness: 0.0535\n",
+ "Epoch 9/10\n",
+ "235/235 [==============================] - 2s 8ms/step - loss: 0.3358 - accuracy: 0.9730 - CategoricalProvableAvgRobustness: 0.0546 - val_loss: 0.2810 - val_accuracy: 0.9775 - val_CategoricalProvableAvgRobustness: 0.0565\n",
+ "Epoch 10/10\n",
+ "235/235 [==============================] - 2s 8ms/step - loss: 0.2858 - accuracy: 0.9758 - CategoricalProvableAvgRobustness: 0.0559 - val_loss: 0.3089 - val_accuracy: 0.9739 - val_CategoricalProvableAvgRobustness: 0.0590\n"
+ ]
+ }
+ ],
+ "source": [
+ "# fit the model\n",
+ "result_3 = model_3.fit(\n",
+ " X_train,\n",
+ " y_train,\n",
+ " batch_size=256,\n",
+ " epochs=10,\n",
+ " validation_data=(X_test, y_test),\n",
+ " shuffle=True,\n",
+ " # verbose=1,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "id": "8f19da69-9869-463e-8fac-b7f179e00315",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 1/10\n",
+ "235/235 [==============================] - 4s 9ms/step - loss: 1.9754 - accuracy: 0.8073 - CategoricalProvableAvgRobustness: 0.1097 - val_loss: 0.5700 - val_accuracy: 0.9194 - val_CategoricalProvableAvgRobustness: 0.1783\n",
+ "Epoch 2/10\n",
+ "235/235 [==============================] - 2s 8ms/step - loss: 0.3412 - accuracy: 0.9265 - CategoricalProvableAvgRobustness: 0.2227 - val_loss: 0.0606 - val_accuracy: 0.9375 - val_CategoricalProvableAvgRobustness: 0.2713\n",
+ "Epoch 3/10\n",
+ "235/235 [==============================] - 2s 9ms/step - loss: -0.0592 - accuracy: 0.9404 - CategoricalProvableAvgRobustness: 0.3110 - val_loss: -0.2699 - val_accuracy: 0.9468 - val_CategoricalProvableAvgRobustness: 0.3581\n",
+ "Epoch 4/10\n",
+ "235/235 [==============================] - 2s 9ms/step - loss: -0.3596 - accuracy: 0.9476 - CategoricalProvableAvgRobustness: 0.3993 - val_loss: -0.5494 - val_accuracy: 0.9537 - val_CategoricalProvableAvgRobustness: 0.4502\n",
+ "Epoch 5/10\n",
+ "235/235 [==============================] - 2s 9ms/step - loss: -0.5982 - accuracy: 0.9505 - CategoricalProvableAvgRobustness: 0.4850 - val_loss: -0.7551 - val_accuracy: 0.9548 - val_CategoricalProvableAvgRobustness: 0.5342\n",
+ "Epoch 6/10\n",
+ "235/235 [==============================] - 2s 8ms/step - loss: -0.7969 - accuracy: 0.9532 - CategoricalProvableAvgRobustness: 0.5602 - val_loss: -0.9409 - val_accuracy: 0.9573 - val_CategoricalProvableAvgRobustness: 0.6143\n",
+ "Epoch 7/10\n",
+ "235/235 [==============================] - 2s 8ms/step - loss: -0.9541 - accuracy: 0.9541 - CategoricalProvableAvgRobustness: 0.6243 - val_loss: -1.0462 - val_accuracy: 0.9578 - val_CategoricalProvableAvgRobustness: 0.6508\n",
+ "Epoch 8/10\n",
+ "235/235 [==============================] - 2s 8ms/step - loss: -1.0742 - accuracy: 0.9566 - CategoricalProvableAvgRobustness: 0.6734 - val_loss: -1.1781 - val_accuracy: 0.9592 - val_CategoricalProvableAvgRobustness: 0.7016\n",
+ "Epoch 9/10\n",
+ "235/235 [==============================] - 2s 8ms/step - loss: -1.1652 - accuracy: 0.9568 - CategoricalProvableAvgRobustness: 0.7114 - val_loss: -1.2754 - val_accuracy: 0.9624 - val_CategoricalProvableAvgRobustness: 0.7390\n",
+ "Epoch 10/10\n",
+ "235/235 [==============================] - 2s 8ms/step - loss: -1.2430 - accuracy: 0.9577 - CategoricalProvableAvgRobustness: 0.7454 - val_loss: -1.3300 - val_accuracy: 0.9602 - val_CategoricalProvableAvgRobustness: 0.7708\n"
+ ]
+ }
+ ],
+ "source": [
+ "# fit the model\n",
+ "result_4 = model_4.fit(\n",
+ " X_train,\n",
+ " y_train,\n",
+ " batch_size=256,\n",
+ " epochs=10,\n",
+ " validation_data=(X_test, y_test),\n",
+ " shuffle=True,\n",
+ " # verbose=1,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "id": "5bfe7641-f7af-4ac3-9ff4-fca3300448ff",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Model accuracy: 0.9739\n",
+ "Model's mean certificate: 0.0590\n",
+ "Loss' minimum margin: 0.01\n",
+ "Loss' alpha: 1000.0\n"
+ ]
+ }
+ ],
+ "source": [
+ "# performance-oriented model\n",
+ "print(f\"Model accuracy: {result_3.history['val_accuracy'][-1]:.4f}\")\n",
+ "print(\n",
+ " f\"Model's mean certificate: {result_3.history['val_CategoricalProvableAvgRobustness'][-1]:.4f}\"\n",
+ ")\n",
+ "print(f\"Loss' minimum margin: {model_3.loss.min_margin.numpy():.2f}\")\n",
+ "print(f\"Loss' alpha: {model_3.loss.alpha.numpy():.1f}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "id": "7ce20cb8-797b-4e03-8600-d0030adfccc1",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Model accuracy: 0.9602\n",
+ "Model's mean certificate: 0.7708\n",
+ "Loss' minimum margin: 0.2\n",
+ "Loss' alpha: 50.0\n"
+ ]
+ }
+ ],
+ "source": [
+ "# robustness-oriented model\n",
+ "print(f\"Model accuracy: {result_4.history['val_accuracy'][-1]:.4f}\")\n",
+ "print(\n",
+ " f\"Model's mean certificate: {result_4.history['val_CategoricalProvableAvgRobustness'][-1]:.4f}\"\n",
+ ")\n",
+ "print(f\"Loss' minimum margin: {model_4.loss.min_margin.numpy():.1f}\")\n",
+ "print(f\"Loss' alpha: {model_4.loss.alpha.numpy():.1f}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "00e39a0d-a145-4a44-85c3-5f67e64ce444",
+ "metadata": {},
+ "source": [
+ "We confirmed experimentally the accuracy-robustness trade-off: a higher minimal margin\n",
+ "and a lower alpha increases robustness, but also decreases accuracy.\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "c4857464-3ed3-436e-8853-9ed67d0ce5c6",
+ "metadata": {},
+ "source": [
+ "## ๐ Congratulations\n",
+ "\n",
+ "You now know how to train provable adversarially robust 1-Lipschitz neural networks!\n",
+ "\n",
+ "๐ Interested readers can learn more about the role of loss functions and the\n",
+ "accuracy-robustness trade-off which occurs when training adversarially robust\n",
+ "1-Lipschitz neural network in the following paper: \n",
+ " [Pay attention to your loss: understanding misconceptions about 1-Lipschitz neural networks](https://arxiv.org/abs/2104.05097).\n"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "venv_tf2.12",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.18"
+ },
+ "toc": {
+ "base_numbering": 1,
+ "nav_menu": {},
+ "number_sections": true,
+ "sideBar": true,
+ "skip_h1_title": false,
+ "title_cell": "Table of Contents",
+ "title_sidebar": "Contents",
+ "toc_cell": false,
+ "toc_position": {},
+ "toc_section_display": true,
+ "toc_window_display": true
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/mkdocs.yml b/mkdocs.yml
index 653174e9..a1a3acfc 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -15,6 +15,7 @@ nav:
- deel.lip.utils: api/utils.md
- Tutorials:
- "Getting started 1 - Creating a 1-Lipschitz neural network": notebooks/Getting_started_1.ipynb
+ - "Getting started 2 - Training an adversarially robust 1-Lipschitz neural network": notebooks/Getting_started_2.ipynb
- "Demo 0: Example & Usage": notebooks/demo0.ipynb
- "Demo 1: Wasserstein distance estimation on toy example": notebooks/demo1.ipynb
- "Demo 2: HKR Classifier on toy dataset": notebooks/demo2.ipynb