You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm trying to convert the pytorch model to tflite model. But I'm facing issue in providing the dummy model input to torch.onnx.export() method. Do you know what can be the dummy model input ?
Here is how my code look like:
`import torch.nn as nn
import torch.onnx
import torchvision
import torch
from onmt.model_builder import build_base_model
from onmt.utils.parse import ArgumentParser
I'm trying to convert the pytorch model to tflite model. But I'm facing issue in providing the dummy model input to torch.onnx.export() method. Do you know what can be the dummy model input ?
Here is how my code look like:
`import torch.nn as nn
import torch.onnx
import torchvision
import torch
from onmt.model_builder import build_base_model
from onmt.utils.parse import ArgumentParser
checkpoint = torch.load('model_step_1600.pt')
model_opt = ArgumentParser.ckpt_model_opts(checkpoint['opt'])
ArgumentParser.update_model_opts(model_opt)
ArgumentParser.validate_model_opts(model_opt)
vocab = checkpoint['vocab']
fields = vocab
model = build_base_model(model_opt, fields, None, checkpoint)
model.eval()
dummy_input = torch.randn(1, 3, 224, 224, requires_grad=True) # dummy_input = torch.from_numpy(X_test[0].reshape(1, -1)).float().to(device)
torch.onnx.export(model, dummy_input, 'model_simple.onnx')
`
The text was updated successfully, but these errors were encountered: