-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdeploy.py
66 lines (52 loc) · 2.14 KB
/
deploy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import blobconverter
import onnx
import torch
import argparse
from onnxsim import simplify
def deploy(model_path: str, new_model: str):
"""
Deploy model to ONNX
:param model_path: the path to the model checkpoint, such as 'checkpoints/checkpoint_ssd300.pt'
:param new_model: the name of the ONNX model, such as 'ssd300'
:return:
"""
# Load model checkpoint
checkpoint = torch.load(model_path, map_location='cuda')
model = checkpoint['model']
model.eval()
# Export to ONNX
input_names = ['input']
output_names = ['boxes', 'scores']
dummy_input = torch.randn(1, 3, 300, 300).to('cuda')
torch.onnx.export(model, dummy_input, new_model+'.onnx', verbose=True,
input_names=input_names, output_names=output_names)
# Simplify ONNX
simple_model, check = simplify('ssd300.onnx')
assert check, "Simplified ONNX model could not be validated"
onnx.save(simple_model, new_model+'-sim'+'.onnx')
# Optional: Deploy model to Blob
def deploy_blob(model_name: str, output_dir: str):
"""
Deploy model to Blob
:param model_name: the name of the ONNX model, such as 'ssd300'
:param output_dir: the path to the output directory, such as 'models'
:return:
"""
print('Deploying model to Blob...')
blobconverter.from_onnx(model=model_name+'-sim'+'.onnx',
output_dir=output_dir,
data_type='FP16',
use_cache=True,
shaves=6)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Deploy model to OpenVINO')
parser.add_argument('--model', type=str, default="checkpoints/checkpoint_ssd300.pt",
help='the path to the model checkpoint')
parser.add_argument('--new_model', default="ssd300", type=str,
help='the name of the ONNX model')
parser.add_argument('--deploy_blob', default=True, type=bool,
help='deploy model to Blob')
args = parser.parse_args()
deploy(args.model, args.new_model)
if args.deploy_blob:
deploy_blob(args.new_model, 'models')