generated from xinetzone/xyzstyle
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
新文件: doc/tutorial/onnx-op/dynamo-export.ipynb
新文件: doc/tutorial/onnx-op/export.ipynb
- Loading branch information
liuxinwei
committed
Apr 23, 2024
1 parent
e9fbc95
commit 9c0124b
Showing
7 changed files
with
664 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,3 +2,4 @@ data/ | |
params/ | ||
attack/ | ||
*.pt | ||
*.sarif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,256 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# PyTorch 转 ONNX \n", | ||
"\n", | ||
"参考:[PyTorch 转换为 ONNX](https://pytorch.org/tutorials//beginner/onnx/export_simple_model_to_onnx_tutorial.html)\n", | ||
"\n", | ||
"在 PyTorch 2.1 版本中,有两种 ONNX 导出工具。\n", | ||
"\n", | ||
"- {func}`torch.onnx.dynamo_export` 是最新的(仍处于测试阶段)基于 TorchDynamo 技术的导出器,该技术与 PyTorch 2.0 一同发布。\n", | ||
"- {func}`torch.onnx.export` 是基于 TorchScript 后端的,自 PyTorch 1.2.0 以来一直可用。\n", | ||
"\n", | ||
"由于 ONNX 导出器使用 `onnx` 和 `onnxscript` 将 PyTorch 算子转换为 ONNX 算子,需要安装:\n", | ||
"\n", | ||
"```bash\n", | ||
"pip install onnx onnxscript\n", | ||
"```\n", | ||
"\n", | ||
"下面以简单的分类器为例展开。\n", | ||
"\n", | ||
"## 简单的分类器模型导出准备" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import torch\n", | ||
"import torch.nn as nn\n", | ||
"import torch.nn.functional as F\n", | ||
"\n", | ||
"\n", | ||
"class MyModel(nn.Module):\n", | ||
"\n", | ||
" def __init__(self):\n", | ||
" super(MyModel, self).__init__()\n", | ||
" self.conv1 = nn.Conv2d(1, 6, 5)\n", | ||
" self.conv2 = nn.Conv2d(6, 16, 5)\n", | ||
" self.fc1 = nn.Linear(16 * 5 * 5, 120)\n", | ||
" self.fc2 = nn.Linear(120, 84)\n", | ||
" self.fc3 = nn.Linear(84, 10)\n", | ||
"\n", | ||
" def forward(self, x):\n", | ||
" x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))\n", | ||
" x = F.max_pool2d(F.relu(self.conv2(x)), 2)\n", | ||
" x = torch.flatten(x, 1)\n", | ||
" x = F.relu(self.fc1(x))\n", | ||
" x = F.relu(self.fc2(x))\n", | ||
" x = self.fc3(x)\n", | ||
" return x" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## 将模型导出为 ONNX 格式\n", | ||
"\n", | ||
"实例化模型并创建随机的 32x32 输入。接下来,可以将模型导出为 ONNX 格式。" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"/media/pc/data/tmp/cache/conda/envs/py311/lib/python3.11/site-packages/torch/onnx/_internal/exporter.py:130: UserWarning: torch.onnx.dynamo_export only implements opset version 18 for now. If you need to use a different opset version, please register them with register_custom_op.\n", | ||
" warnings.warn(\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"torch_model = MyModel()\n", | ||
"torch_input = torch.randn(1, 1, 32, 32)\n", | ||
"onnx_program = torch.onnx.dynamo_export(torch_model, torch_input)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"不需要对模型进行任何代码更改。生成的 ONNX 模型存储在二进制 protobuf 文件 `torch.onnx.ONNXProgram` 中。" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## 将 ONNX 模型保存到文件中\n", | ||
"\n", | ||
"尽管在许多应用中将导出的模型加载到内存中是有用的,但我们可以将其保存到磁盘上,代码如下:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"onnx_program.save(\"my_image_classifier.onnx\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"您可以将 ONNX 文件重新加载到内存中,并使用以下代码检查其格式是否正确:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import onnx\n", | ||
"onnx_model = onnx.load(\"my_image_classifier.onnx\")\n", | ||
"onnx.checker.check_model(onnx_model)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## 使用 ONNX Runtime 执行 ONNX 模型\n", | ||
"\n", | ||
"最后一步是使用 ONNX Runtime 执行 ONNX 模型,但在我们这样做之前,让我们先安装 ONNX Runtime。\n", | ||
"\n", | ||
"```bash\n", | ||
"pip install onnxruntime\n", | ||
"```\n", | ||
"\n", | ||
"ONNX 标准不支持 PyTorch 支持的所有数据结构和类型,所以需要在将输入喂给 ONNX Runtime 之前,将 PyTorch 的输入适配为 ONNX 格式。在我们的示例中,输入恰好是相同的,但在更复杂的模型中,它可能比原始的 PyTorch 模型有更多的输入。\n", | ||
"\n", | ||
"ONNX Runtime 需要额外的步骤,该步骤涉及将所有 PyTorch 张量转换为 Numpy(在 CPU 上),并在字典中包装它们,其中键是字符串,表示输入名称,值为 `numpy` 张量。\n", | ||
"\n", | ||
"现在我们可以创建 ONNX Runtime 推理会话,使用处理过的输入执行 ONNX 模型并获取输出。在这个教程中,ONNX Runtime 是在 CPU 上执行的,但它也可以在 GPU上 执行。" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Input length: 1\n", | ||
"Sample input: (tensor([[[[-0.5305, -0.6818, 2.2350, ..., -0.2503, 0.4694, 1.3666],\n", | ||
" [ 0.7013, 0.0179, -1.2689, ..., 0.4369, 0.5982, -0.6541],\n", | ||
" [ 0.8644, 0.8552, 0.4100, ..., -0.8513, 0.4207, 0.4363],\n", | ||
" ...,\n", | ||
" [ 0.4400, -0.3064, -1.9848, ..., 0.0462, 0.7269, 1.3543],\n", | ||
" [ 1.5511, -0.6354, 0.9151, ..., 0.2501, -0.0140, -0.3875],\n", | ||
" [-1.2229, -0.8693, 1.0505, ..., 0.0598, 0.7852, 0.1350]]]]),)\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"import onnxruntime\n", | ||
"\n", | ||
"onnx_input = onnx_program.adapt_torch_inputs_to_onnx(torch_input)\n", | ||
"print(f\"Input length: {len(onnx_input)}\")\n", | ||
"print(f\"Sample input: {onnx_input}\")\n", | ||
"\n", | ||
"ort_session = onnxruntime.InferenceSession(\"./my_image_classifier.onnx\", providers=['CPUExecutionProvider'])\n", | ||
"\n", | ||
"def to_numpy(tensor):\n", | ||
" return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()\n", | ||
"\n", | ||
"onnxruntime_input = {k.name: to_numpy(v) for k, v in zip(ort_session.get_inputs(), onnx_input)}\n", | ||
"\n", | ||
"onnxruntime_outputs = ort_session.run(None, onnxruntime_input)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## 将PyTorch的结果与ONNX Runtime的结果进行比较\n", | ||
"\n", | ||
"确定导出模型是否良好的最佳方式是通过与 PyTorch 的数值评估,这是我们的真实来源。\n", | ||
"\n", | ||
"为此,我们需要使用相同的输入执行 PyTorch 模型,并将结果与 ONNX Runtime 的结果进行比较。在比较结果之前,我们需要将 PyTorch 的输出转换为匹配 ONNX 的格式。" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 6, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"PyTorch and ONNX Runtime output matched!\n", | ||
"Output length: 1\n", | ||
"Sample output: [array([[-0.02100155, -0.13608684, -0.14742026, -0.04622332, -0.01618233,\n", | ||
" -0.07353653, -0.11702952, -0.02780916, 0.09021657, 0.02800114]],\n", | ||
" dtype=float32)]\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"torch_outputs = torch_model(torch_input)\n", | ||
"torch_outputs = onnx_program.adapt_torch_outputs_to_onnx(torch_outputs)\n", | ||
"\n", | ||
"assert len(torch_outputs) == len(onnxruntime_outputs)\n", | ||
"for torch_output, onnxruntime_output in zip(torch_outputs, onnxruntime_outputs):\n", | ||
" torch.testing.assert_close(torch_output, torch.tensor(onnxruntime_output))\n", | ||
"\n", | ||
"print(\"PyTorch and ONNX Runtime output matched!\")\n", | ||
"print(f\"Output length: {len(onnxruntime_outputs)}\")\n", | ||
"print(f\"Sample output: {onnxruntime_outputs}\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "xin", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.11.8" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
Oops, something went wrong.