diff --git a/README.md b/README.md index 616871ec..c26ff4ba 100644 --- a/README.md +++ b/README.md @@ -81,6 +81,7 @@ Congratulations! You are done! Now you can train your model with your favorite f - Training model for cars segmentation on CamVid dataset [here](https://github.com/qubvel/segmentation_models.pytorch/blob/main/examples/cars%20segmentation%20(camvid).ipynb). - Training SMP model with [Catalyst](https://github.com/catalyst-team/catalyst) (high-level framework for PyTorch), [TTAch](https://github.com/qubvel/ttach) (TTA library for PyTorch) and [Albumentations](https://github.com/albu/albumentations) (fast image augmentation library) - [here](https://github.com/catalyst-team/catalyst/blob/v21.02rc0/examples/notebooks/segmentation-tutorial.ipynb) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/catalyst-team/catalyst/blob/v21.02rc0/examples/notebooks/segmentation-tutorial.ipynb) - Training SMP model with [Pytorch-Lightning](https://pytorch-lightning.readthedocs.io) framework - [here](https://github.com/ternaus/cloths_segmentation) (clothes binary segmentation by [@ternaus](https://github.com/ternaus)). + - Export trained model to ONNX - [notebook](https://github.com/qubvel/segmentation_models.pytorch/blob/main/examples/convert_to_onnx.ipynb) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/qubvel/segmentation_models.pytorch/blob/main/examples/convert_to_onnx.ipynb) ### 📦 Models diff --git a/examples/convert_to_onnx.ipynb b/examples/convert_to_onnx.ipynb new file mode 100644 index 00000000..abd063a0 --- /dev/null +++ b/examples/convert_to_onnx.ipynb @@ -0,0 +1,223 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/qubvel/segmentation_models.pytorch/blob/main/examples/convert_to_onnx.ipynb)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# to make onnx export work\n", + "!pip install onnx onnxruntime" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "See complete tutorial in Pytorch docs:\n", + " - https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import onnx\n", + "import onnxruntime\n", + "import numpy as np\n", + "\n", + "import torch\n", + "import segmentation_models_pytorch as smp" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Create random model (or load your own model)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "model = smp.Unet(\"resnet34\", encoder_weights=\"imagenet\", classes=1)\n", + "model = model.eval()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Export the model to ONNX" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "# dynamic_axes is used to specify the variable length axes. it can be just batch size\n", + "dynamic_axes = {0: \"batch_size\", 2: \"height\", 3: \"width\"}\n", + "\n", + "onnx_model_name = \"unet_resnet34.onnx\"\n", + "\n", + "onnx_model = torch.onnx.export(\n", + " model, # model being run\n", + " torch.randn(1, 3, 224, 224), # model input\n", + " onnx_model_name, # where to save the model (can be a file or file-like object)\n", + " export_params=True, # store the trained parameter weights inside the model file\n", + " opset_version=17, # the ONNX version to export\n", + " do_constant_folding=True, # whether to execute constant folding for optimization\n", + " input_names=[\"input\"], # the model's input names\n", + " output_names=[\"output\"], # the model's output names\n", + " dynamic_axes={ # variable length axes\n", + " \"input\": dynamic_axes,\n", + " \"output\": dynamic_axes,\n", + " },\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "# check with onnx first\n", + "onnx_model = onnx.load(onnx_model_name)\n", + "onnx.checker.check_model(onnx_model)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Run with onnxruntime" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[array([[[[-1.41701847e-01, -4.63768840e-03, 1.21411584e-01, ...,\n", + " 5.22197843e-01, 3.40217263e-01, 8.52423906e-02],\n", + " [-2.29843616e-01, 2.19401851e-01, 3.53053480e-01, ...,\n", + " 2.79466838e-01, 3.20288718e-01, -2.22393833e-02],\n", + " [-3.12503517e-01, -3.66358161e-02, 1.19251609e-02, ...,\n", + " -5.48991561e-02, 3.71140465e-02, -1.82842150e-01],\n", + " ...,\n", + " [-3.02772015e-01, -4.22928065e-01, -1.49621412e-01, ...,\n", + " -1.42241001e-01, -9.90390778e-02, -1.33311331e-01],\n", + " [-1.08293816e-01, -1.28070369e-01, -5.43620177e-02, ...,\n", + " -8.64556879e-02, -1.74177170e-01, 6.03154302e-03],\n", + " [-1.29619062e-01, -2.96604559e-02, -2.86361389e-03, ...,\n", + " -1.91345289e-01, -1.82653710e-01, 1.17175849e-02]]],\n", + " \n", + " \n", + " [[[-6.16237633e-02, 1.12350248e-01, 1.59193069e-01, ...,\n", + " 4.03313845e-01, 2.26862252e-01, 7.33022243e-02],\n", + " [-1.60109222e-01, 1.21696621e-01, 1.84655115e-01, ...,\n", + " 1.20978586e-01, 2.45723248e-01, 1.00066036e-01],\n", + " [-2.11992145e-01, 1.71708465e-02, -1.57656223e-02, ...,\n", + " -1.11918494e-01, -1.64519548e-01, -1.73958957e-01],\n", + " ...,\n", + " [-2.79706120e-01, -2.87421644e-01, -5.19880295e-01, ...,\n", + " -8.30744207e-02, -3.48939300e-02, 1.26617640e-01],\n", + " [-2.62198627e-01, -2.91804910e-01, -2.82318443e-01, ...,\n", + " 1.81179233e-02, 2.32534595e-02, 1.85002953e-01],\n", + " [-9.28771719e-02, -5.16399741e-05, -9.53909755e-03, ...,\n", + " -2.28582099e-02, -5.09671569e-02, 2.05268264e-02]]]],\n", + " dtype=float32)]" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# create sample with different batch size, height and width\n", + "# from what we used in export above\n", + "sample = torch.randn(2, 3, 512, 512)\n", + "\n", + "ort_session = onnxruntime.InferenceSession(\n", + " onnx_model_name, providers=[\"CPUExecutionProvider\"]\n", + ")\n", + "\n", + "# compute ONNX Runtime output prediction\n", + "ort_inputs = {\"input\": sample.numpy()}\n", + "ort_outputs = ort_session.run(output_names=None, input_feed=ort_inputs)\n", + "ort_outputs" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Verify it's the same as for pytorch model" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Exported model has been tested with ONNXRuntime, and the result looks good!\n" + ] + } + ], + "source": [ + "# compute PyTorch output prediction\n", + "with torch.no_grad():\n", + " torch_out = model(sample)\n", + "\n", + "# compare ONNX Runtime and PyTorch results\n", + "np.testing.assert_allclose(torch_out.numpy(), ort_outputs[0], rtol=1e-03, atol=1e-05)\n", + "\n", + "print(\"Exported model has been tested with ONNXRuntime, and the result looks good!\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/segmentation_models_pytorch/__init__.py b/segmentation_models_pytorch/__init__.py index c8209226..5cde6004 100644 --- a/segmentation_models_pytorch/__init__.py +++ b/segmentation_models_pytorch/__init__.py @@ -26,7 +26,9 @@ # Suppress the specific SyntaxWarning for `pretrainedmodels` warnings.filterwarnings("ignore", message="is with a literal", category=SyntaxWarning) -warnings.filterwarnings("ignore", message=r'"is" with \'str\' literal.*', category=SyntaxWarning) # for python >= 3.12 +warnings.filterwarnings( + "ignore", message=r'"is" with \'str\' literal.*', category=SyntaxWarning +) # for python >= 3.12 def create_model( diff --git a/segmentation_models_pytorch/base/model.py b/segmentation_models_pytorch/base/model.py index 1b4b3d61..29c7dd2a 100644 --- a/segmentation_models_pytorch/base/model.py +++ b/segmentation_models_pytorch/base/model.py @@ -33,7 +33,8 @@ def check_input_shape(self, x): def forward(self, x): """Sequentially pass `x` trough model`s encoder, decoder and heads""" - self.check_input_shape(x) + if not torch.jit.is_tracing(): + self.check_input_shape(x) features = self.encoder(x) decoder_output = self.decoder(*features)