-
Notifications
You must be signed in to change notification settings - Fork 143
/
export_model.py
72 lines (58 loc) · 3 KB
/
export_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
# Copyright (C) 2018-2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import numpy as np
import argparse
import torch
import torch.nn as nn
from torch.autograd import Variable
from .sparse_conv import SparseConvONNX, SparseConvTransposeONNX
def export(num_inp_points, num_out_points, max_grid_extent,
in_channels, filters, kernel_size, transpose):
np.random.seed(324)
torch.manual_seed(32)
if transpose:
sparse_conv = SparseConvTransposeONNX(in_channels=in_channels,
filters=filters,
kernel_size=kernel_size,
use_bias=False,
normalize=False)
else:
sparse_conv = SparseConvONNX(in_channels=in_channels,
filters=filters,
kernel_size=kernel_size,
use_bias=False,
normalize=False)
# Generate a list of unique positions and add a mantissa
def gen_pos(num_points):
inp_pos = np.random.randint(0, max_grid_extent, [num_points, 3])
inp_pos = np.unique(inp_pos, axis=0).astype(np.float32)
inp_pos = torch.tensor(inp_pos) + torch.rand(inp_pos.shape, dtype=torch.float32) # [0, 1)
return inp_pos
inp_pos = gen_pos(num_inp_points)
out_pos = gen_pos(num_out_points) if num_out_points else inp_pos
features = torch.randn([inp_pos.shape[0], in_channels])
voxel_size = torch.tensor(1.0)
sparse_conv.eval()
new_kernel = torch.randn(sparse_conv.state_dict()["kernel"].shape)
sparse_conv.load_state_dict({"kernel": new_kernel,
"offset": sparse_conv.state_dict()["offset"]})
with torch.no_grad():
torch.onnx.export(sparse_conv, (features, inp_pos, out_pos, voxel_size), 'model.onnx',
input_names=['input', 'input1', 'input2', 'voxel_size'],
output_names=['output'],
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK)
ref = sparse_conv(features, inp_pos, out_pos, voxel_size)
return [features.detach().numpy(), inp_pos.detach().numpy(),
out_pos.detach().numpy()], ref.detach().numpy()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Generate ONNX model and test data')
parser.add_argument('--num_inp_points', type=int)
parser.add_argument('--num_out_points', type=int)
parser.add_argument('--max_grid_extent', type=int)
parser.add_argument('--in_channels', type=int)
parser.add_argument('--filters', type=int)
parser.add_argument('--kernel_size', type=int, nargs='+')
parser.add_argument('--transpose', action='store_true')
args = parser.parse_args()
export(args.num_inp_points, args.num_out_points, args.max_grid_extent,
args.in_channels, args.filters, args.kernel_size, args.transpose)