Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Getting-started with deel-lip #78

Closed
wants to merge 19 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ supports tensorflow versions 2.x.

| **Tutorial Name** | Notebook |
| :-------------------------- | :----------------------------------------------------------------------------------------------------------------------------------------------------------------: |
| Getting Started | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/deel-ai/deel-lip/blob/master/docs/notebooks/demo0.ipynb) |
| Getting Started 1 - Creating a 1-Lipschitz neural network | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/deel-ai/deel-lip/blob/master/docs/notebooks/Getting_started_1.ipynb) |
| Getting Started 2 - Training an adversarially robust 1-Lipschitz neural network | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/deel-ai/deel-lip/blob/master/docs/notebooks/Getting_started_2.ipynb) |
| Wasserstein distance estimation on toy example | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/deel-ai/deel-lip/blob/master/docs/notebooks/demo1.ipynb) |
| HKR Classifier on toy dataset | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/deel-ai/deel-lip/blob/master/docs/notebooks/demo2.ipynb) |
| HKR classifier on MNIST dataset | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/deel-ai/deel-lip/blob/master/docs/notebooks/demo3.ipynb) |
Expand Down
294 changes: 290 additions & 4 deletions deel/lip/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
from tensorflow.keras.utils import register_keras_serializable





@register_keras_serializable("deel-lip", "_kr")
def _kr(y_true, y_pred, epsilon):
"""Returns the element-wise KR loss.
Expand Down Expand Up @@ -62,10 +65,7 @@ def __init__(self, multi_gpu=False, reduction=Reduction.AUTO, name="KR"):
The Kantorovich-Rubinstein duality is formulated as following:

$$
W_1(\mu, \nu) =
\sup_{f \in Lip_1(\Omega)} \underset{\textbf{x} \sim \mu}{\mathbb{E}}
\left[f(\textbf{x} )\right] -
\underset{\textbf{x} \sim \nu}{\mathbb{E}} \left[f(\textbf{x} )\right]
\text{certificate} = \min_i \frac{ \hat{y}_{\pi_1} - \hat{y}_{\pi_i}}{||\mathbf{W}_{n,\pi_1}-\mathbf{W}_{n,\pi_i}||}
$$

Where mu and nu stands for the two distributions, the distribution where the
Expand Down Expand Up @@ -493,3 +493,289 @@ def get_config(self):
config = {"tau": self.tau.numpy()}
base_config = super(TauCategoricalCrossentropy, self).get_config()
return dict(list(base_config.items()) + list(config.items()))

















from deel.lip.layers.base_layer import LipschitzLayer

@register_keras_serializable("deel-lip", "Certificate_Binary")
class Certificate_Binary(Loss):
def __init__(self, model, K=None, reduction=Reduction.AUTO,name="Certificate_Binary"):
r"""
Certificate-based loss in the context of binary classification,
which is inspired from the paper 'Improved deterministic l2 robustness on CIFAR-10 and CIFAR-100'
(https://openreview.net/forum?id=tD7eCtaSkR).

Our formula is as follow:


$$\text{certificate} = |\hat{y}|/K$$

where $\hat{y}$ is the single logit predicted by our model,
and $K$ is the Lipschitz constant associated with the model.

Note that `y_true` and `y_pred` must be of rank 2:
(batch_size, C) for multilabel classification with C categories


Args:
model: A tensorflow multi-layered model.
K : The Lipschitz constant of the model.
It is calculated using the model if not provided by a user upon instanciation.
reduction: passed to tf.keras.Loss constructor
name (str): passed to tf.keras.Loss constructor

"""
self.model=model

if K is None:
self.K = tf.constant(get_K_(get_layers(self.model)), dtype=tf.float32)

else:
self.K = tf.constant(K, dtype=tf.float32)

super(Certificate_Binary, self).__init__(reduction=reduction, name=name)

@tf.function
def call(self, y_true, y_pred):


return tf.abs(y_pred[:, 0]) / self.K

def get_config(self):
config = {"model":self.model, "K": self.K.numpy()}
base_config = super(Certificate_Binary, self).get_config()
return dict(list(base_config.items()) + list(config.items()))


@register_keras_serializable("deel-lip", "Certificate_Multiclass")
class Certificate_Multiclass(Loss):
def __init__(self, model, K_n_minus_1=None, reduction=Reduction.AUTO,name="Certificate_Multiclass"):
r"""
Certificate-based loss, which is inspired from the paper 'Improved deterministic l2 robustness on CIFAR-10 and CIFAR-100'
(https://openreview.net/forum?id=tD7eCtaSkR).

Our formula is as follow:


$$\text{certificate} = \min_i \frac{ \hat{y}_{\pi_1} - \hat{y}_{\pi_i}}{||\mathbf{W}_{n,\pi_1}-\mathbf{W}_{n,\pi_i}||K_{n-1}}
$$

where $\pi$ is the permutation of {1, ..., C} - C being the number of labels - that sorts the elements of $\hat{y}$
from most likely to least likely,
where $|\mathbf{W}_{n,k}$ designates the weights of the k-th perceptron of the last layer (n being the number of layers),
and $K_{n-1}$ is the Lipschitz constant associated with the first n-1 layers of the model.

Note that `y_true` and `y_pred` must be of rank 2:
(batch_size, C) for multilabel classification with C categories


Args:
model: A tensorflow multi-layered model.
K_n_minus_1 : The Lipschitz constant of the n - 1 first layers, where n is the total number of layers.
It is calculated using the model if not provided by a user upon instanciation.
reduction: passed to tf.keras.Loss constructor
name (str): passed to tf.keras.Loss constructor

"""
self.model=model

if K_n_minus_1 is None:
self.K_n_minus_1 = tf.constant(get_K_(get_layers(self.model)[:-1]), dtype=tf.float32)

else:
self.K_n_minus_1 = tf.constant(K_n_minus_1, dtype=tf.float32)

self.last_layer_weights = get_weights_last_layer_tensor(self.model)

super(Certificate_Multiclass, self).__init__(reduction=reduction, name=name)

@tf.function
def call(self, y_true, y_pred):

num_classes = tf.shape(y_true)[1]

sorted_indices = get_sorted_logits_indices_tensor(y_pred)

certificate = tf.zeros(len(y_pred), dtype=tf.float32)

num_classes = tf.shape(y_true)[1]

indices_batch_size = tf.range(len(y_pred), dtype=tf.int32)

indices_num_classes = tf.range(1, num_classes, dtype=tf.int32)

for j in indices_batch_size:

min_value = tf.constant(float('inf'), dtype=tf.float32) # Initialize with positive infinity

for i in indices_num_classes:

numerator = y_pred[j][sorted_indices[j][0]] - y_pred[j][sorted_indices[j][i]]

# Assuming last_layer_weights is a TensorFlow tensor
denominator = tf.norm(self.last_layer_weights[:, sorted_indices[j][0]] - self.last_layer_weights[:, sorted_indices[j][i]])

formula_value = numerator / (denominator * self.K_n_minus_1)

min_value = tf.minimum(min_value, formula_value)

certificate = tf.tensor_scatter_nd_update(certificate, indices=tf.expand_dims(tf.expand_dims(j, axis=0), axis=0), updates=tf.expand_dims(min_value, axis=0))

return certificate

def get_config(self):
config = {"model":self.model, "K_n_minus_1": self.K_n_minus_1.numpy()}
base_config = super(Certificate_Multiclass, self).get_config()
return dict(list(base_config.items()) + list(config.items()))


def get_K_(layers):

check_is_Lipschitz(layers)

K_ = 1
# Print information about each layer
for layer in layers:
if isinstance(layer,LipschitzLayer):
K_ = layer.k_coef_lip * K_
else:
pass
return K_

# check last layer is a layer with weights
if not hasattr(last_layer,'get_weights'):
raise BadLastLayerError("The last layer '%s' must have a set of weights to calculate the certificate." %last_layer.name)

if last_layer.get_weights()==[]:
raise BadLastLayerError("The last layer '%s' must have a set of weights to calculate the certificate." %last_layer.name)
# check last layer has no activation function set
activation = getattr(last_layer, 'activation')
if activation!=GLOBAL_CONSTANTS["no_activation"]:
logging.warning("We recommend avoiding using an activation \
function for the last layer (here the '%s' activation function of the layer '%s').\n"%(activation,last_layer.name))



GLOBAL_CONSTANTS = {
"supported_neutral_layers": ["Flatten", "InputLayer"],#"KerasTensor"

"not_deel": [
"dense",
"average_pooling2d",
"global_average_pooling2d",
"conv2d"
], # We don't use layer.__class__.__name__ to find these as for Conv2D and GlobalAveragePooling2D, it results in 'type'

"not_Lipschitz": [
"Dropout",
"ELU",
"LeakyReLU",
"ThresholdedReLU",
"BatchNormalization",
],



"unrecommended_activation_functions": [tf.keras.activations.relu, tf.keras.activations.softmax,\
tf.keras.activations.exponential,tf.keras.activations.elu,\
tf.keras.activations.selu,tf.keras.activations.tanh, \
tf.keras.activations.sigmoid, tf.keras.activations.softplus, tf.keras.activations.softsign],
"min_max_norm":tf.keras.constraints.MinMaxNorm,
"recommended_activation_names": ['group_sort2', 'full_sort', 'group_sort', 'householder', 'max_min', 'p_re_lu'],
"unrecommended_activation_names": ['ReLU'],
"no_activation": tf.keras.activations.linear ,
}

def get_weights_last_layer_tensor(model):
last_layer_weights = get_last_layer(model).get_weights()[0]
return tf.convert_to_tensor(last_layer_weights, dtype=tf.float32)

def get_sorted_logits_indices_tensor(model_output):
# Sort the model outputs model.predict(x)
sorted_indices = tf.argsort(model_output, axis=1, direction='DESCENDING')
return sorted_indices

def get_layers(model):
return model.layers

def get_last_layer(model):
return get_layers(model)[-1]

def check_last_layer(model):
last_layer=get_last_layer(model)

class NotLipschtzLayerError(Exception):
pass
class BadLastLayerError(Exception):
pass

def check_is_Lipschitz(layers):

for i,layer in enumerate(layers):
check_activation_layer(layer)
# print(layer.__class__.__name__)
if layer.__class__.__name__ in GLOBAL_CONSTANTS["supported_neutral_layers"]:
# print("layer neutral")

pass
elif isinstance(layer,LipschitzLayer):
pass

elif layer.__class__.__name__ in GLOBAL_CONSTANTS["not_Lipschitz"]:
# triggers when using none Lipschitz layers such as "batch_normalization"
raise NotLipschtzLayerError("The layer '%s' is not supported" %layer.name)
print("ok")
elif any(layer.name.startswith(substring) for substring in GLOBAL_CONSTANTS["not_deel"]):
#print("Layer %s not deel." %layer.name)
logging.warning("A deel equivalent exists for '%s'. For practical purposes, we will assume that the layer is 1-Lipschitz." %layer.name)
elif layer.__class__.__name__ in GLOBAL_CONSTANTS["unrecommended_activation_names"]:
# triggers when using tf.keras.layers.ReLU()
logging.warning("The layer '%s' is not recommended. \
For practical purposes, we recommend to use deel lip activation layer instead such as GroupSort2.\n" %(layer.name))
else:
logging.warning("Unknown layer '%s' used. For practical purposes, we will assume that the layer is 1-Lipschitz." %layer.name)

def check_activation_layer(layer):
if hasattr(layer, 'activation'):

activation = getattr(layer, 'activation')

if activation!=GLOBAL_CONSTANTS["no_activation"]:


if activation in GLOBAL_CONSTANTS["unrecommended_activation_functions"]:
# triggers when calling unrecommended activation function e.g.
# lip.layers.SpectralDense(64, activation='selu')(x)
logging.warning("The '%s' activation function of the layer '%s' is not recommended. \
For practical purposes, we recommend to use deel lip activation functions instead such as GroupSort2.\n" %(activation,layer.name))
return None

if isinstance(activation, GLOBAL_CONSTANTS["min_max_norm"]):
return None
elif hasattr(activation, 'name'):

n=activation.name
# print(n)
if layer.activation.__class__.__name__ in GLOBAL_CONSTANTS["recommended_activation_names"]:
return None
else:
print("The '%s' activation function of the layer '%s' is unknown. We will assume it is 1-Lipschitz.\n" %(n,layer.name))

else:

logging.warning("The '%s' activation function of the layer '%s' is unknown.\n" %(activation,layer.name))
Binary file added docs/assets/noise.PNG
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/assets/pigs.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 2 additions & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ supports tensorflow versions 2.x.

| **Tutorial Name** | Notebook |
| :-------------------------- | :----------------------------------------------------------------------------------------------------------------------------------------------------------------: |
| Getting Started | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/deel-ai/deel-lip/blob/master/docs/notebooks/demo0.ipynb) |
| Getting started 1 - Creating a 1-Lipschitz neural network | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/deel-ai/deel-lip/blob/master/docs/notebooks/Getting_started_1.ipynb) |
| Getting started 2 - Training an adversarially robust 1-Lipschitz neural network | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/deel-ai/deel-lip/blob/master/docs/notebooks/Getting_started_2.ipynb) |
| Wasserstein distance estimation on toy example | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/deel-ai/deel-lip/blob/master/docs/notebooks/demo1.ipynb) |
| HKR Classifier on toy dataset | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/deel-ai/deel-lip/blob/master/docs/notebooks/demo2.ipynb) |
| HKR classifier on MNIST dataset | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/deel-ai/deel-lip/blob/master/docs/notebooks/demo3.ipynb) |
Expand Down
Loading