Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Try to convert pth to onnx. But error #898

Open
caizhaoxin opened this issue Jun 23, 2022 · 1 comment
Open

Try to convert pth to onnx. But error #898

caizhaoxin opened this issue Jun 23, 2022 · 1 comment

Comments

@caizhaoxin
Copy link

I wanna convert pth to onnx format. This is my code:

import torch
from model.faster_rcnn.vgg16 import vgg16
from model.faster_rcnn.resnet import resnet
import numpy as np
from torch.autograd import Variable

def load_model(model, pretrained_path):
print('Loading pretrained model from {}'.format(pretrained_path))
pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage)
model.load_state_dict(pretrained_dict, strict=False)
return model

output_onnx = './output.onnx'
raw_weights = './faster_rcnn_1_10_2504.pth'
pascal_classes = np.asarray(['background',
'aeroplane', 'bicycle', 'bird', 'boat',
'bottle', 'bus', 'car', 'cat', 'chair',
'cow', 'diningtable', 'dog', 'horse',
'motorbike', 'person', 'pottedplant',
'sheep', 'sofa', 'train', 'tvmonitor'])

load weight

net = resnet(pascal_classes, 101, pretrained=False, class_agnostic=False)
net.create_architecture()
checkpoint = torch.load(raw_weights)
for k in checkpoint.keys():
print(k)
net.load_state_dict(checkpoint['model'])

initilize the tensor holder here.

im_data = torch.FloatTensor(1)
im_info = torch.FloatTensor(1)
num_boxes = torch.LongTensor(1)
gt_boxes = torch.FloatTensor(1)

ship to cuda

im_data = im_data.cuda()
im_info = im_info.cuda()
num_boxes = num_boxes.cuda()
gt_boxes = gt_boxes.cuda()

make variable

im_data = Variable(im_data, volatile=True)
im_info = Variable(im_info, volatile=True)
num_boxes = Variable(num_boxes, volatile=True)
gt_boxes = Variable(gt_boxes, volatile=True)

net.eval()
print('Finished loading model!')
device = torch.device("cuda")
net = net.to(device)

input_names = ["input0"]
output_names = ["output0"]
inputs = torch.randn(1, 3, 300, 300).to(device)

output model

torch_out = torch.onnx.export(net, inputs, output_onnx, export_params=True, verbose=False,keep_initializers_as_inputs=True, input_names=input_names, output_names=output_names)

but when I run it, I got this error. How can I fix it?Thanks!
Traceback (most recent call last):
File "pth2onnx.py", line 67, in
torch_out = torch.onnx.export(net, inputs, output_onnx, export_params=True, verbose=False,keep_initializers_as_inputs=True, input_names=input_names, output_names=output_names)
File "/usr/local/lib/python3.6/dist-packages/torch/onnx/init.py", line 276, in export
custom_opsets, enable_onnx_checker, use_external_data_format)
File "/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py", line 94, in export
use_external_data_format=use_external_data_format)
File "/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py", line 701, in _export
dynamic_axes=dynamic_axes)
File "/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py", line 459, in _model_to_graph
use_new_jit_passes)
File "/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py", line 420, in _create_jit_graph
graph, torch_out = _trace_and_get_graph_from_model(model, args)
File "/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py", line 380, in _trace_and_get_graph_from_model
torch.jit._get_trace_graph(model, args, strict=False, _force_outplace=False, _return_inputs_states=True)
File "/usr/local/lib/python3.6/dist-packages/torch/jit/_trace.py", line 1139, in _get_trace_graph
outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/usr/local/lib/python3.6/dist-packages/torch/jit/_trace.py", line 130, in forward
self._force_outplace,
File "/usr/local/lib/python3.6/dist-packages/torch/jit/_trace.py", line 116, in wrapper
outs.append(self.inner(*trace_inputs))
File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 887, in _call_impl
result = self._slow_forward(*input, **kwargs)
File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 860, in _slow_forward
result = self.forward(*input, **kwargs)
TypeError: forward() missing 3 required positional arguments: 'im_info', 'gt_boxes', and 'num_boxes'

@sisinote
Copy link

Did you finish the onnx conversion ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants