Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add onnx tutorial #990

Merged
merged 5 commits into from
Nov 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ Congratulations! You are done! Now you can train your model with your favorite f
- Training model for cars segmentation on CamVid dataset [here](https://github.com/qubvel/segmentation_models.pytorch/blob/main/examples/cars%20segmentation%20(camvid).ipynb).
- Training SMP model with [Catalyst](https://github.com/catalyst-team/catalyst) (high-level framework for PyTorch), [TTAch](https://github.com/qubvel/ttach) (TTA library for PyTorch) and [Albumentations](https://github.com/albu/albumentations) (fast image augmentation library) - [here](https://github.com/catalyst-team/catalyst/blob/v21.02rc0/examples/notebooks/segmentation-tutorial.ipynb) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/catalyst-team/catalyst/blob/v21.02rc0/examples/notebooks/segmentation-tutorial.ipynb)
- Training SMP model with [Pytorch-Lightning](https://pytorch-lightning.readthedocs.io) framework - [here](https://github.com/ternaus/cloths_segmentation) (clothes binary segmentation by [@ternaus](https://github.com/ternaus)).
- Export trained model to ONNX - [notebook](https://github.com/qubvel/segmentation_models.pytorch/blob/main/examples/convert_to_onnx.ipynb) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/qubvel/segmentation_models.pytorch/blob/main/examples/convert_to_onnx.ipynb)

### 📦 Models <a name="models"></a>

Expand Down
223 changes: 223 additions & 0 deletions examples/convert_to_onnx.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/qubvel/segmentation_models.pytorch/blob/main/examples/convert_to_onnx.ipynb)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# to make onnx export work\n",
"!pip install onnx onnxruntime"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"See complete tutorial in Pytorch docs:\n",
" - https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import onnx\n",
"import onnxruntime\n",
"import numpy as np\n",
"\n",
"import torch\n",
"import segmentation_models_pytorch as smp"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Create random model (or load your own model)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"model = smp.Unet(\"resnet34\", encoder_weights=\"imagenet\", classes=1)\n",
"model = model.eval()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Export the model to ONNX"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# dynamic_axes is used to specify the variable length axes. it can be just batch size\n",
"dynamic_axes = {0: \"batch_size\", 2: \"height\", 3: \"width\"}\n",
"\n",
"onnx_model_name = \"unet_resnet34.onnx\"\n",
"\n",
"onnx_model = torch.onnx.export(\n",
" model, # model being run\n",
" torch.randn(1, 3, 224, 224), # model input\n",
" onnx_model_name, # where to save the model (can be a file or file-like object)\n",
" export_params=True, # store the trained parameter weights inside the model file\n",
" opset_version=17, # the ONNX version to export\n",
" do_constant_folding=True, # whether to execute constant folding for optimization\n",
" input_names=[\"input\"], # the model's input names\n",
" output_names=[\"output\"], # the model's output names\n",
" dynamic_axes={ # variable length axes\n",
" \"input\": dynamic_axes,\n",
" \"output\": dynamic_axes,\n",
" },\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# check with onnx first\n",
"onnx_model = onnx.load(onnx_model_name)\n",
"onnx.checker.check_model(onnx_model)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Run with onnxruntime"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[array([[[[-1.41701847e-01, -4.63768840e-03, 1.21411584e-01, ...,\n",
" 5.22197843e-01, 3.40217263e-01, 8.52423906e-02],\n",
" [-2.29843616e-01, 2.19401851e-01, 3.53053480e-01, ...,\n",
" 2.79466838e-01, 3.20288718e-01, -2.22393833e-02],\n",
" [-3.12503517e-01, -3.66358161e-02, 1.19251609e-02, ...,\n",
" -5.48991561e-02, 3.71140465e-02, -1.82842150e-01],\n",
" ...,\n",
" [-3.02772015e-01, -4.22928065e-01, -1.49621412e-01, ...,\n",
" -1.42241001e-01, -9.90390778e-02, -1.33311331e-01],\n",
" [-1.08293816e-01, -1.28070369e-01, -5.43620177e-02, ...,\n",
" -8.64556879e-02, -1.74177170e-01, 6.03154302e-03],\n",
" [-1.29619062e-01, -2.96604559e-02, -2.86361389e-03, ...,\n",
" -1.91345289e-01, -1.82653710e-01, 1.17175849e-02]]],\n",
" \n",
" \n",
" [[[-6.16237633e-02, 1.12350248e-01, 1.59193069e-01, ...,\n",
" 4.03313845e-01, 2.26862252e-01, 7.33022243e-02],\n",
" [-1.60109222e-01, 1.21696621e-01, 1.84655115e-01, ...,\n",
" 1.20978586e-01, 2.45723248e-01, 1.00066036e-01],\n",
" [-2.11992145e-01, 1.71708465e-02, -1.57656223e-02, ...,\n",
" -1.11918494e-01, -1.64519548e-01, -1.73958957e-01],\n",
" ...,\n",
" [-2.79706120e-01, -2.87421644e-01, -5.19880295e-01, ...,\n",
" -8.30744207e-02, -3.48939300e-02, 1.26617640e-01],\n",
" [-2.62198627e-01, -2.91804910e-01, -2.82318443e-01, ...,\n",
" 1.81179233e-02, 2.32534595e-02, 1.85002953e-01],\n",
" [-9.28771719e-02, -5.16399741e-05, -9.53909755e-03, ...,\n",
" -2.28582099e-02, -5.09671569e-02, 2.05268264e-02]]]],\n",
" dtype=float32)]"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# create sample with different batch size, height and width\n",
"# from what we used in export above\n",
"sample = torch.randn(2, 3, 512, 512)\n",
"\n",
"ort_session = onnxruntime.InferenceSession(\n",
" onnx_model_name, providers=[\"CPUExecutionProvider\"]\n",
")\n",
"\n",
"# compute ONNX Runtime output prediction\n",
"ort_inputs = {\"input\": sample.numpy()}\n",
"ort_outputs = ort_session.run(output_names=None, input_feed=ort_inputs)\n",
"ort_outputs"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Verify it's the same as for pytorch model"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Exported model has been tested with ONNXRuntime, and the result looks good!\n"
]
}
],
"source": [
"# compute PyTorch output prediction\n",
"with torch.no_grad():\n",
" torch_out = model(sample)\n",
"\n",
"# compare ONNX Runtime and PyTorch results\n",
"np.testing.assert_allclose(torch_out.numpy(), ort_outputs[0], rtol=1e-03, atol=1e-05)\n",
"\n",
"print(\"Exported model has been tested with ONNXRuntime, and the result looks good!\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
4 changes: 3 additions & 1 deletion segmentation_models_pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@

# Suppress the specific SyntaxWarning for `pretrainedmodels`
warnings.filterwarnings("ignore", message="is with a literal", category=SyntaxWarning)
warnings.filterwarnings("ignore", message=r'"is" with \'str\' literal.*', category=SyntaxWarning) # for python >= 3.12
warnings.filterwarnings(
"ignore", message=r'"is" with \'str\' literal.*', category=SyntaxWarning
) # for python >= 3.12


def create_model(
Expand Down
3 changes: 2 additions & 1 deletion segmentation_models_pytorch/base/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ def check_input_shape(self, x):
def forward(self, x):
"""Sequentially pass `x` trough model`s encoder, decoder and heads"""

self.check_input_shape(x)
if not torch.jit.is_tracing():
self.check_input_shape(x)

features = self.encoder(x)
decoder_output = self.decoder(*features)
Expand Down