Skip to content
This repository has been archived by the owner on Feb 7, 2023. It is now read-only.

Different outputs between pytorch, onnx and coreml #576

Open
marc-jv opened this issue Jun 16, 2020 · 1 comment
Open

Different outputs between pytorch, onnx and coreml #576

marc-jv opened this issue Jun 16, 2020 · 1 comment
Labels
bug Unexpected behaviour that should be corrected (type)

Comments

@marc-jv
Copy link

marc-jv commented Jun 16, 2020

🐞Describe the bug

I'm training a neural network containing an embedding layer, a linear layer, a GRU, then followed with another linear and a logsoftmax. It's trained properly and exported as well to onnx and coreml. However, when I test (that is, computing a forward step with a dummy input) the three models (.pth, .onnx and .mlmodel), the pytorch and onnx models yield the same output but the coreml doesn't.

Trace

This is what happens when I compare the pytorch output vs the coreml output:

AssertionError: 
Not equal to tolerance rtol=0.001, atol=1e-05

Mismatched elements: 98 / 100 (98%)
Max absolute difference: 0.53596115
Max relative difference: 0.1372321
 x: array([[-5.277452, -7.18906 , -3.870493, -5.237523, -5.697977, -4.906397,
        -5.19996 , -5.691385, -5.81255 , -5.201305, -5.705749, -3.767732,
        -5.457488, -6.338089, -6.073925, -5.906573, -5.663018, -6.57819 ,...
 y: array([[-5.169696, -6.974405, -3.716341, -5.348598, -6.07383 , -4.743478,
        -5.518433, -5.430327, -5.695125, -5.232084, -5.561243, -3.704498,
        -5.26874 , -6.728903, -6.277035, -5.768873, -5.754221, -6.781257,...

Unfortunately I can't attach the models since they belong to my company.

System environment (please complete the following information):

  • coremltools version: 3.4
  • onnx-coreml version: 1.3
  • OS: MacOS
  • macOS version (if applicable): 10.15.4
  • How you install python: virtualenv
  • python version: 3.7
  • any other relevant information:

Additional context

I deleted the GRU and adapted a bit the network and everything works out, so I guess there's an issue with it. The network receives as input also the initial hidden state. Also, the batch_first flag is set to false.

UPDATE

Ok so I simplified (a lot) the network and came up with the following script, which has the same issue I described above:

import torch as th
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import onnx
import onnxruntime
from onnx_coreml import convert


class NeuralNet(nn.Module):
    def __init__(self, input_size, output_size, embeddings_dim, hidden_layer):
        super(NeuralNet, self).__init__()

        self.embeddings = nn.Embedding(input_size, embeddings_dim)

        self.gru1 = nn.GRU(embeddings_dim, hidden_layer)

        self.fc_out = nn.Linear(hidden_layer, output_size)

    def forward(self, inputs, hidden):
        embedded = self.embeddings(inputs)

        outputs, hidden = self.gru1(embedded.transpose(1, 0), hidden)

        logits = self.fc_out(outputs.transpose(1, 0)[:, -1, :])

        log_outputs = F.log_softmax(logits, dim=-1)

        return log_outputs


def to_numpy(tensor, dtype=None):
    np_tensor = tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

    if dtype:
        return np_tensor.astype(dtype)
    return np_tensor


def export_to_onnx(net, x, hidden, torch_out):
    onnx_file_path = "net.onnx"

    # Export the model
    th.onnx.export(net,  # model being run
                   (x, hidden),  # model input (or a tuple for multiple inputs)
                   onnx_file_path,  # where to save the model (can be a file or file-like object)
                   input_names=['input', 'hidden'],  # the model's input names
                   output_names=['output']  # the model's output names
                   )
    onnx_model = onnx.load(onnx_file_path)
    onnx.checker.check_model(onnx_model)
    ort_session = onnxruntime.InferenceSession(onnx_file_path)

    # compute ONNX Runtime output prediction

    ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x), ort_session.get_inputs()[1].name: to_numpy(hidden)}
    ort_outs = ort_session.run(None, ort_inputs)
    # compare ONNX Runtime and PyTorch results
    np.testing.assert_allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05)
    print("Exported model has been tested with ONNXRuntime, and the result looks good!")


def export_to_coreml(x, hidden, torch_out):
    onnx_file_path = "net.onnx"
    coreml_file_path = "net.mlmodel"

    model = convert(model=onnx_file_path, minimum_ios_deployment_target='13')
    model.save(coreml_file_path)

    # compute CoreML output prediction
    output = model.predict({"input": to_numpy(x, np.float32), "hidden": to_numpy(hidden)})[
        'output']
    # compare CoreML and PyTorch results
    np.testing.assert_allclose(to_numpy(torch_out), output, rtol=1e-03, atol=1e-05)
    print("Exported model has been tested with CoreML, and the result looks good!")


if __name__ == "__main__":
    input_size = 60
    output_size = 100
    embeddings_dim = 10
    hidden_layers = 8

    net = NeuralNet(input_size, output_size, embeddings_dim, hidden_layers)

    x = th.zeros((1, 60), dtype=th.long)
    hidden = th.zeros(1, 1, hidden_layers)

    torch_out = net.forward(x, hidden)

    export_to_onnx(net, x, hidden, torch_out)
    export_to_coreml(x, hidden, torch_out)

Regarding this line:

output = model.predict({"input": to_numpy(x, np.float32), "hidden": to_numpy(hidden)})[
        'output']

I'm not sure if it makes sense to use an input of type Float, since the first layer is and embedding layer, but couldn't figure out the way to use a Long input in coreml.

UPDATE 2

As I said, if you remove the GRU everything goes well, that is, changing the network to:

class NeuralNet(nn.Module):
    def __init__(self, input_size, output_size, embeddings_dim, hidden_layer):
        super(NeuralNet, self).__init__()

        self.embeddings = nn.Embedding(input_size, embeddings_dim)

        #self.gru1 = nn.GRU(embeddings_dim, hidden_layer)

        self.fc_out = nn.Linear(embeddings_dim * input_size, output_size)

    def forward(self, inputs, hidden):
        embedded = self.embeddings(inputs)

        #outputs, hidden = self.gru1(embedded.transpose(1, 0), hidden)

        logits = self.fc_out(embedded.view(1, -1))

        log_outputs = F.log_softmax(logits, dim=-1)

        return log_outputs
@marc-jv marc-jv added the bug Unexpected behaviour that should be corrected (type) label Jun 16, 2020
@francescojalvo
Copy link

Did you have time to look into this @aseemw ? Thanks

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
bug Unexpected behaviour that should be corrected (type)
Projects
None yet
Development

No branches or pull requests

2 participants