Skip to content

Commit aef6e56

Browse files
committed
Add onnx utils and export code, tweak padding and conv2d_same for better dynamic export with recent PyTorch
1 parent 80b247d commit aef6e56

File tree

6 files changed

+399
-15
lines changed

6 files changed

+399
-15
lines changed

onnx_export.py

+86
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
""" ONNX export script
2+
3+
Export PyTorch models as ONNX graphs.
4+
5+
This export script originally started as an adaptation of code snippets found at
6+
https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html
7+
8+
The default parameters work with PyTorch 1.6 and ONNX 1.7 and produce an optimal ONNX graph
9+
for hosting in the ONNX runtime (see onnx_validate.py). To export an ONNX model compatible
10+
with caffe2 (see caffe2_benchmark.py and caffe2_validate.py), the --keep-init and --aten-fallback
11+
flags are currently required.
12+
13+
Older versions of PyTorch/ONNX (tested PyTorch 1.4, ONNX 1.5) do not need extra flags for
14+
caffe2 compatibility, but they produce a model that isn't as fast running on ONNX runtime.
15+
16+
Most new release of PyTorch and ONNX cause some sort of breakage in the export / usage of ONNX models.
17+
Please do your research and search ONNX and PyTorch issue tracker before asking me. Thanks.
18+
19+
Copyright 2020 Ross Wightman
20+
"""
21+
import argparse
22+
23+
import timm
24+
from timm.utils.onnx import onnx_export
25+
26+
parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation')
27+
parser.add_argument('output', metavar='ONNX_FILE',
28+
help='output model filename')
29+
parser.add_argument('--model', '-m', metavar='MODEL', default='mobilenetv3_large_100',
30+
help='model architecture (default: mobilenetv3_large_100)')
31+
parser.add_argument('--opset', type=int, default=None,
32+
help='ONNX opset to use (default: 10)')
33+
parser.add_argument('--keep-init', action='store_true', default=False,
34+
help='Keep initializers as input. Needed for Caffe2 compatible export in newer PyTorch/ONNX.')
35+
parser.add_argument('--aten-fallback', action='store_true', default=False,
36+
help='Fallback to ATEN ops. Helps fix AdaptiveAvgPool issue with Caffe2 in newer PyTorch/ONNX.')
37+
parser.add_argument('--dynamic-size', action='store_true', default=False,
38+
help='Export model width dynamic width/height. Not recommended for "tf" models with SAME padding.')
39+
parser.add_argument('--check-forward', action='store_true', default=False,
40+
help='Do a full check of torch vs onnx forward after export.')
41+
parser.add_argument('-b', '--batch-size', default=1, type=int,
42+
metavar='N', help='mini-batch size (default: 1)')
43+
parser.add_argument('--img-size', default=None, type=int,
44+
metavar='N', help='Input image dimension, uses model default if empty')
45+
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
46+
help='Override mean pixel value of dataset')
47+
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
48+
help='Override std deviation of of dataset')
49+
parser.add_argument('--num-classes', type=int, default=1000,
50+
help='Number classes in dataset')
51+
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
52+
help='path to checkpoint (default: none)')
53+
54+
55+
def main():
56+
args = parser.parse_args()
57+
58+
args.pretrained = True
59+
if args.checkpoint:
60+
args.pretrained = False
61+
62+
print("==> Creating PyTorch {} model".format(args.model))
63+
# NOTE exportable=True flag disables autofn/jit scripted activations and uses Conv2dSameExport layers
64+
# for models using SAME padding
65+
model = timm.create_model(
66+
args.model,
67+
num_classes=args.num_classes,
68+
in_chans=3,
69+
pretrained=args.pretrained,
70+
checkpoint_path=args.checkpoint,
71+
exportable=True,
72+
)
73+
74+
onnx_export(
75+
model,
76+
args.output,
77+
opset=args.opset,
78+
dynamic_size=args.dynamic_size,
79+
aten_fallback=args.aten_fallback,
80+
keep_initializers=args.keep_init,
81+
check_forward=args.check_forward,
82+
)
83+
84+
85+
if __name__ == '__main__':
86+
main()

onnx_validate.py

+110
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
""" ONNX-runtime validation script
2+
3+
This script was created to verify accuracy and performance of exported ONNX
4+
models running with the onnxruntime. It utilizes the PyTorch dataloader/processing
5+
pipeline for a fair comparison against the originals.
6+
7+
Copyright 2020 Ross Wightman
8+
"""
9+
import argparse
10+
import numpy as np
11+
import onnxruntime
12+
from timm.data import create_loader, resolve_data_config, create_dataset
13+
from timm.utils import AverageMeter
14+
import time
15+
16+
parser = argparse.ArgumentParser(description='ONNX Validation')
17+
parser.add_argument('data', metavar='DIR',
18+
help='path to dataset')
19+
parser.add_argument('--onnx-input', default='', type=str, metavar='PATH',
20+
help='path to onnx model/weights file')
21+
parser.add_argument('--onnx-output-opt', default='', type=str, metavar='PATH',
22+
help='path to output optimized onnx graph')
23+
parser.add_argument('--profile', action='store_true', default=False,
24+
help='Enable profiler output.')
25+
parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',
26+
help='number of data loading workers (default: 2)')
27+
parser.add_argument('-b', '--batch-size', default=256, type=int,
28+
metavar='N', help='mini-batch size (default: 256)')
29+
parser.add_argument('--img-size', default=None, type=int,
30+
metavar='N', help='Input image dimension, uses model default if empty')
31+
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
32+
help='Override mean pixel value of dataset')
33+
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
34+
help='Override std deviation of of dataset')
35+
parser.add_argument('--crop-pct', type=float, default=None, metavar='PCT',
36+
help='Override default crop pct of 0.875')
37+
parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
38+
help='Image resize interpolation type (overrides model)')
39+
parser.add_argument('--print-freq', '-p', default=10, type=int,
40+
metavar='N', help='print frequency (default: 10)')
41+
42+
43+
def main():
44+
args = parser.parse_args()
45+
args.gpu_id = 0
46+
47+
# Set graph optimization level
48+
sess_options = onnxruntime.SessionOptions()
49+
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
50+
if args.profile:
51+
sess_options.enable_profiling = True
52+
if args.onnx_output_opt:
53+
sess_options.optimized_model_filepath = args.onnx_output_opt
54+
55+
session = onnxruntime.InferenceSession(args.onnx_input, sess_options)
56+
57+
data_config = resolve_data_config(vars(args))
58+
loader = create_loader(
59+
create_dataset('', args.data),
60+
input_size=data_config['input_size'],
61+
batch_size=args.batch_size,
62+
use_prefetcher=False,
63+
interpolation=data_config['interpolation'],
64+
mean=data_config['mean'],
65+
std=data_config['std'],
66+
num_workers=args.workers,
67+
crop_pct=data_config['crop_pct']
68+
)
69+
70+
input_name = session.get_inputs()[0].name
71+
72+
batch_time = AverageMeter()
73+
top1 = AverageMeter()
74+
top5 = AverageMeter()
75+
end = time.time()
76+
for i, (input, target) in enumerate(loader):
77+
# run the net and return prediction
78+
output = session.run([], {input_name: input.data.numpy()})
79+
output = output[0]
80+
81+
# measure accuracy and record loss
82+
prec1, prec5 = accuracy_np(output, target.numpy())
83+
top1.update(prec1.item(), input.size(0))
84+
top5.update(prec5.item(), input.size(0))
85+
86+
# measure elapsed time
87+
batch_time.update(time.time() - end)
88+
end = time.time()
89+
90+
if i % args.print_freq == 0:
91+
print(
92+
f'Test: [{i}/{len(loader)}]\t'
93+
f'Time {batch_time.val:.3f} ({batch_time.avg:.3f}, {input.size(0) / batch_time.avg:.3f}/s, '
94+
f'{100 * batch_time.avg / input.size(0):.3f} ms/sample) \t'
95+
f'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
96+
f'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'
97+
)
98+
99+
print(f' * Prec@1 {top1.avg:.3f} ({100-top1.avg:.3f}) Prec@5 {top5.avg:.3f} ({100.-top5.avg:.3f})')
100+
101+
102+
def accuracy_np(output, target):
103+
max_indices = np.argsort(output, axis=1)[:, ::-1]
104+
top5 = 100 * np.equal(max_indices[:, :5], target[:, np.newaxis]).sum(axis=1).mean()
105+
top1 = 100 * np.equal(max_indices[:, 0], target).mean()
106+
return top1, top5
107+
108+
109+
if __name__ == '__main__':
110+
main()

timm/layers/conv2d_same.py

+76-8
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,22 @@
77
import torch.nn.functional as F
88
from typing import Tuple, Optional
99

10-
from .padding import pad_same, get_padding_value
10+
from .config import is_exportable, is_scriptable
11+
from .padding import pad_same, pad_same_arg, get_padding_value
12+
13+
14+
_USE_EXPORT_CONV = False
1115

1216

1317
def conv2d_same(
14-
x, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, stride: Tuple[int, int] = (1, 1),
15-
padding: Tuple[int, int] = (0, 0), dilation: Tuple[int, int] = (1, 1), groups: int = 1):
18+
x,
19+
weight: torch.Tensor,
20+
bias: Optional[torch.Tensor] = None,
21+
stride: Tuple[int, int] = (1, 1),
22+
padding: Tuple[int, int] = (0, 0),
23+
dilation: Tuple[int, int] = (1, 1),
24+
groups: int = 1,
25+
):
1626
x = pad_same(x, weight.shape[-2:], stride, dilation)
1727
return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups)
1828

@@ -21,21 +31,79 @@ class Conv2dSame(nn.Conv2d):
2131
""" Tensorflow like 'SAME' convolution wrapper for 2D convolutions
2232
"""
2333

24-
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
25-
padding=0, dilation=1, groups=1, bias=True):
34+
def __init__(
35+
self,
36+
in_channels,
37+
out_channels,
38+
kernel_size,
39+
stride=1,
40+
padding=0,
41+
dilation=1,
42+
groups=1,
43+
bias=True,
44+
):
2645
super(Conv2dSame, self).__init__(
27-
in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias)
46+
in_channels, out_channels, kernel_size,
47+
stride, 0, dilation, groups, bias,
48+
)
49+
50+
def forward(self, x):
51+
return conv2d_same(
52+
x, self.weight, self.bias,
53+
self.stride, self.padding, self.dilation, self.groups,
54+
)
55+
56+
57+
class Conv2dSameExport(nn.Conv2d):
58+
""" ONNX export friendly Tensorflow like 'SAME' convolution wrapper for 2D convolutions
59+
60+
NOTE: This does not currently work with torch.jit.script
61+
"""
62+
63+
# pylint: disable=unused-argument
64+
def __init__(
65+
self,
66+
in_channels,
67+
out_channels,
68+
kernel_size,
69+
stride=1,
70+
padding=0,
71+
dilation=1,
72+
groups=1,
73+
bias=True,
74+
):
75+
super(Conv2dSameExport, self).__init__(
76+
in_channels, out_channels, kernel_size,
77+
stride, 0, dilation, groups, bias,
78+
)
79+
self.pad = None
80+
self.pad_input_size = (0, 0)
2881

2982
def forward(self, x):
30-
return conv2d_same(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
83+
input_size = x.size()[-2:]
84+
if self.pad is None:
85+
pad_arg = pad_same_arg(input_size, self.weight.size()[-2:], self.stride, self.dilation)
86+
self.pad = nn.ZeroPad2d(pad_arg)
87+
self.pad_input_size = input_size
88+
89+
x = self.pad(x)
90+
return F.conv2d(
91+
x, self.weight, self.bias,
92+
self.stride, self.padding, self.dilation, self.groups,
93+
)
3194

3295

3396
def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs):
3497
padding = kwargs.pop('padding', '')
3598
kwargs.setdefault('bias', False)
3699
padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs)
37100
if is_dynamic:
38-
return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs)
101+
if _USE_EXPORT_CONV and is_exportable():
102+
# older PyTorch ver needed this to export same padding reasonably
103+
assert not is_scriptable() # Conv2DSameExport does not work with jit
104+
return Conv2dSameExport(in_chs, out_chs, kernel_size, **kwargs)
105+
else:
106+
return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs)
39107
else:
40108
return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs)
41109

timm/layers/padding.py

+29-6
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import math
66
from typing import List, Tuple
77

8+
import torch
89
import torch.nn.functional as F
910

1011

@@ -15,21 +16,43 @@ def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> in
1516

1617

1718
# Calculate asymmetric TensorFlow-like 'SAME' padding for a convolution
18-
def get_same_padding(x: int, k: int, s: int, d: int):
19-
return max((math.ceil(x / s) - 1) * s + (k - 1) * d + 1 - x, 0)
19+
def get_same_padding(x: int, kernel_size: int, stride: int, dilation: int):
20+
if isinstance(x, torch.Tensor):
21+
return torch.clamp(((x / stride).ceil() - 1) * stride + (kernel_size - 1) * dilation + 1 - x, min=0)
22+
else:
23+
return max((math.ceil(x / stride) - 1) * stride + (kernel_size - 1) * dilation + 1 - x, 0)
2024

2125

2226
# Can SAME padding for given args be done statically?
2327
def is_static_pad(kernel_size: int, stride: int = 1, dilation: int = 1, **_):
2428
return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0
2529

2630

31+
def pad_same_arg(
32+
input_size: List[int],
33+
kernel_size: List[int],
34+
stride: List[int],
35+
dilation: List[int] = (1, 1),
36+
) -> List[int]:
37+
ih, iw = input_size
38+
kh, kw = kernel_size
39+
pad_h = get_same_padding(ih, kh, stride[0], dilation[0])
40+
pad_w = get_same_padding(iw, kw, stride[1], dilation[1])
41+
return [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]
42+
43+
2744
# Dynamically pad input x with 'SAME' padding for conv with specified args
28-
def pad_same(x, k: List[int], s: List[int], d: List[int] = (1, 1), value: float = 0):
45+
def pad_same(
46+
x,
47+
kernel_size: List[int],
48+
stride: List[int],
49+
dilation: List[int] = (1, 1),
50+
value: float = 0,
51+
):
2952
ih, iw = x.size()[-2:]
30-
pad_h, pad_w = get_same_padding(ih, k[0], s[0], d[0]), get_same_padding(iw, k[1], s[1], d[1])
31-
if pad_h > 0 or pad_w > 0:
32-
x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2], value=value)
53+
pad_h = get_same_padding(ih, kernel_size[0], stride[0], dilation[0])
54+
pad_w = get_same_padding(iw, kernel_size[1], stride[1], dilation[1])
55+
x = F.pad(x, (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2), value=value)
3356
return x
3457

3558

timm/layers/pos_embed_sincos.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
import torch
99
from torch import nn as nn
1010

11+
from .trace_utils import _assert
12+
1113

1214
def pixel_freq_bands(
1315
num_bands: int,
@@ -425,7 +427,7 @@ def __init__(
425427
def get_embed(self, shape: Optional[List[int]] = None):
426428
if self.bands is not None:
427429
# rebuild embeddings every call, use if target shape changes
428-
assert shape is not None
430+
_assert(shape is not None, 'valid shape needed')
429431
embeds = build_rotary_pos_embed(
430432
shape,
431433
self.bands,

0 commit comments

Comments
 (0)