diff --git a/python/torch_mlir/extras/onnx_importer.py b/python/torch_mlir/extras/onnx_importer.py index c6923e82df0c..a4f95c3ee778 100644 --- a/python/torch_mlir/extras/onnx_importer.py +++ b/python/torch_mlir/extras/onnx_importer.py @@ -239,6 +239,7 @@ class NodeImporter: "_p", "_b", "_nv_map", + "_none_value", ] def __init__( @@ -259,6 +260,7 @@ def __init__( self._p = parent_op self._b = block self._nv_map: Dict[str, Value] = {} + self._none_value: Optional[Value] = None @classmethod def define_function( @@ -366,8 +368,8 @@ def import_all(self, func=True): Operation.create(name="torch.operator_terminator", operands=outputs) def get_none(self): - if "" in self._nv_map: - return self._nv_map[""] + if self._none_value is not None: + return self._none_value with InsertionPoint(self._b), Location.name("onnx_importer.none"): nne = Operation.create( @@ -376,7 +378,7 @@ def get_none(self): operands=[], attributes={}, ).results[0] - self._nv_map[""] = nne + self._none_value = nne return nne def import_node(self, node: onnx.NodeProto): @@ -396,6 +398,12 @@ def import_node(self, node: onnx.NodeProto): input_values = [] input_type_protos = [] for input_name in node.input: + # ONNX uses the empty string for omitted optional inputs; it must not + # be confused with _nv_map[""], which may hold a real tensor named "". + if input_name == "": + input_values.append(self.get_none()) + input_type_protos.append(onnx.TypeProto()) + continue try: input_values.append(self._nv_map[input_name]) # Missing optional arguments will have empty types @@ -447,7 +455,8 @@ def import_node(self, node: onnx.NodeProto): self.import_regions(node.attribute, custom_op) for output_name, output_value in zip(output_names, custom_op.results): - self._nv_map[output_name] = output_value + if output_name != "": + self._nv_map[output_name] = output_value def import_attributes(self, onnx_attrs: List[onnx.AttributeProto]): attrs = {} diff --git a/test/python/onnx_importer/test_empty_string_optional_inputs.py b/test/python/onnx_importer/test_empty_string_optional_inputs.py new file mode 100644 index 000000000000..c401526805a3 --- /dev/null +++ b/test/python/onnx_importer/test_empty_string_optional_inputs.py @@ -0,0 +1,70 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# RUN: %PYTHON %s + +"""Regression for NodeImporter: ONNX input name '' means omitted optional. + +The importer must not conflate that with _nv_map[""] when an earlier node binds +a real tensor to the empty-string output name (see onnx_importer empty-string +collision fix). +""" + +import unittest + +import onnx +from onnx import TensorProto, helper + +from _torch_mlir_config import configure_context, ir, onnx_importer + + +def _minimal_collision_model() -> onnx.ModelProto: + """Identity writes to output ""; second node lists '', '' as omitted inputs.""" + inp = helper.make_tensor_value_info("x", TensorProto.FLOAT, [1, 2]) + out = helper.make_tensor_value_info("y", TensorProto.FLOAT, [1, 2]) + n1 = helper.make_node("Identity", ["x"], [""]) + n2 = helper.make_node( + "ReproEmptyStringCollision", + ["x", "", ""], + ["y"], + domain="zmc.repro", + ) + graph = helper.make_graph([n1, n2], "g", [inp], [out]) + return helper.make_model( + graph, + opset_imports=[ + helper.make_opsetid("", 21), + helper.make_opsetid("zmc.repro", 1), + ], + ) + + +class EmptyStringOptionalInputsTest(unittest.TestCase): + def test_optional_slots_use_constant_none_not_prior_tensor(self): + model = _minimal_collision_model() + ctx = ir.Context() + configure_context(ctx) + mi = onnx_importer.ModelInfo(model) + m = mi.create_module(context=ctx).operation + onnx_importer.NodeImporter.define_function(mi.main_graph, m).import_all() + asm = m.get_asm() + lines = [ + ln.strip() for ln in asm.splitlines() if "ReproEmptyStringCollision" in ln + ] + self.assertEqual( + len(lines), + 1, + msg="expected exactly one onnx.ReproEmptyStringCollision operator line", + ) + line = lines[0] + # Correct: trailing optionals are torch.constant.none uses (printed as %none, %none). + self.assertGreaterEqual( + line.count("%none"), + 2, + msg=f"expected at least two %none operands for omitted inputs, got:\n{line}", + ) + + +if __name__ == "__main__": + unittest.main()