From a5d668ffd8012e9ae190c2f0b9a485021f8bf9ac Mon Sep 17 00:00:00 2001 From: gcunhase <4861122+gcunhase@users.noreply.github.com> Date: Fri, 9 Jan 2026 13:42:46 -0500 Subject: [PATCH 1/8] Add fix for Resize mismatch error Signed-off-by: gcunhase <4861122+gcunhase@users.noreply.github.com> --- modelopt/onnx/autocast/precisionconverter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modelopt/onnx/autocast/precisionconverter.py b/modelopt/onnx/autocast/precisionconverter.py index d708c890f..1f8ad8c7c 100644 --- a/modelopt/onnx/autocast/precisionconverter.py +++ b/modelopt/onnx/autocast/precisionconverter.py @@ -68,7 +68,7 @@ class InitializerConsumerTracker: OP_TYPES_NOT_SUPPORTED_IN_LOW_PRECISION = ["Upsample", "NonMaxSuppression", "Celu"] # Mapping of op types to indices of inputs that should not be converted to low precision. -SKIP_LOW_PRECISION_MAPPING_FP16 = {"Resize": {2}} +SKIP_LOW_PRECISION_MAPPING_FP16 = {"Resize": {1, 2}} # {X, roi (opt), scales (opt)} SKIP_LOW_PRECISION_MAPPING_BF16 = {"Resize": {1, 2}} From a4eecb92dfa688b9c80a744fc17d98b94de5c7ac Mon Sep 17 00:00:00 2001 From: gcunhase <4861122+gcunhase@users.noreply.github.com> Date: Fri, 9 Jan 2026 15:25:59 -0500 Subject: [PATCH 2/8] Add unittest Signed-off-by: gcunhase <4861122+gcunhase@users.noreply.github.com> --- tests/_test_utils/onnx/lib_test_models.py | 83 +++++++++++++++++++++++ tests/unit/onnx/autocast/test_autocast.py | 22 +++++- 2 files changed, 104 insertions(+), 1 deletion(-) diff --git a/tests/_test_utils/onnx/lib_test_models.py b/tests/_test_utils/onnx/lib_test_models.py index 675fe03b3..f822fa130 100644 --- a/tests/_test_utils/onnx/lib_test_models.py +++ b/tests/_test_utils/onnx/lib_test_models.py @@ -924,3 +924,86 @@ def build_conv_isinf_model(opset_version=13): onnx.checker.check_model(model_inferred) return model_inferred + + +def build_conv_resize_model(): + # Define your model inputs and outputs + input_names = ["input_0"] + output_names = ["output_0"] + input_shapes = [(1, 288, 32, 32)] + output_shapes = [(1, 16, 64, 64)] + + inputs = [ + helper.make_tensor_value_info(input_name, onnx.TensorProto.FLOAT, input_shape) + for input_name, input_shape in zip(input_names, input_shapes) + ] + outputs = [ + helper.make_tensor_value_info(output_name, onnx.TensorProto.FLOAT, output_shape) + for output_name, output_shape in zip(output_names, output_shapes) + ] + + # Create the ONNX graph with the nodes + nodes = [ + helper.make_node( + op_type="Conv", + inputs=["input_0", "weights_1"], + outputs=["conv1_conv/Conv2D:0"], + name="conv1_conv/Conv2D", + dilations=[1, 1], + group=1, + kernel_shape=[1, 1], + pads=[0, 0, 0, 0], + strides=[1, 1], + ), + helper.make_node( + op_type="Resize", + inputs=[ + "conv1_conv/Conv2D:0", + "resize_roi_scales", + "resize_roi_scales", + "resize_sizes", + ], + outputs=["output_0"], + name="resize1_resize/Resize", + coordinate_transformation_mode="asymmetric", + cubic_coeff_a=-0.75, + mode="nearest", + nearest_mode="floor", + ), + ] + + # Create the ONNX initializers + initializers = [ + helper.make_tensor( + name="weights_1", + data_type=onnx.TensorProto.FLOAT, + dims=(16, 288, 1, 1), + vals=np.random.uniform(low=0.5, high=1.0, size=16 * 288 * 1 * 1), + ), + helper.make_tensor( + name="resize_roi_scales", + data_type=onnx.TensorProto.FLOAT, + dims=(0,), + vals=[], + ), + helper.make_tensor( + name="resize_sizes", + data_type=onnx.TensorProto.INT64, + dims=(4,), + vals=[1, 16, 64, 64], + ), + ] + + # Create the ONNX graph with the nodes and initializers + graph = helper.make_graph(nodes, "conv_resize", inputs, outputs, initializer=initializers) + + # Create the ONNX model + model = helper.make_model(graph) + model.opset_import[0].version = 13 + model.ir_version = 10 + + # Check the ONNX model + model_inferred = onnx.shape_inference.infer_shapes(model) + onnx.checker.check_model(model_inferred) + + return model_inferred diff --git a/tests/unit/onnx/autocast/test_autocast.py b/tests/unit/onnx/autocast/test_autocast.py index 3f987eaea..b4d790a23 100644 --- a/tests/unit/onnx/autocast/test_autocast.py +++ b/tests/unit/onnx/autocast/test_autocast.py @@ -20,7 +20,7 @@ import onnx import onnx_graphsurgeon as gs import pytest -from _test_utils.onnx.lib_test_models import build_conv_isinf_model +from _test_utils.onnx.lib_test_models import build_conv_isinf_model, build_conv_resize_model import modelopt.onnx.autocast.utils as utils import modelopt.onnx.utils as onnx_utils @@ -190,6 +190,26 @@ def test_conv_isinf_conversion(tmp_path, opset_version): assert assert_input_precision(isinf_nodes, dtype=supported_dtype) +def test_conv_resize_conversion(tmp_path): + onnx_model = build_conv_resize_model() + onnx_path = os.path.join(tmp_path, "conv_resize_model.onnx") + onnx.save(onnx_model, onnx_path) + + # Convert the model + converted_model = convert_to_mixed_precision(onnx_path=onnx_path) + + # Output model should be produced in the same tmp_path + output_onnx_path = onnx_path.replace(".onnx", ".fp16.onnx") + onnx.save(converted_model, output_onnx_path) + + # Load the output model and check QDQ node placements + graph = gs.import_onnx(converted_model) + + # Check that Conv is converted + conv_nodes = [n for n in graph.nodes if "Conv" in n.op] + assert assert_input_precision(conv_nodes) + + @pytest.mark.parametrize("target_opset", [13, 17, 19, 21]) def test_opset_parameter(temp_model_path, target_opset): """Test that the opset parameter correctly sets the output model's opset version.""" From 50c88e9f51183e22d5b5d3778222c328622d1878 Mon Sep 17 00:00:00 2001 From: gcunhase <4861122+gcunhase@users.noreply.github.com> Date: Mon, 12 Jan 2026 10:47:25 -0500 Subject: [PATCH 3/8] Revert change in Resize indices Signed-off-by: gcunhase <4861122+gcunhase@users.noreply.github.com> --- modelopt/onnx/autocast/precisionconverter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modelopt/onnx/autocast/precisionconverter.py b/modelopt/onnx/autocast/precisionconverter.py index 1f8ad8c7c..d708c890f 100644 --- a/modelopt/onnx/autocast/precisionconverter.py +++ b/modelopt/onnx/autocast/precisionconverter.py @@ -68,7 +68,7 @@ class InitializerConsumerTracker: OP_TYPES_NOT_SUPPORTED_IN_LOW_PRECISION = ["Upsample", "NonMaxSuppression", "Celu"] # Mapping of op types to indices of inputs that should not be converted to low precision. -SKIP_LOW_PRECISION_MAPPING_FP16 = {"Resize": {1, 2}} # {X, roi (opt), scales (opt)} +SKIP_LOW_PRECISION_MAPPING_FP16 = {"Resize": {2}} SKIP_LOW_PRECISION_MAPPING_BF16 = {"Resize": {1, 2}} From fbdbb26e54689eb195ce8e625d0ec2f88c06dd6f Mon Sep 17 00:00:00 2001 From: gcunhase <4861122+gcunhase@users.noreply.github.com> Date: Mon, 12 Jan 2026 11:34:24 -0500 Subject: [PATCH 4/8] Add function to duplicate shared constants Signed-off-by: gcunhase <4861122+gcunhase@users.noreply.github.com> --- modelopt/onnx/autocast/graphsanitizer.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/modelopt/onnx/autocast/graphsanitizer.py b/modelopt/onnx/autocast/graphsanitizer.py index 837e32f7c..f24acd52f 100644 --- a/modelopt/onnx/autocast/graphsanitizer.py +++ b/modelopt/onnx/autocast/graphsanitizer.py @@ -67,6 +67,7 @@ def sanitize(self) -> None: self.convert_opset() self.replace_layernorm_pattern() self.ensure_graph_name_exists() + self.duplicate_shared_constants() onnx_utils.name_onnx_nodes(self.model.graph) self.replace_custom_domain_nodes() self.sanitize_io_casts() @@ -254,6 +255,12 @@ def ensure_graph_name_exists(self) -> None: if not self.model.graph.name: self.model.graph.name = "model" + def duplicate_shared_constants(self) -> None: + """Duplicate constant tensors if they are shared.""" + self.model, is_duplicated_constant = onnx_utils.duplicate_shared_constants(self.model) + if is_duplicated_constant: + logger.info("Shared constants were detected and duplicated accordingly.") + def _match_layernorm_pattern(self, mean_node: onnx.NodeProto) -> dict | None: """Match the sequence of operations that constitute a LayerNorm. From 1480e17000858012254020d894d9f09d8a043fa0 Mon Sep 17 00:00:00 2001 From: gcunhase <4861122+gcunhase@users.noreply.github.com> Date: Mon, 12 Jan 2026 20:19:34 -0500 Subject: [PATCH 5/8] Improve testing function Signed-off-by: gcunhase <4861122+gcunhase@users.noreply.github.com> --- tests/_test_utils/onnx/lib_test_models.py | 2 ++ tests/unit/onnx/autocast/test_autocast.py | 14 +++++++++----- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/tests/_test_utils/onnx/lib_test_models.py b/tests/_test_utils/onnx/lib_test_models.py index f822fa130..ff97b5142 100644 --- a/tests/_test_utils/onnx/lib_test_models.py +++ b/tests/_test_utils/onnx/lib_test_models.py @@ -955,6 +955,8 @@ def build_conv_resize_model(): pads=[0, 0, 0, 0], strides=[1, 1], ), + # Note: resize_roi_scales is intentionally used for both roi and scales inputs + # to test the shared constant duplication fix (PR #757) helper.make_node( op_type="Resize", inputs=[ diff --git a/tests/unit/onnx/autocast/test_autocast.py b/tests/unit/onnx/autocast/test_autocast.py index b4d790a23..a28c367c8 100644 --- a/tests/unit/onnx/autocast/test_autocast.py +++ b/tests/unit/onnx/autocast/test_autocast.py @@ -174,7 +174,7 @@ def test_conv_isinf_conversion(tmp_path, opset_version): output_onnx_path = onnx_path.replace(".onnx", ".fp16.onnx") onnx.save(converted_model, output_onnx_path) - # Load the output model and check QDQ node placements + # Load the output model graph = gs.import_onnx(converted_model) # Check that Conv is converted @@ -202,12 +202,16 @@ def test_conv_resize_conversion(tmp_path): output_onnx_path = onnx_path.replace(".onnx", ".fp16.onnx") onnx.save(converted_model, output_onnx_path) - # Load the output model and check QDQ node placements + # Load the output model graph = gs.import_onnx(converted_model) - # Check that Conv is converted - conv_nodes = [n for n in graph.nodes if "Conv" in n.op] - assert assert_input_precision(conv_nodes) + # Check that Resize is correctly converted: + # - Data and ROI inputs (indices 0 and 1) should be FP16 + # - The remaining inputs (scales/sizes) should remain in their original types + resize_node = next(n for n in graph.nodes if n.op == "Resize") + assert all(inp.dtype == np.float16 for inp in resize_node.inputs[0:2]), ( + "Resize data and ROI inputs should be FP16" + ) @pytest.mark.parametrize("target_opset", [13, 17, 19, 21]) From 991667a3a67df8a52b7518e5a2d34c03d109109b Mon Sep 17 00:00:00 2001 From: gcunhase <4861122+gcunhase@users.noreply.github.com> Date: Mon, 12 Jan 2026 20:21:13 -0500 Subject: [PATCH 6/8] nit: update comment Signed-off-by: gcunhase <4861122+gcunhase@users.noreply.github.com> --- tests/unit/onnx/autocast/test_autocast.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/onnx/autocast/test_autocast.py b/tests/unit/onnx/autocast/test_autocast.py index a28c367c8..f761e1e9e 100644 --- a/tests/unit/onnx/autocast/test_autocast.py +++ b/tests/unit/onnx/autocast/test_autocast.py @@ -207,7 +207,7 @@ def test_conv_resize_conversion(tmp_path): # Check that Resize is correctly converted: # - Data and ROI inputs (indices 0 and 1) should be FP16 - # - The remaining inputs (scales/sizes) should remain in their original types + # - The remaining inputs (scales/sizes) should be kept in their original precisions resize_node = next(n for n in graph.nodes if n.op == "Resize") assert all(inp.dtype == np.float16 for inp in resize_node.inputs[0:2]), ( "Resize data and ROI inputs should be FP16" From 878601215b6330d103a6c69d29152632dd552761 Mon Sep 17 00:00:00 2001 From: gcunhase <4861122+gcunhase@users.noreply.github.com> Date: Tue, 13 Jan 2026 12:52:45 -0500 Subject: [PATCH 7/8] Update value_info and init mapping after graph sanitizing Signed-off-by: gcunhase <4861122+gcunhase@users.noreply.github.com> --- modelopt/onnx/autocast/graphsanitizer.py | 2 +- modelopt/onnx/autocast/precisionconverter.py | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/modelopt/onnx/autocast/graphsanitizer.py b/modelopt/onnx/autocast/graphsanitizer.py index f24acd52f..85f407a59 100644 --- a/modelopt/onnx/autocast/graphsanitizer.py +++ b/modelopt/onnx/autocast/graphsanitizer.py @@ -259,7 +259,7 @@ def duplicate_shared_constants(self) -> None: """Duplicate constant tensors if they are shared.""" self.model, is_duplicated_constant = onnx_utils.duplicate_shared_constants(self.model) if is_duplicated_constant: - logger.info("Shared constants were detected and duplicated accordingly.") + logger.warning("Shared constants were detected and duplicated accordingly.") def _match_layernorm_pattern(self, mean_node: onnx.NodeProto) -> dict | None: """Match the sequence of operations that constitute a LayerNorm. diff --git a/modelopt/onnx/autocast/precisionconverter.py b/modelopt/onnx/autocast/precisionconverter.py index d708c890f..eae589d8a 100644 --- a/modelopt/onnx/autocast/precisionconverter.py +++ b/modelopt/onnx/autocast/precisionconverter.py @@ -1419,6 +1419,11 @@ def _sanitize_model(self): graph_sanitizer.sanitize() self.model = graph_sanitizer.model + # Update value_info_map and initializer_map after sanitizing model + self.value_info_map, self.initializer_map, self.node_to_init_map = utils.setup_mappings( + self.model + ) + def _create_skip_inputs_mapping(self, tensor_block_dict: dict[str, dict[str, list[int]]] = {}): """Create mapping of op types to indices of inputs that should not be converted to low precision.""" skip_inputs_map = {} From 7e9c6f4fccab4bce3f29a71b43a724b4321efa29 Mon Sep 17 00:00:00 2001 From: gcunhase <4861122+gcunhase@users.noreply.github.com> Date: Tue, 13 Jan 2026 19:36:44 -0500 Subject: [PATCH 8/8] Fix empty zero-point tensor issue Signed-off-by: gcunhase <4861122+gcunhase@users.noreply.github.com> --- modelopt/onnx/quantization/fp8.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modelopt/onnx/quantization/fp8.py b/modelopt/onnx/quantization/fp8.py index bca898b0c..cab92483c 100755 --- a/modelopt/onnx/quantization/fp8.py +++ b/modelopt/onnx/quantization/fp8.py @@ -102,7 +102,7 @@ def _convert(node: onnx.NodeProto): ) zero_point = initializers[zero_point_idx] dtype = onnx.helper.tensor_dtype_to_np_dtype(zero_point.data_type) - vals = np.array(zero_point.int32_data, dtype=dtype).tobytes() + vals = np.array(zero_point.int32_data, dtype=dtype).tobytes() or zero_point.raw_data np_zero_point = onnx.helper.make_tensor( zero_point_name, onnx.TensorProto.FLOAT8E4M3FN, zero_point.dims, vals, raw=True