Skip to content

Commit

Permalink
Merge pull request #7 from Sharing-Sam-Work/getting-started
Browse files Browse the repository at this point in the history
Getting started
  • Loading branch information
Sharing-Sam-Work authored Aug 9, 2023
2 parents 29c0108 + 2a7d956 commit c625d2e
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 27 deletions.
2 changes: 2 additions & 0 deletions deel/lip/layers/convolutional.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@
from keras.utils import conv_utils # in Keras for TF >= 2.6
except ModuleNotFoundError:
from tensorflow.python.keras.utils import conv_utils # in TF.python for TF <= 2.5
except ImportError:
from tensorflow.python.keras.utils import conv_utils # in TF.python for TF <= 2.5


def _compute_conv_lip_factor(kernel_size, strides, input_shape, data_format):
Expand Down
4 changes: 2 additions & 2 deletions docs/notebooks/Getting_started_1.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@
"id": "de38b069-1705-408c-8dd3-99ed17cf519f",
"metadata": {},
"source": [
"- specify the Lipschitz constant of each layer of the model through the `k_coef_lip` attribute of the `Sequential` object, e.g.:"
"- specify the Lipschitz constant of the whole model through the `k_coef_lip` attribute of the `Sequential` object, e.g.:"
]
},
{
Expand All @@ -261,7 +261,7 @@
"K1_model = lip.model.Sequential([ \n",
" ....\n",
" ],\n",
" # This parameter sets the Lipschitz constant of each layer. Its value is 1 by default.\n",
" # This parameter sets the Lipschitz constant of the whole model. Its value is 1 by default.\n",
" k_coef_lip=1,\n",
")"
]
Expand Down
50 changes: 25 additions & 25 deletions docs/notebooks/Getting_started_2.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 1,
"id": "d3cf25c6-1691-4935-bdac-9f182a87ef32",
"metadata": {},
"outputs": [],
Expand All @@ -126,7 +126,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 2,
"id": "90da93e3-9c59-4e72-b817-01f97b324c51",
"metadata": {},
"outputs": [],
Expand Down Expand Up @@ -165,7 +165,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 3,
"id": "359fa47e-86f3-440d-971f-0a49fe7dcdc7",
"metadata": {},
"outputs": [],
Expand Down Expand Up @@ -417,16 +417,16 @@
},
{
"cell_type": "code",
"execution_count": 28,
"execution_count": 13,
"id": "a964983d-4347-4f57-8489-98e6a05cfd0c",
"metadata": {},
"outputs": [],
"source": [
"# performance-oriented model\n",
"model_3 = create_conv_model(\"HKR_model_3\")\n",
"\n",
"min_margin_3=1\n",
"alpha_3=30\n",
"min_margin_3=0.1\n",
"alpha_3=50\n",
"\n",
"model_3.compile(\n",
" loss=lip.losses.MulticlassHKR(min_margin=min_margin_3,alpha=alpha_3),\n",
Expand All @@ -437,16 +437,16 @@
},
{
"cell_type": "code",
"execution_count": 39,
"execution_count": 14,
"id": "5b57e8ad-19fc-4791-8e28-d0e613cbddc5",
"metadata": {},
"outputs": [],
"source": [
"# robustness-oriented model\n",
"model_4 = create_conv_model(\"HKR_model_4\")\n",
"\n",
"min_margin_4=3\n",
"alpha_4=10\n",
"min_margin_4=1\n",
"alpha_4=30\n",
"\n",
"model_4.compile(\n",
" loss=lip.losses.MulticlassHKR(min_margin=min_margin_4,alpha=alpha_4),\n",
Expand All @@ -465,7 +465,7 @@
},
{
"cell_type": "code",
"execution_count": 30,
"execution_count": 15,
"id": "b976968d-4494-4935-8338-06ebd99b6c78",
"metadata": {},
"outputs": [
Expand All @@ -474,9 +474,9 @@
"output_type": "stream",
"text": [
"Epoch 1/2\n",
"235/235 [==============================] - 26s 96ms/step - loss: 10.3240 - accuracy: 0.7246 - MulticlassKR: 0.8375 - val_loss: 3.9340 - val_accuracy: 0.8969 - val_MulticlassKR: 1.3823\n",
"235/235 [==============================] - 23s 91ms/step - loss: 0.8637 - accuracy: 0.8007 - MulticlassKR: 0.2155 - val_loss: 0.1633 - val_accuracy: 0.9228 - val_MulticlassKR: 0.3141\n",
"Epoch 2/2\n",
"235/235 [==============================] - 22s 92ms/step - loss: 2.9818 - accuracy: 0.9019 - MulticlassKR: 1.5695 - val_loss: 1.9182 - val_accuracy: 0.9244 - val_MulticlassKR: 1.7646\n"
"235/235 [==============================] - 21s 90ms/step - loss: 0.0387 - accuracy: 0.9298 - MulticlassKR: 0.3854 - val_loss: -0.1238 - val_accuracy: 0.9444 - val_MulticlassKR: 0.4765\n"
]
}
],
Expand All @@ -495,7 +495,7 @@
},
{
"cell_type": "code",
"execution_count": 40,
"execution_count": 16,
"id": "8f19da69-9869-463e-8fac-b7f179e00315",
"metadata": {},
"outputs": [
Expand All @@ -504,9 +504,9 @@
"output_type": "stream",
"text": [
"Epoch 1/2\n",
"235/235 [==============================] - 25s 95ms/step - loss: 16.5619 - accuracy: 0.5336 - MulticlassKR: 1.1116 - val_loss: 8.8919 - val_accuracy: 0.7831 - val_MulticlassKR: 2.0187\n",
"235/235 [==============================] - 23s 89ms/step - loss: 10.4944 - accuracy: 0.7172 - MulticlassKR: 0.8307 - val_loss: 3.8939 - val_accuracy: 0.8901 - val_MulticlassKR: 1.3772\n",
"Epoch 2/2\n",
"235/235 [==============================] - 21s 91ms/step - loss: 7.6997 - accuracy: 0.7989 - MulticlassKR: 2.2113 - val_loss: 6.5351 - val_accuracy: 0.8436 - val_MulticlassKR: 2.4005\n"
"235/235 [==============================] - 21s 89ms/step - loss: 2.9442 - accuracy: 0.8982 - MulticlassKR: 1.5552 - val_loss: 1.9326 - val_accuracy: 0.9199 - val_MulticlassKR: 1.7427\n"
]
}
],
Expand All @@ -525,18 +525,18 @@
},
{
"cell_type": "code",
"execution_count": 34,
"execution_count": 17,
"id": "5bfe7641-f7af-4ac3-9ff4-fca3300448ff",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model accuracy: 0.9244\n",
"Model MulticlassKR: 1.7646\n",
"Loss' minimum margin: 1.0\n",
"Loss' alpha: 30.0\n"
"Model accuracy: 0.9444\n",
"Model MulticlassKR: 0.4765\n",
"Loss' minimum margin: 0.1\n",
"Loss' alpha: 50.0\n"
]
}
],
Expand All @@ -550,18 +550,18 @@
},
{
"cell_type": "code",
"execution_count": 41,
"execution_count": 18,
"id": "7ce20cb8-797b-4e03-8600-d0030adfccc1",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model accuracy: 0.8436\n",
"Model MulticlassKR: 2.4005\n",
"Loss' minimum margin: 3.0\n",
"Loss' alpha: 10.0\n"
"Model accuracy: 0.9199\n",
"Model MulticlassKR: 1.7427\n",
"Loss' minimum margin: 1.0\n",
"Loss' alpha: 30.0\n"
]
}
],
Expand Down

0 comments on commit c625d2e

Please sign in to comment.