Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions python/torch_mlir/extras/onnx_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ class NodeImporter:
"_p",
"_b",
"_nv_map",
"_none_value",
]

def __init__(
Expand All @@ -259,6 +260,7 @@ def __init__(
self._p = parent_op
self._b = block
self._nv_map: Dict[str, Value] = {}
self._none_value: Optional[Value] = None

@classmethod
def define_function(
Expand Down Expand Up @@ -366,8 +368,8 @@ def import_all(self, func=True):
Operation.create(name="torch.operator_terminator", operands=outputs)

def get_none(self):
if "" in self._nv_map:
return self._nv_map[""]
if self._none_value is not None:
return self._none_value

with InsertionPoint(self._b), Location.name("onnx_importer.none"):
nne = Operation.create(
Expand All @@ -376,7 +378,7 @@ def get_none(self):
operands=[],
attributes={},
).results[0]
self._nv_map[""] = nne
self._none_value = nne
return nne

def import_node(self, node: onnx.NodeProto):
Expand All @@ -396,6 +398,12 @@ def import_node(self, node: onnx.NodeProto):
input_values = []
input_type_protos = []
for input_name in node.input:
# ONNX uses the empty string for omitted optional inputs; it must not
# be confused with _nv_map[""], which may hold a real tensor named "".
if input_name == "":
input_values.append(self.get_none())
input_type_protos.append(onnx.TypeProto())
continue
try:
input_values.append(self._nv_map[input_name])
# Missing optional arguments will have empty types
Expand Down Expand Up @@ -447,7 +455,8 @@ def import_node(self, node: onnx.NodeProto):
self.import_regions(node.attribute, custom_op)

for output_name, output_value in zip(output_names, custom_op.results):
self._nv_map[output_name] = output_value
if output_name != "":
self._nv_map[output_name] = output_value

def import_attributes(self, onnx_attrs: List[onnx.AttributeProto]):
attrs = {}
Expand Down
70 changes: 70 additions & 0 deletions test/python/onnx_importer/test_empty_string_optional_inputs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

# RUN: %PYTHON %s

"""Regression for NodeImporter: ONNX input name '' means omitted optional.

The importer must not conflate that with _nv_map[""] when an earlier node binds
a real tensor to the empty-string output name (see onnx_importer empty-string
collision fix).
"""

import unittest

import onnx
from onnx import TensorProto, helper

from _torch_mlir_config import configure_context, ir, onnx_importer


def _minimal_collision_model() -> onnx.ModelProto:
"""Identity writes to output ""; second node lists '', '' as omitted inputs."""
inp = helper.make_tensor_value_info("x", TensorProto.FLOAT, [1, 2])
out = helper.make_tensor_value_info("y", TensorProto.FLOAT, [1, 2])
n1 = helper.make_node("Identity", ["x"], [""])
n2 = helper.make_node(
"ReproEmptyStringCollision",
["x", "", ""],
["y"],
domain="zmc.repro",
)
graph = helper.make_graph([n1, n2], "g", [inp], [out])
return helper.make_model(
graph,
opset_imports=[
helper.make_opsetid("", 21),
helper.make_opsetid("zmc.repro", 1),
],
)


class EmptyStringOptionalInputsTest(unittest.TestCase):
def test_optional_slots_use_constant_none_not_prior_tensor(self):
model = _minimal_collision_model()
ctx = ir.Context()
configure_context(ctx)
mi = onnx_importer.ModelInfo(model)
m = mi.create_module(context=ctx).operation
onnx_importer.NodeImporter.define_function(mi.main_graph, m).import_all()
asm = m.get_asm()
lines = [
ln.strip() for ln in asm.splitlines() if "ReproEmptyStringCollision" in ln
]
self.assertEqual(
len(lines),
1,
msg="expected exactly one onnx.ReproEmptyStringCollision operator line",
)
line = lines[0]
# Correct: trailing optionals are torch.constant.none uses (printed as %none, %none).
self.assertGreaterEqual(
line.count("%none"),
2,
msg=f"expected at least two %none operands for omitted inputs, got:\n{line}",
)


if __name__ == "__main__":
unittest.main()
Loading