diff --git a/onnxscript/_framework_apis/torch_2_5.py b/onnxscript/_framework_apis/torch_2_5.py index eeebbb63d..4fc6fda24 100644 --- a/onnxscript/_framework_apis/torch_2_5.py +++ b/onnxscript/_framework_apis/torch_2_5.py @@ -17,7 +17,7 @@ import pathlib from typing import Callable -from onnxscript import ir, optimizer +from onnxscript import ir, optimizer, version_converter from onnxscript.function_libs.torch_lib import registration from onnxscript.ir import _external_data @@ -51,18 +51,10 @@ def optimize(model: ir.Model) -> ir.Model: def convert_version(model: ir.Model, target_version: int) -> ir.Model: """Convert the model to the specified ONNX opset version.""" - # model_version = model.opset_import.get("") - # if model_version == target_version: - # # No conversion needed - # return model - - # # FIXME(justinchuby): version_converter does not support functions - # proto = ir.serde.serialize_model(model) - # proto = onnx.version_converter.convert_version(proto, target_version) - # return ir.serde.deserialize_model(proto) - # TODO(justinchuby): This function needs to be carefully implemented - # to handle large models. For now, we just return the model. - del target_version # Unused + # Internal flag. Will go away. + enabled = os.getenv("TORCH_ONNX_ENABLE_VERSION_CONVERSION") == "1" + if enabled: + version_converter.convert_version(model, target_version) return model diff --git a/onnxscript/_framework_apis/torch_2_6.py b/onnxscript/_framework_apis/torch_2_6.py index ec929a1d8..2e228e552 100644 --- a/onnxscript/_framework_apis/torch_2_6.py +++ b/onnxscript/_framework_apis/torch_2_6.py @@ -10,7 +10,10 @@ "get_torchlib_ops", "optimize", "save_model_with_external_data", + "torchlib_opset", ] +from typing import TYPE_CHECKING + from onnxscript import ir, optimizer from onnxscript._framework_apis.torch_2_5 import ( check_model, @@ -19,8 +22,24 @@ save_model_with_external_data, ) +if TYPE_CHECKING: + from onnxscript.onnx_opset._impl.opset18 import Opset18 + def optimize(model: ir.Model) -> ir.Model: """Optimize the model.""" optimizer.optimize_ir(model) return model + + +def torchlib_opset() -> Opset18: + """Return the default opset for torchlib.""" + import onnxscript # pylint: disable=import-outside-toplevel + + return onnxscript.opset18 # type: ignore + + +def torchlib_opset_version() -> int: + """Return the default opset version for torchlib.""" + + return torchlib_opset().version diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index a955583e9..63f692954 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -4391,7 +4391,10 @@ def aten_instance_norm( ), "running_mean and running_var must be provided when use_input_stats is False" batch_size = op.Shape(input, start=0, end=1) - bn_input = op.Reshape(input, op.Concat([1, -1], op.Shape(input, start=2), axis=0)) + bn_input = op.Reshape( + input, + op.Concat(op.Constant(value_ints=[1, -1]), op.Shape(input, start=2), axis=0), + ) weight = op.Tile(weight, batch_size) bias = op.Tile(bias, batch_size) running_mean = op.Tile(running_mean, batch_size) @@ -5225,9 +5228,8 @@ def aten_mean_dim(self: TReal, dim: INT64, keepdim: bool = False) -> TReal: if IsScalar(self): result = self else: - if IsScalar(dim): - dim = op.Unsqueeze(dim, axes=0) - result = op.ReduceMean(self, dim, keepdims=keepdim) + dims = op.Reshape(dim, op.Constant(value_ints=[-1])) + result = op.ReduceMean(self, dims, keepdims=keepdim) return result diff --git a/onnxscript/version_converter/__init__.py b/onnxscript/version_converter/__init__.py new file mode 100644 index 000000000..299373f9c --- /dev/null +++ b/onnxscript/version_converter/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +__all__ = [ + # Functions + "convert_version", +] + +from onnxscript import ir +from onnxscript.optimizer import _inliner +from onnxscript.version_converter import _version_converter + + +def convert_version(model: ir.Model, target_version: int) -> None: + """Convert the model to the specified ONNX opset version.""" + + # In functions, we can have attribute-parameters, which means we don't know the value of the attribute. + # Hence, we inline all the functions. + _inliner.inline(model) + _version_converter.convert_version(model, target_version) diff --git a/onnxscript/version_converter/_version_converter.py b/onnxscript/version_converter/_version_converter.py new file mode 100644 index 000000000..28a590bb2 --- /dev/null +++ b/onnxscript/version_converter/_version_converter.py @@ -0,0 +1,314 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Convert the model to the specified ONNX opset version.""" + +from __future__ import annotations + +import dataclasses +import functools +import logging +from typing import Callable, Sequence, Union + +import onnxscript.ir.convenience as ir_convenience +import onnxscript.rewriter.pattern as orp +from onnxscript import ir + +logger = logging.getLogger(__name__) + + +CURRENT_MAX_ONNX_OPSET = 23 + + +class VersionConverterError(RuntimeError): + """Raised when an node's version cannot be upgraded/downgraded successfully.""" + + +@dataclasses.dataclass +class Replacement: + """A replacement for a node in the graph.""" + + new_outputs: Sequence[ir.Value] + new_nodes: Sequence[ir.Node] + + +# A version-adapter function takes a node, a RewriterContext and returns +# a Replacement for the node or None (if no replacement is needed). + +ReturnValue = Union[Sequence[ir.Value], ir.Value, None] +AdapterFunction = Callable[[ir.Node, orp.RewriterContext], ReturnValue] + + +class AdapterRegistry: + """A class that maintains a registry of adapters for ops.""" + + def __init__(self): + self.op_adapters: dict[tuple[str, str, int, bool], AdapterFunction] = {} + + def lookup_adapters( + self, + domain: str, + opname: str, + original_version: int, + up_conversion: bool = True, + ) -> AdapterFunction | None: + adapter_func = self.op_adapters.get((domain, opname, original_version, up_conversion)) + if adapter_func is not None: + return adapter_func + return None + + def register( + self, opname: str, domain: str = "", node_version=None, up_conversion=True + ) -> Callable[[AdapterFunction], AdapterFunction]: + """Register an adapter based on the domain, operator type, node version and whether to upgrade/downgrade node version""" + + def decorator(function: AdapterFunction) -> AdapterFunction: + @functools.wraps(function) + def wrapped_function(*args, **kwargs): + return function(*args, **kwargs) + + self.op_adapters[(domain, opname, node_version, up_conversion)] = function + return wrapped_function + + return decorator + + +registry: AdapterRegistry = AdapterRegistry() + +register = registry.register + + +def _get_input(node: ir.Node, index: int) -> ir.Value | None: + if index < len(node.inputs): + return node.inputs[index] + return None + + +def _get_int_attribute(node: ir.Node, name: str, default: int | None = None) -> int | None: + if name in node.attributes: + attr = node.attributes[name] + if not isinstance(attr, ir.Attr): + return None + attr_val = attr.value + if isinstance(attr_val, int): + return attr_val + # This is an invalid model: attribute has invalid/unexpected type. + # For now, we just return None. We could raise an error too. + return None + return default + + +def _get_str_attribute(node: ir.Node, name: str, default: str | None = None) -> str | None: + if name in node.attributes: + attr = node.attributes[name] + if not isinstance(attr, ir.Attr): + return None + attr_val = attr.value + if isinstance(attr_val, str): + return attr_val + # This is an invalid model: attribute has invalid/unexpected type. + # For now, we just return None. We could raise an error too. + return None + return default + + +## Op-specific adapters + +# Opset 19 -> 20 + + +@register("DFT", node_version=19, up_conversion=True) +def dft_19_20(node: ir.Node, op): + input = node.inputs[0] + inverse = _get_int_attribute(node, "inverse", 0) + onesided = _get_int_attribute(node, "onesided", 0) + axis = _get_int_attribute(node, "axis", None) + if axis is not None: + axis_value = op.Constant(value_int=axis) + return op.DFT(input, axis_value, inverse=inverse, onesided=onesided) + return None + + +@register("GridSample", node_version=19, up_conversion=True) +def gridsample_19_20(node: ir.Node, op): + x = node.inputs[0] + grid = node.inputs[1] + align_corners = _get_int_attribute(node, "align_corners", 0) + mode = _get_str_attribute(node, "mode", "linear") + padding_mode = _get_str_attribute(node, "padding_mode", "zeros") + if mode == "bilinear": + return op.GridSample( + x, grid, align_corners=align_corners, mode="linear", padding_mode=padding_mode + ) + elif mode == "bicubic": + return op.GridSample( + x, grid, align_corners=align_corners, mode="cubic", padding_mode=padding_mode + ) + return None + + +# Opset 20 -> 21 + + +@register("GroupNormalization", node_version=20, up_conversion=True) +def groupnormalization_20_21(node: ir.Node, op): + x = _get_input(node, 0) + scale = _get_input(node, 1) + bias = _get_input(node, 2) + if x is None or scale is None or bias is None: + raise VersionConverterError(f"Missing input for {node}") + + x_shape = x.shape + if x_shape is None: + raise VersionConverterError(f"Missing required shape for {x}") + num_channels = x_shape[1] + if not isinstance(num_channels, int): + return None + + scale_shape = scale.shape + bias_shape = bias.shape + if scale_shape is None or bias_shape is None: + return None + if not isinstance(scale_shape[0], int) or not isinstance(bias_shape[0], int): + return None + + num_groups = _get_int_attribute(node, "num_groups", None) + if num_groups is None: + raise VersionConverterError("Missing required attribute: num_groups") + if ( + num_groups != num_channels + and num_groups == scale_shape[0] + and num_groups == bias_shape[0] + ): + reshape_1_sizes = op.Constant(value_ints=[-1, 1]) + reshape_2_sizes = op.Constant(value_ints=[-1]) + c_div = int(num_channels / num_groups) + expand_sizes = op.Constant(value_ints=[1, c_div]) + + # Modify scale input + scale_reshape_1 = op.Reshape(scale, reshape_1_sizes) + scale_expand = op.Expand(scale_reshape_1, expand_sizes) + scale_reshape_2 = op.Reshape(scale_expand, reshape_2_sizes) + + # Modify bias input + bias_reshape_1 = op.Reshape(bias, reshape_1_sizes) + bias_expand = op.Expand(bias_reshape_1, expand_sizes) + bias_reshape_2 = op.Reshape(bias_expand, reshape_2_sizes) + + return op.GroupNormalization(x, scale_reshape_2, bias_reshape_2, num_groups=num_groups) + return None + + +class _VersionConverter: + opset_imports: dict[str, int] + model_version: int + + def __init__(self, target_version: int): + self.target_version = target_version + + def _upgrade_version(self, node: ir.Node, opset_version: int, up_conversion: bool) -> None: + if up_conversion is True: + node.version = opset_version + 1 + else: + node.version = opset_version - 1 + + def process_node( + self, node: ir.Node, opset_version: int, up_conversion: bool = True + ) -> Replacement | None: + if node.domain not in {"", "ai.onnx"}: + return None + adapter = registry.lookup_adapters( + node.domain, node.op_type, opset_version, up_conversion + ) + if adapter is None: + return None + context = orp.RewriterContext() + output = adapter(node, context) + if output is not None: + if isinstance(output, ir.Value): + output = [output] + return Replacement(output, context.nodes) + return None + + def replace_node(self, node: ir.Node, replacement, root: ir.Graph | ir.Function) -> None: + logger.debug("Replacing node: %s::%s %s", node.domain, node.op_type, node.name) + + ir_convenience.replace_nodes_and_values( + root, node, [node], replacement.new_nodes, node.outputs, replacement.new_outputs + ) + + def visit_attribute(self, attr: ir.Attr | ir.RefAttr) -> None: + if isinstance(attr, ir.Attr): + if attr.type == ir.AttributeType.GRAPH: + self.visit_graph(attr.value) # type: ignore[arg-type] + elif attr.type == ir.AttributeType.GRAPHS: + for graph in attr.value: + self.visit_graph(graph) # type: ignore[arg-type] + + def visit_node( + self, + node: ir.Node, + root: ir.Graph | ir.Function, + opset_version: int, + up_conversion: bool = True, + ) -> None: + replacement = self.process_node(node, opset_version, up_conversion) + if replacement is None: + # No change. Process attributes. + for attr in node.attributes.values(): + self.visit_attribute(attr) + return None + else: + self.replace_node(node, replacement, root) + return None + + def visit_graph(self, graph: ir.Graph) -> None: + if self.target_version > CURRENT_MAX_ONNX_OPSET: + logger.warning( + "Conversion to target opset: %s not currently supported.", + self.target_version, + ) + return None + for node in graph: + up_conversion = True + if node.version is None: + node.version = self.model_version + # Iterate each node from current node version -> target version + # and updating node based on the correct adapter + # Up-conversion [ver->ver+1] or down-conversion [ver->ver-1] + # TODO(shubhambhokare1): Remove once down-conversion adapters are supoorted + if self.target_version < node.version: + up_conversion = False + logger.warning( + "Target opset: %s less than %s, downstream version conversion not currently handled.", + self.target_version, + self.model_version, + ) + return None + for opset_version in range(node.version, self.target_version): + try: + self.visit_node(node, graph, opset_version, up_conversion) + self._upgrade_version(node, opset_version, up_conversion) + except VersionConverterError as e: + logger.warning( + "Skipping version conversion for node %s due to exception: %s", + node.op_type, + e, + ) + return None + + def visit_model(self, model: ir.Model) -> None: + self.opset_imports = model.opset_imports + model_version = self.opset_imports.get("") + if model_version is None: + model_version = model.opset_imports.get("ai.onnx") + if model_version is None: + return None + self.model_version = model_version + self.visit_graph(model.graph) + return None + + +def convert_version(model: ir.Model, target_version: int) -> None: + """Convert the model to the specified ONNX opset version.""" + version_converter = _VersionConverter(target_version=target_version) + version_converter.visit_model(model) diff --git a/onnxscript/version_converter/_version_converter_test.py b/onnxscript/version_converter/_version_converter_test.py new file mode 100644 index 000000000..472ffe2e5 --- /dev/null +++ b/onnxscript/version_converter/_version_converter_test.py @@ -0,0 +1,332 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import unittest + +import onnx.checker +import onnx.defs +import onnx.parser +import onnx.shape_inference + +from onnxscript import ir, version_converter + + +class ApapterCoverageTest(unittest.TestCase): + def get_all_unique_schema_versions(self) -> dict[str, list]: + """Collect all unique versions of ONNX standard domain ops""" + op_version_dict = {} + all_schemas = onnx.defs.get_all_schemas_with_history() + for schema in all_schemas: + if schema.name not in op_version_dict: + op_version_dict[schema.name] = [schema.since_version] + else: + if schema.since_version not in op_version_dict[schema.name]: + op_version_dict[schema.name].append(schema.since_version) + return op_version_dict + + # TODO(shubhambhokare1) : Using existing onnx testing suite to verify operator adapter's functionality + def test_upstream_coverage(self): + op_version_dict = self.get_all_unique_schema_versions() + op_upgrades = [] + for op_type in op_version_dict: # pylint: disable=consider-using-dict-items + for opset_version in op_version_dict[op_type]: + op_upgrades.append((op_type, opset_version)) + + adapter_list = version_converter._version_converter.registry.op_adapters # pylint: disable=protected-access + for adapter_sig in adapter_list: + adapter_info = list(adapter_sig) + domain, name, upgrade_version = ( + adapter_info[0], + adapter_info[1], + adapter_info[2] + 1, + ) + self.assertEqual(domain, "") + self.assertIn((name, upgrade_version), op_upgrades) + + def test_version_convert_non_standard_onnx_domain(self): + model_proto = onnx.parser.parse_model( + """ + + agraph (float[4, 512, 512] input_x, float[4, 1024, 1024] input_y) => (float[4, 1024, 1024] output) + { + shape_a = Constant() + reshape_x = Reshape (input_x, shape_a) + shape_b = Constant() + reshape_y = Reshape (input_x, shape_b) + gridsample = GridSample (reshape_x, reshape_y) + shape_c = Constant() + output = Reshape (gridsample, shape_c) + } + """ + ) + model = ir.serde.deserialize_model(model_proto) + self.assertEqual(model.graph.node(4).op_type, "GridSample") + self.assertEqual(model.graph.node(4).attributes["mode"].value, "bilinear") + + target_version = 20 + version_converter.convert_version(model, target_version=target_version) + + self.assertEqual(model.graph.node(0).op_type, "Constant") + self.assertEqual(model.graph.node(0).version, None) + self.assertEqual(model.graph.node(1).op_type, "Reshape") + self.assertEqual(model.graph.node(1).version, None) + self.assertEqual(model.graph.node(4).op_type, "GridSample") + self.assertEqual(model.graph.node(4).version, None) + self.assertEqual(model.graph.node(4).attributes["mode"].value, "bilinear") + + +class VersionConverter18to17Test(unittest.TestCase): + def test_version_convert_compatible(self): + model_proto = onnx.parser.parse_model( + """ + + agraph (float[1, 4, 512, 512] input_x, float[1, 4, 512, 64] input_y) => (float[1, 4, 512, 64] output) + { + shape_a = Constant() + reshape_x = Reshape (input_x, shape_a) + shape_b = Constant() + reshape_y = Reshape (input_y, shape_b) + matmul = MatMul (reshape_x, reshape_y) + shape_c = Constant() + output = Reshape (matmul, shape_c) + } + """ + ) + model = ir.serde.deserialize_model(model_proto) + target_version = 17 + version_converter.convert_version(model, target_version=target_version) + + +class VersionConverter18to19Test(unittest.TestCase): + def test_version_convert_compatible(self): + model_proto = onnx.parser.parse_model( + """ + + agraph (float[1, 4, 512, 512] input_x, float[1, 4, 512, 64] input_y) => (float[1, 4, 512, 64] output) + { + shape_a = Constant() + reshape_x = Reshape (input_x, shape_a) + shape_b = Constant() + reshape_y = Reshape (input_y, shape_b) + matmul = MatMul (reshape_x, reshape_y) + shape_c = Constant() + output = Reshape (matmul, shape_c) + } + """ + ) + model = ir.serde.deserialize_model(model_proto) + target_version = 19 + version_converter.convert_version(model, target_version=target_version) + + self.assertEqual(model.graph.node(0).op_type, "Constant") + self.assertEqual(model.graph.node(0).version, 19) + self.assertEqual(model.graph.node(1).op_type, "Reshape") + self.assertEqual(model.graph.node(1).version, 19) + self.assertEqual(model.graph.node(4).op_type, "MatMul") + self.assertEqual(model.graph.node(4).version, 19) + + +class VersionConverter19to20Test(unittest.TestCase): + def test_version_convert_compatible(self): + model_proto = onnx.parser.parse_model( + """ + + agraph (float[4, 512, 512] input_x) => (float[4, 257, 64, 2] output) + { + shape_a = Constant() + reshape_x = Reshape (input_x, shape_a) + dft = DFT (reshape_x) + shape_c = Constant() + output = Reshape (dft, shape_c) + } + """ + ) + model = ir.serde.deserialize_model(model_proto) + target_version = 20 + version_converter.convert_version(model, target_version=target_version) + + self.assertEqual(model.graph.node(0).op_type, "Constant") + self.assertEqual(model.graph.node(0).version, 20) + self.assertEqual(model.graph.node(1).op_type, "Reshape") + self.assertEqual(model.graph.node(1).version, 20) + self.assertEqual(model.graph.node(2).op_type, "Constant") + self.assertEqual(model.graph.node(3).version, 20) + self.assertEqual(model.graph.node(3).op_type, "DFT") + self.assertEqual(model.graph.node(3).version, 20) + self.assertEqual(len(model.graph.node(3).inputs), 2) + + def test_version_convert_gridsample_linear(self): + model_proto = onnx.parser.parse_model( + """ + + agraph (float[4, 512, 512] input_x, float[4, 1024, 1024] input_y) => (float[4, 1024, 1024] output) + { + shape_a = Constant() + reshape_x = Reshape (input_x, shape_a) + shape_b = Constant() + reshape_y = Reshape (input_x, shape_b) + gridsample = GridSample (reshape_x, reshape_y) + shape_c = Constant() + output = Reshape (gridsample, shape_c) + } + """ + ) + model = ir.serde.deserialize_model(model_proto) + self.assertEqual(model.graph.node(4).op_type, "GridSample") + self.assertEqual(model.graph.node(4).attributes["mode"].value, "bilinear") + + target_version = 20 + version_converter.convert_version(model, target_version=target_version) + + self.assertEqual(model.graph.node(0).op_type, "Constant") + self.assertEqual(model.graph.node(0).version, 20) + self.assertEqual(model.graph.node(1).op_type, "Reshape") + self.assertEqual(model.graph.node(1).version, 20) + self.assertEqual(model.graph.node(4).op_type, "GridSample") + self.assertEqual(model.graph.node(4).version, 20) + self.assertEqual(model.graph.node(4).attributes["mode"].value, "linear") + + def test_version_convert_gridsample_cubic(self): + model_proto = onnx.parser.parse_model( + """ + + agraph (float[4, 512, 512] input_x, float[4, 1024, 1024] input_y) => (float[4, 1024, 1024] output) + { + shape_a = Constant() + reshape_x = Reshape (input_x, shape_a) + shape_b = Constant() + reshape_y = Reshape (input_x, shape_b) + gridsample = GridSample (reshape_x, reshape_y) + shape_c = Constant() + output = Reshape (gridsample, shape_c) + } + """ + ) + model = ir.serde.deserialize_model(model_proto) + self.assertEqual(model.graph.node(4).op_type, "GridSample") + self.assertEqual(model.graph.node(4).attributes["mode"].value, "bicubic") + + target_version = 20 + version_converter.convert_version(model, target_version=target_version) + + self.assertEqual(model.graph.node(0).op_type, "Constant") + self.assertEqual(model.graph.node(0).version, 20) + self.assertEqual(model.graph.node(1).op_type, "Reshape") + self.assertEqual(model.graph.node(1).version, 20) + self.assertEqual(model.graph.node(4).op_type, "GridSample") + self.assertEqual(model.graph.node(4).version, 20) + self.assertEqual(model.graph.node(4).attributes["mode"].value, "cubic") + + def test_version_convert_inline(self): + model_proto = onnx.parser.parse_model( + """ + + agraph (float[4, 512, 512] input_x, float[4, 1024, 1024] input_y) => (float[4, 257, 64, 2] output) + { + shape_a = Constant() + reshape_x = Reshape (input_x, shape_a) + shape_b = Constant() + reshape_y = Reshape (input_x, shape_b) + gridsample = GridSample (reshape_x, reshape_y) + output = foo(gridsample) + } + + + foo (x) => (dft) { + dft = DFT (x) + } + """ + ) + model = ir.serde.deserialize_model(model_proto) + target_version = 20 + version_converter.convert_version(model, target_version=target_version) + + self.assertEqual(model.graph.node(0).op_type, "Constant") + self.assertEqual(model.graph.node(0).version, 20) + self.assertEqual(model.graph.node(1).op_type, "Reshape") + self.assertEqual(model.graph.node(1).version, 20) + self.assertEqual(model.graph.node(4).op_type, "GridSample") + self.assertEqual(model.graph.node(4).version, 20) + self.assertEqual(model.graph.node(4).attributes["mode"].value, "linear") + self.assertEqual(model.graph.node(6).op_type, "DFT") + self.assertEqual(model.graph.node(6).version, 20) + self.assertEqual(len(model.graph.node(6).inputs), 2) + + +class VersionConverter20to21Test(unittest.TestCase): + def test_version_groupnorm(self): + model_proto = onnx.parser.parse_model( + """ + + agraph (float[1, 4, 512, 512] input_x, float[2] scale, float[2] bias) => (float[4, 512, 512] output) + { + groupnorm = GroupNormalization (input_x, scale, bias) + shape_c = Constant() + output = Reshape (groupnorm, shape_c) + } + """ + ) + model = ir.serde.deserialize_model(model_proto) + target_version = 21 + version_converter.convert_version(model, target_version=target_version) + + self.assertEqual(model.graph.node(3).op_type, "Reshape") + self.assertEqual(model.graph.node(3).version, 21) + self.assertEqual(model.graph.node(4).op_type, "Expand") + self.assertEqual(model.graph.node(4).version, 21) + self.assertEqual(model.graph.node(5).op_type, "Reshape") + self.assertEqual(model.graph.node(5).version, 21) + self.assertEqual(model.graph.node(6).op_type, "Reshape") + self.assertEqual(model.graph.node(6).version, 21) + self.assertEqual(model.graph.node(7).op_type, "Expand") + self.assertEqual(model.graph.node(7).version, 21) + self.assertEqual(model.graph.node(8).op_type, "Reshape") + self.assertEqual(model.graph.node(8).version, 21) + self.assertEqual(model.graph.node(9).op_type, "GroupNormalization") + self.assertEqual(model.graph.node(9).version, 21) + + def test_version_groupnorm_no_bias(self): + model_proto = onnx.parser.parse_model( + """ + + agraph (float[1, 4, 512, 512] input_x, float[2] scale) => (float[4, 512, 512] output) + { + groupnorm = GroupNormalization (input_x, scale) + shape_c = Constant() + output = Reshape (groupnorm, shape_c) + } + """ + ) + model = ir.serde.deserialize_model(model_proto) + target_version = 21 + version_converter.convert_version(model, target_version=target_version) + + self.assertEqual(model.graph.node(0).op_type, "GroupNormalization") + self.assertEqual(model.graph.node(0).version, 20) + + +class VersionConverter23to24Test(unittest.TestCase): + def test_version_convert_compatible(self): + model_proto = onnx.parser.parse_model( + """ + + agraph (float[1, 4, 512, 512] input_x, float[1, 4, 512, 64] input_y) => (float[1, 4, 512, 64] output) + { + shape_a = Constant() + reshape_x = Reshape (input_x, shape_a) + shape_b = Constant() + reshape_y = Reshape (input_y, shape_b) + matmul = MatMul (reshape_x, reshape_y) + shape_c = Constant() + output = Reshape (matmul, shape_c) + } + """ + ) + model = ir.serde.deserialize_model(model_proto) + target_version = 24 + version_converter.convert_version(model, target_version=target_version) + + +if __name__ == "__main__": + unittest.main()