Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changes related to the second getting-started notebook, including the… #84

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ supports tensorflow versions 2.x.
| **Tutorial Name** | Notebook |
| :-------------------------- | :----------------------------------------------------------------------------------------------------------------------------------------------------------------: |
| Getting Started 1 - Creating a 1-Lipschitz neural network | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/deel-ai/deel-lip/blob/master/docs/notebooks/Getting_started_1.ipynb) |
| Getting Started 2 - Training an adversarially robust 1-Lipschitz neural network | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/deel-ai/deel-lip/blob/master/docs/notebooks/Getting_started_2.ipynb) |
| Wasserstein distance estimation on toy example | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/deel-ai/deel-lip/blob/master/docs/notebooks/demo1.ipynb) |
| HKR Classifier on toy dataset | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/deel-ai/deel-lip/blob/master/docs/notebooks/demo2.ipynb) |
| HKR classifier on MNIST dataset | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/deel-ai/deel-lip/blob/master/docs/notebooks/demo3.ipynb) |
Expand Down
324 changes: 324 additions & 0 deletions deel/lip/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@
Reduction,
)
from tensorflow.keras.utils import register_keras_serializable
from deel.lip.layers.base_layer import LipschitzLayer
import logging

logging.basicConfig(level=logging.WARNING)


@register_keras_serializable("deel-lip", "_kr")
Expand Down Expand Up @@ -561,3 +565,323 @@ def get_config(self):
config = {"tau": self.tau.numpy()}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))


@register_keras_serializable("deel-lip", "Certificate_Binary")
class Certificate_Binary(Loss):
def __init__(
self, model, K=None, reduction=Reduction.AUTO, name="Certificate_Binary"
):
r"""
Certificate-based loss in the context of binary classification,
which is inspired from the paper
'Improved deterministic l2 robustness on CIFAR-10 and CIFAR-100'
(https://openreview.net/forum?id=tD7eCtaSkR).

Our formula is as follow:

$$\text{certificate} = |\hat{y}|/K$$

where $\hat{y}$ is the single logit predicted by our model,
and $K$ is the Lipschitz constant associated with the model.

Note that `y_true` and `y_pred` must be of rank 2:
(batch_size, 1).


Args:
model: A tensorflow multi-layered model.
K: The Lipschitz constant of the model.
It is calculated using the model if
not provided by a user upon instanciation.
reduction: passed to tf.keras.Loss constructor
name (str): passed to tf.keras.Loss constructor

"""
self.model = model
check_last_layer(model)
if K is None:
self.K = tf.constant(get_K_(get_layers(self.model)), dtype=tf.float32)
else:
self.K = tf.constant(K, dtype=tf.float32)
super(Certificate_Binary, self).__init__(reduction=reduction, name=name)

@tf.function
def call(self, y_true, y_pred):
return tf.abs(y_pred[:, 0]) / self.K

def get_config(self):
config = {"model": self.model, "K": self.K.numpy()}
base_config = super(Certificate_Binary, self).get_config()
return dict(list(base_config.items()) + list(config.items()))


@register_keras_serializable("deel-lip", "Certificate_Multiclass")
class Certificate_Multiclass(Loss):
def __init__(
self,
model,
K_n_minus_1=None,
reduction=Reduction.AUTO,
name="Certificate_Multiclass",
):
r"""
Certificate-based loss, which is inspired from the paper
'Improved deterministic l2 robustness on CIFAR-10 and CIFAR-100'
(https://openreview.net/forum?id=tD7eCtaSkR).

Our formula is as follow:

$$\text{certificate} = \min_i \frac{ \hat{y}_{\pi_1} -
\hat{y}_{\pi_i}}{||\mathbf{W}_{n,\pi_1}-\mathbf{W}_{n,\pi_i}||K_{n-1}}
$$

where $\pi$ is the permutation of {1, ..., C} - C being the number of labels
that sorts the elements of $\hat{y}$ from most likely to least likely,
where $|\mathbf{W}_{n,k}$ designates the weights of the k-th
perceptron of the last layer (n being the number of layers),
and $K_{n-1}$ is the Lipschitz constant associated with the
first n-1 layers of the model.

Note that `y_true` and `y_pred` must be of rank 2:
(batch_size, C) for multilabel classification with C categories


Args:
model: A tensorflow multi-layered model.
K_n_minus_1 : The Lipschitz constant of the n - 1 first layers,
where n is the total number of layers.
It is calculated using the model if not provided by
a user upon instanciation.
reduction: passed to tf.keras.Loss constructor
name (str): passed to tf.keras.Loss constructor

"""
self.model = model
check_last_layer(model)
if K_n_minus_1 is None:
self.K_n_minus_1 = tf.constant(
get_K_(get_layers(self.model)[:-1]), dtype=tf.float32
)
else:
self.K_n_minus_1 = tf.constant(K_n_minus_1, dtype=tf.float32)
super(Certificate_Multiclass, self).__init__(reduction=reduction, name=name)

@tf.function
def call(self, y_true, y_pred):
last_layer_weights = get_last_layer(self.model).kernel
num_classes = tf.shape(y_true)[1]
sorted_indices = get_sorted_logits_indices_tensor(y_pred)
certificate = tf.zeros(len(y_pred), dtype=tf.float32)
num_classes = tf.shape(y_true)[1]
indices_batch_size = tf.range(len(y_pred), dtype=tf.int32)
indices_num_classes = tf.range(1, num_classes, dtype=tf.int32)
for j in indices_batch_size:
min_value = tf.constant(
float("inf"), dtype=tf.float32
) # Initialize with positive infinity
for i in indices_num_classes:
numerator = (
y_pred[j][sorted_indices[j][0]] - y_pred[j][sorted_indices[j][i]]
)
denominator = tf.norm(
last_layer_weights[:, sorted_indices[j][0]]
- last_layer_weights[:, sorted_indices[j][i]]
)
formula_value = numerator / (denominator * self.K_n_minus_1)
min_value = tf.minimum(min_value, formula_value)
certificate = tf.tensor_scatter_nd_update(
certificate,
indices=tf.expand_dims(tf.expand_dims(j, axis=0), axis=0),
updates=tf.expand_dims(min_value, axis=0),
)
return certificate

def get_config(self):
config = {"model": self.model, "K_n_minus_1": self.K_n_minus_1.numpy()}
base_config = super(Certificate_Multiclass, self).get_config()
return dict(list(base_config.items()) + list(config.items()))


def get_K_(layers):
check_is_Lipschitz(layers)
K_ = 1
for layer in layers: # Print information about each layer
if isinstance(layer, LipschitzLayer):
K_ = layer.k_coef_lip * K_
else:
pass
return K_


GLOBAL_CONSTANTS = {
"supported_neutral_layers": ["Flatten", "InputLayer"],
"not_deel": [
"dense",
"average_pooling2d",
"global_average_pooling2d",
"conv2d",
], # We don't use layer.__class__.__name__ to \
# find these as for Conv2D and GlobalAveragePooling2D, \
# it results in 'type'
"not_Lipschitz": [
"Dropout",
"ELU",
"LeakyReLU",
"ThresholdedReLU",
"BatchNormalization",
],
"unrecommended_activation_functions": [
tf.keras.activations.relu,
tf.keras.activations.softmax,
tf.keras.activations.exponential,
tf.keras.activations.elu,
tf.keras.activations.selu,
tf.keras.activations.tanh,
tf.keras.activations.sigmoid,
tf.keras.activations.softplus,
tf.keras.activations.softsign,
],
"min_max_norm": tf.keras.constraints.MinMaxNorm,
"recommended_activation_names": [
"group_sort2",
"full_sort",
"group_sort",
"householder",
"max_min",
"p_re_lu",
],
"unrecommended_activation_names": ["ReLU"],
"no_activation": tf.keras.activations.linear,
}


def get_sorted_logits_indices_tensor(model_output):
# Sort the model outputs model.predict(x)
sorted_indices = tf.argsort(model_output, axis=1, direction="DESCENDING")
return sorted_indices


def get_layers(model):
return model.layers


def get_weights_last_layer(model):
return get_last_layer(model).get_weights()[0]


def get_sorted_logits_indices(model_output):
# Sort the model outputs model.predict(x)
sorted_indices = np.argsort(model_output, axis=1)[:, ::-1]
return sorted_indices


def get_last_layer(model):
return get_layers(model)[-1]


def check_last_layer(model):
last_layer = get_last_layer(model)
if not hasattr(
last_layer, "get_weights"
): # check last layer is a layer with weights
raise BadLastLayerError(
"The last layer '%s' must have a set of \
weights to calculate the certificate."
% last_layer.name
)
if last_layer.get_weights() == []:
raise BadLastLayerError(
"The last layer '%s' must have a set of weights \
to calculate the certificate."
% last_layer.name
)
# check last layer has no activation function set
activation = getattr(last_layer, "activation")
if activation != GLOBAL_CONSTANTS["no_activation"]:
logging.warning(
"We recommend avoiding using an activation \
function for the last layer (here the '%s' activation function of the layer '%s').\n"
% (activation, last_layer.name)
)


class NotLipschtzLayerError(Exception):
pass


class BadLastLayerError(Exception):
pass


def check_is_Lipschitz(layers):
for i, layer in enumerate(layers):
check_activation_layer(layer)
if layer.__class__.__name__ in GLOBAL_CONSTANTS["supported_neutral_layers"]:
pass
elif isinstance(layer, LipschitzLayer):
pass
elif (
layer.__class__.__name__ in GLOBAL_CONSTANTS["not_Lipschitz"]
): # triggers when using none Lipschitz layers such as "batch_normalization"
raise NotLipschtzLayerError("The layer '%s' is not supported" % layer.name)
print("ok")
elif any(
layer.name.startswith(substring)
for substring in GLOBAL_CONSTANTS["not_deel"]
):
logging.warning(
"A deel equivalent exists for '%s'. For practical \
purposes, we will assume that the layer is 1-Lipschitz."
% layer.name
)
elif (
layer.__class__.__name__
in GLOBAL_CONSTANTS["unrecommended_activation_names"]
):
logging.warning(
"The layer '%s' is not recommended. \
For practical purposes, we recommend to use deel lip \
activation layer instead such as GroupSort2.\n"
% (layer.name)
)
else:
logging.warning(
"Unknown layer '%s' used. For practical purposes, \
we will assume that the layer is 1-Lipschitz."
% layer.name
)


def check_activation_layer(layer):
if hasattr(layer, "activation"):
activation = getattr(layer, "activation")
if activation != GLOBAL_CONSTANTS["no_activation"]:
if activation in GLOBAL_CONSTANTS["unrecommended_activation_functions"]:
logging.warning(
"The '%s' activation function of the layer '%s' is not recommended.\
For practical purposes, we recommend to use deel lip activation \
functions instead such as GroupSort2.\n"
% (activation, layer.name)
)
return None
if isinstance(activation, GLOBAL_CONSTANTS["min_max_norm"]):
return None
elif hasattr(activation, "name"):
n = activation.name
if (
layer.activation.__class__.__name__
in GLOBAL_CONSTANTS["recommended_activation_names"]
):
return None
else:
print(
"The '%s' activation function of the layer '%s' is unknown. \
We will assume it is 1-Lipschitz.\n"
% (n, layer.name)
)
else:
logging.warning(
"The '%s' activation function of the layer '%s' is unknown.\n"
% (activation, layer.name)
)
Binary file added docs/assets/noise.PNG
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/assets/pigs.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ supports tensorflow versions 2.x.
| **Tutorial Name** | Notebook |
| :-------------------------- | :----------------------------------------------------------------------------------------------------------------------------------------------------------------: |
| Getting started 1 - Creating a 1-Lipschitz neural network | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/deel-ai/deel-lip/blob/master/docs/notebooks/Getting_started_1.ipynb) |
| Getting started 2 - Training an adversarially robust 1-Lipschitz neural network | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/deel-ai/deel-lip/blob/master/docs/notebooks/Getting_started_2.ipynb) |
| Wasserstein distance estimation on toy example | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/deel-ai/deel-lip/blob/master/docs/notebooks/demo1.ipynb) |
| HKR Classifier on toy dataset | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/deel-ai/deel-lip/blob/master/docs/notebooks/demo2.ipynb) |
| HKR classifier on MNIST dataset | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/deel-ai/deel-lip/blob/master/docs/notebooks/demo3.ipynb) |
Expand Down
Loading