From 4bb447758c79cd9c26743ce44a95c45bd044887c Mon Sep 17 00:00:00 2001 From: eesoymilk Date: Sat, 9 Mar 2024 21:37:34 +0800 Subject: [PATCH 1/4] Bug Fix: Always index into outputs whose operator is supupposed to have sequence-like output. Include op types are: ['If', 'Loop', 'Scan', 'SequenceMap', 'Split'] --- onnx2torch/converter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnx2torch/converter.py b/onnx2torch/converter.py index 891fd793..8ef6927a 100644 --- a/onnx2torch/converter.py +++ b/onnx2torch/converter.py @@ -124,7 +124,7 @@ def convert( # pylint: disable=too-many-locals, too-many-branches, too-many-sta torch_input_node = torch_nodes[onnx_input_node.unique_name] # Get only one needed output of torch_input_node by index - if len(onnx_input_node.output_values) > 1: + if len(onnx_input_node.output_values) > 1 or onnx_input_node._proto.op_type in ('If', 'Loop', 'Scan', 'SequenceMap', 'Split'): index = onnx_input_node.output_values.index(value_name) torch_input_node = torch_graph.call_function(getitem, args=(torch_input_node, index)) torch_nodes[name + '_split_output'] = torch_input_node From 04265ca5aa607bf62af6cf66787892b645dced1e Mon Sep 17 00:00:00 2001 From: eesoymilk Date: Tue, 2 Apr 2024 19:31:42 +0800 Subject: [PATCH 2/4] Fix pylint issue: try to access private member --- onnx2torch/converter.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/onnx2torch/converter.py b/onnx2torch/converter.py index 8ef6927a..94cceb75 100644 --- a/onnx2torch/converter.py +++ b/onnx2torch/converter.py @@ -124,7 +124,13 @@ def convert( # pylint: disable=too-many-locals, too-many-branches, too-many-sta torch_input_node = torch_nodes[onnx_input_node.unique_name] # Get only one needed output of torch_input_node by index - if len(onnx_input_node.output_values) > 1 or onnx_input_node._proto.op_type in ('If', 'Loop', 'Scan', 'SequenceMap', 'Split'): + if len(onnx_input_node.output_values) > 1 or onnx_input_node.proto.op_type in ( + 'If', + 'Loop', + 'Scan', + 'SequenceMap', + 'Split' + ): index = onnx_input_node.output_values.index(value_name) torch_input_node = torch_graph.call_function(getitem, args=(torch_input_node, index)) torch_nodes[name + '_split_output'] = torch_input_node From a67b7df67bf19b6c444f1bda76114074b010f557 Mon Sep 17 00:00:00 2001 From: eesoymilk Date: Tue, 2 Apr 2024 19:37:34 +0800 Subject: [PATCH 3/4] Refactor: adhere to black format --- onnx2torch/converter.py | 79 +++++++++++++++++++++++++++++++---------- 1 file changed, 60 insertions(+), 19 deletions(-) diff --git a/onnx2torch/converter.py b/onnx2torch/converter.py index 94cceb75..75d70b50 100644 --- a/onnx2torch/converter.py +++ b/onnx2torch/converter.py @@ -17,7 +17,9 @@ def _remove_initializers_from_input(model: ModelProto) -> ModelProto: graph_inputs = model.graph.input - graph_inputs_mapping = {one_input.name: one_input for one_input in graph_inputs} + graph_inputs_mapping = { + one_input.name: one_input for one_input in graph_inputs + } for initializer in model.graph.initializer: if initializer.name in graph_inputs_mapping: @@ -29,10 +31,14 @@ def _remove_initializers_from_input(model: ModelProto) -> ModelProto: class InitializersContainer(nn.Module): """Module for storing initializers in torch fx graph.""" - def add_initializer(self, name: str, initializer: torch.Tensor) -> None: # pylint: disable=missing-docstring + def add_initializer( + self, name: str, initializer: torch.Tensor + ) -> None: # pylint: disable=missing-docstring self.register_buffer(name, initializer) - def forward(self, *args, **kwargs): # pylint: disable=missing-function-docstring + def forward( + self, *args, **kwargs + ): # pylint: disable=missing-function-docstring raise RuntimeError('Got unexpected "forward" on constant container') @@ -72,10 +78,15 @@ def convert( # pylint: disable=too-many-locals, too-many-branches, too-many-sta onnx_model = safe_shape_inference(onnx_model_or_path) if onnx_model.ir_version < 3: - raise NotImplementedError('Onnx IR is too old (minimal supported version is 3).') + raise NotImplementedError( + 'Onnx IR is too old (minimal supported version is 3).' + ) onnx_model = _remove_initializers_from_input(onnx_model) - opset_import = {opsetid_proto.domain: opsetid_proto.version for opsetid_proto in onnx_model.opset_import} + opset_import = { + opsetid_proto.domain: opsetid_proto.version + for opsetid_proto in onnx_model.opset_import + } onnx_graph = OnnxGraph(onnx_model.graph) # pylint: disable=no-member torch_graph = fx.Graph() @@ -89,7 +100,9 @@ def convert( # pylint: disable=too-many-locals, too-many-branches, too-many-sta for input_value, name in enumerate(onnx_graph.input_values, 1): if save_input_names: if not name.isidentifier(): - raise ValueError(f'Input name "{name}" cannot be used as name of placeholder in fx.GraphModule.') + raise ValueError( + f'Input name "{name}" cannot be used as name of placeholder in fx.GraphModule.' + ) placeholder_name = name else: @@ -124,53 +137,81 @@ def convert( # pylint: disable=too-many-locals, too-many-branches, too-many-sta torch_input_node = torch_nodes[onnx_input_node.unique_name] # Get only one needed output of torch_input_node by index - if len(onnx_input_node.output_values) > 1 or onnx_input_node.proto.op_type in ( + if len( + onnx_input_node.output_values + ) > 1 or onnx_input_node.proto.op_type in ( 'If', 'Loop', 'Scan', 'SequenceMap', - 'Split' + 'Split', ): index = onnx_input_node.output_values.index(value_name) - torch_input_node = torch_graph.call_function(getitem, args=(torch_input_node, index)) + torch_input_node = torch_graph.call_function( + getitem, args=(torch_input_node, index) + ) torch_nodes[name + '_split_output'] = torch_input_node args.append(torch_input_node) elif value_type == ValueType.GRAPH_INITIALIZER: # The name of pytorch buffer must not contain '.'(dot) - len_torch_initializers = sum(1 for _ in torch_initializers.buffers()) + len_torch_initializers = sum( + 1 for _ in torch_initializers.buffers() + ) torch_buffer_name = f'onnx_initializer_{len_torch_initializers}' if value_name not in torch_nodes: torch_initializers.add_initializer( torch_buffer_name, onnx_graph.initializers[value_name].to_torch(), ) - torch_nodes[torch_buffer_name] = torch_graph.get_attr(f'initializers.{torch_buffer_name}') + torch_nodes[torch_buffer_name] = torch_graph.get_attr( + f'initializers.{torch_buffer_name}' + ) args.append(torch_nodes[torch_buffer_name]) elif value_type == ValueType.EMPTY: args.append(None) else: - raise RuntimeError(f'Got unexpected input value type ({value_type})') + raise RuntimeError( + f'Got unexpected input value type ({value_type})' + ) # Collect kwargs if there are some skipped args kwargs = {} if None in args: first_skipped_arg = args.index(None) - forward_args = tuple(inspect.signature(torch_module.forward).parameters.keys()) + forward_args = tuple( + inspect.signature(torch_module.forward).parameters.keys() + ) forward_args = forward_args[first_skipped_arg : len(args)] - args, kwargs_values = args[:first_skipped_arg], args[first_skipped_arg:] - kwargs.update({name: value for name, value in zip(forward_args, kwargs_values) if value is not None}) - - torch_nodes[name] = torch_graph.call_module(module_name=name, args=tuple(args), kwargs=kwargs) + args, kwargs_values = ( + args[:first_skipped_arg], + args[first_skipped_arg:], + ) + kwargs.update( + { + name: value + for name, value in zip(forward_args, kwargs_values) + if value is not None + } + ) + + torch_nodes[name] = torch_graph.call_module( + module_name=name, args=tuple(args), kwargs=kwargs + ) # Create output nodes - onnx_output_nodes = [onnx_graph.value_as_node_output(value_name)[0] for value_name in onnx_graph.output_values] + onnx_output_nodes = [ + onnx_graph.value_as_node_output(value_name)[0] + for value_name in onnx_graph.output_values + ] # Delete duplicates and save order onnx_output_nodes = list(OrderedDict.fromkeys(onnx_output_nodes)) - torch_output_nodes = [torch_nodes[onnx_node.unique_name] for onnx_node in onnx_output_nodes] + torch_output_nodes = [ + torch_nodes[onnx_node.unique_name] for onnx_node in onnx_output_nodes + ] if len(torch_output_nodes) == 1: torch_output_nodes = torch_output_nodes[0] torch_graph.output(torch_output_nodes) From e4e501592321757778f55c121103337e07c93e03 Mon Sep 17 00:00:00 2001 From: eesoymilk Date: Tue, 2 Apr 2024 19:38:00 +0800 Subject: [PATCH 4/4] Refactor: adhere to black format --- onnx2torch/converter.py | 77 ++++++++++------------------------------- 1 file changed, 18 insertions(+), 59 deletions(-) diff --git a/onnx2torch/converter.py b/onnx2torch/converter.py index 75d70b50..37dc6d03 100644 --- a/onnx2torch/converter.py +++ b/onnx2torch/converter.py @@ -17,9 +17,7 @@ def _remove_initializers_from_input(model: ModelProto) -> ModelProto: graph_inputs = model.graph.input - graph_inputs_mapping = { - one_input.name: one_input for one_input in graph_inputs - } + graph_inputs_mapping = {one_input.name: one_input for one_input in graph_inputs} for initializer in model.graph.initializer: if initializer.name in graph_inputs_mapping: @@ -31,14 +29,10 @@ def _remove_initializers_from_input(model: ModelProto) -> ModelProto: class InitializersContainer(nn.Module): """Module for storing initializers in torch fx graph.""" - def add_initializer( - self, name: str, initializer: torch.Tensor - ) -> None: # pylint: disable=missing-docstring + def add_initializer(self, name: str, initializer: torch.Tensor) -> None: # pylint: disable=missing-docstring self.register_buffer(name, initializer) - def forward( - self, *args, **kwargs - ): # pylint: disable=missing-function-docstring + def forward(self, *args, **kwargs): # pylint: disable=missing-function-docstring raise RuntimeError('Got unexpected "forward" on constant container') @@ -78,15 +72,10 @@ def convert( # pylint: disable=too-many-locals, too-many-branches, too-many-sta onnx_model = safe_shape_inference(onnx_model_or_path) if onnx_model.ir_version < 3: - raise NotImplementedError( - 'Onnx IR is too old (minimal supported version is 3).' - ) + raise NotImplementedError('Onnx IR is too old (minimal supported version is 3).') onnx_model = _remove_initializers_from_input(onnx_model) - opset_import = { - opsetid_proto.domain: opsetid_proto.version - for opsetid_proto in onnx_model.opset_import - } + opset_import = {opsetid_proto.domain: opsetid_proto.version for opsetid_proto in onnx_model.opset_import} onnx_graph = OnnxGraph(onnx_model.graph) # pylint: disable=no-member torch_graph = fx.Graph() @@ -100,9 +89,7 @@ def convert( # pylint: disable=too-many-locals, too-many-branches, too-many-sta for input_value, name in enumerate(onnx_graph.input_values, 1): if save_input_names: if not name.isidentifier(): - raise ValueError( - f'Input name "{name}" cannot be used as name of placeholder in fx.GraphModule.' - ) + raise ValueError(f'Input name "{name}" cannot be used as name of placeholder in fx.GraphModule.') placeholder_name = name else: @@ -137,9 +124,7 @@ def convert( # pylint: disable=too-many-locals, too-many-branches, too-many-sta torch_input_node = torch_nodes[onnx_input_node.unique_name] # Get only one needed output of torch_input_node by index - if len( - onnx_input_node.output_values - ) > 1 or onnx_input_node.proto.op_type in ( + if len(onnx_input_node.output_values) > 1 or onnx_input_node.proto.op_type in ( 'If', 'Loop', 'Scan', @@ -147,71 +132,45 @@ def convert( # pylint: disable=too-many-locals, too-many-branches, too-many-sta 'Split', ): index = onnx_input_node.output_values.index(value_name) - torch_input_node = torch_graph.call_function( - getitem, args=(torch_input_node, index) - ) + torch_input_node = torch_graph.call_function(getitem, args=(torch_input_node, index)) torch_nodes[name + '_split_output'] = torch_input_node args.append(torch_input_node) elif value_type == ValueType.GRAPH_INITIALIZER: # The name of pytorch buffer must not contain '.'(dot) - len_torch_initializers = sum( - 1 for _ in torch_initializers.buffers() - ) + len_torch_initializers = sum(1 for _ in torch_initializers.buffers()) torch_buffer_name = f'onnx_initializer_{len_torch_initializers}' if value_name not in torch_nodes: torch_initializers.add_initializer( torch_buffer_name, onnx_graph.initializers[value_name].to_torch(), ) - torch_nodes[torch_buffer_name] = torch_graph.get_attr( - f'initializers.{torch_buffer_name}' - ) + torch_nodes[torch_buffer_name] = torch_graph.get_attr(f'initializers.{torch_buffer_name}') args.append(torch_nodes[torch_buffer_name]) elif value_type == ValueType.EMPTY: args.append(None) else: - raise RuntimeError( - f'Got unexpected input value type ({value_type})' - ) + raise RuntimeError(f'Got unexpected input value type ({value_type})') # Collect kwargs if there are some skipped args kwargs = {} if None in args: first_skipped_arg = args.index(None) - forward_args = tuple( - inspect.signature(torch_module.forward).parameters.keys() - ) + forward_args = tuple(inspect.signature(torch_module.forward).parameters.keys()) forward_args = forward_args[first_skipped_arg : len(args)] - args, kwargs_values = ( - args[:first_skipped_arg], - args[first_skipped_arg:], - ) - kwargs.update( - { - name: value - for name, value in zip(forward_args, kwargs_values) - if value is not None - } - ) - - torch_nodes[name] = torch_graph.call_module( - module_name=name, args=tuple(args), kwargs=kwargs - ) + args, kwargs_values = args[:first_skipped_arg], args[first_skipped_arg:] + kwargs.update({name: value for name, value in zip(forward_args, kwargs_values) if value is not None}) + + torch_nodes[name] = torch_graph.call_module(module_name=name, args=tuple(args), kwargs=kwargs) # Create output nodes - onnx_output_nodes = [ - onnx_graph.value_as_node_output(value_name)[0] - for value_name in onnx_graph.output_values - ] + onnx_output_nodes = [onnx_graph.value_as_node_output(value_name)[0] for value_name in onnx_graph.output_values] # Delete duplicates and save order onnx_output_nodes = list(OrderedDict.fromkeys(onnx_output_nodes)) - torch_output_nodes = [ - torch_nodes[onnx_node.unique_name] for onnx_node in onnx_output_nodes - ] + torch_output_nodes = [torch_nodes[onnx_node.unique_name] for onnx_node in onnx_output_nodes] if len(torch_output_nodes) == 1: torch_output_nodes = torch_output_nodes[0] torch_graph.output(torch_output_nodes)