diff --git a/onnxruntime/python/tools/quantization/fusions/fusion_layernorm.py b/onnxruntime/python/tools/quantization/fusions/fusion_layernorm.py index 7d58c1c180822..a28d4a32778fc 100644 --- a/onnxruntime/python/tools/quantization/fusions/fusion_layernorm.py +++ b/onnxruntime/python/tools/quantization/fusions/fusion_layernorm.py @@ -33,6 +33,16 @@ def fuse( | | +-------------------------------------------------+ + Or, using Mul instead of Pow: + + +----------------------+ + | | + | v + [Root] --> ReduceMean --> Sub --> Mul --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Add + (axis=2 or -1) | (in0=in1) (axis=2 or -1) (E-6 or E-12 or 0) ^ + | | + +-------------------------------------------------+ + It also handles cases of duplicated sub nodes exported from older version of PyTorch: +----------------------+ @@ -40,7 +50,7 @@ def fuse( | +-------> Sub-----------------------------------------------+ | | | | | v - [Root] --> ReduceMean --> Sub --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Add + [Root] --> ReduceMean --> Sub --> (Pow or Mul) --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Add | ^ | | +----------------------+ @@ -70,10 +80,9 @@ def fuse( div_node, [ (["Sqrt", "Add", "ReduceMean", "Pow", "Sub"], [1, 0, 0, 0, 0]), - ( - ["Sqrt", "Add", "ReduceMean", "Pow", "Cast", "Sub"], - [1, 0, 0, 0, 0, 0], - ), + (["Sqrt", "Add", "ReduceMean", "Pow", "Cast", "Sub"], [1, 0, 0, 0, 0, 0]), + (["Sqrt", "Add", "ReduceMean", "Mul", "Sub"], [1, 0, 0, 0, 0]), + (["Sqrt", "Add", "ReduceMean", "Mul", "Cast", "Sub"], [1, 0, 0, 0, 0, 0]), ], output_name_to_node, ) @@ -90,8 +99,10 @@ def fuse( # Skip fusion since epsilon value is not expected. return - pow_node = parent_nodes[3] - if self.find_constant_input(pow_node, 2.0) != 1: + pow_or_mul_node = parent_nodes[3] + if pow_or_mul_node.op_type == "Pow" and self.find_constant_input(pow_or_mul_node, 2.0) != 1: + return + elif pow_or_mul_node.op_type == "Mul" and pow_or_mul_node.input[0] != pow_or_mul_node.input[1]: return mul_node = input_name_to_nodes[div_node.output[0]][0] diff --git a/onnxruntime/test/python/quantization/test_qnn_preprocess_model.py b/onnxruntime/test/python/quantization/test_qnn_preprocess_model.py index 7e0a8496b8bfb..70f8ca127e184 100644 --- a/onnxruntime/test/python/quantization/test_qnn_preprocess_model.py +++ b/onnxruntime/test/python/quantization/test_qnn_preprocess_model.py @@ -55,15 +55,26 @@ def build_model(self, shape, scale_val, bias_val): bias_const = onnx.numpy_helper.from_array(np.array(bias_val, dtype=np.float32), "bias_const") two_const = onnx.numpy_helper.from_array(np.array(2.0, dtype=np.float32), "two_const") - m_rm0_node = onnx.helper.make_node("ReduceMean", ["l2_seq_output", "axes_const"], ["m_rm0_out"]) - m_sub_node = onnx.helper.make_node("Sub", ["l2_seq_output", "m_rm0_out"], ["m_sub_out"]) - m_pow_node = onnx.helper.make_node("Pow", ["m_sub_out", "two_const"], ["m_pow_out"]) - m_rm1_node = onnx.helper.make_node("ReduceMean", ["m_pow_out", "axes_const"], ["m_rm1_out"]) - m_add0_node = onnx.helper.make_node("Add", ["m_rm1_out", "eps_const"], ["m_add0_out"]) - m_sqrt_node = onnx.helper.make_node("Sqrt", ["m_add0_out"], ["m_sqrt_out"]) - m_div_node = onnx.helper.make_node("Div", ["m_sub_out", "m_sqrt_out"], ["m_div_out"]) - m_mul_node = onnx.helper.make_node("Mul", ["m_div_out", "scale_const"], ["m_mul_out"]) - m_add1_node = onnx.helper.make_node("Add", ["m_mul_out", "bias_const"], ["output"]) + m0_rm0_node = onnx.helper.make_node("ReduceMean", ["l2_seq_output", "axes_const"], ["m0_rm0_out"]) + m0_sub_node = onnx.helper.make_node("Sub", ["l2_seq_output", "m0_rm0_out"], ["m0_sub_out"]) + m0_pow_node = onnx.helper.make_node("Pow", ["m0_sub_out", "two_const"], ["m0_pow_out"]) + m0_rm1_node = onnx.helper.make_node("ReduceMean", ["m0_pow_out", "axes_const"], ["m0_rm1_out"]) + m0_add0_node = onnx.helper.make_node("Add", ["m0_rm1_out", "eps_const"], ["m0_add0_out"]) + m0_sqrt_node = onnx.helper.make_node("Sqrt", ["m0_add0_out"], ["m0_sqrt_out"]) + m0_div_node = onnx.helper.make_node("Div", ["m0_sub_out", "m0_sqrt_out"], ["m0_div_out"]) + m0_mul_node = onnx.helper.make_node("Mul", ["m0_div_out", "scale_const"], ["m0_mul_out"]) + m0_add1_node = onnx.helper.make_node("Add", ["m0_mul_out", "bias_const"], ["m0_add1_out"]) + + # Alternate ReduceMean sequence + m1_rm0_node = onnx.helper.make_node("ReduceMean", ["m0_add1_out", "axes_const"], ["m1_rm0_out"]) + m1_sub_node = onnx.helper.make_node("Sub", ["m0_add1_out", "m1_rm0_out"], ["m1_sub_out"]) + m1_mul0_node = onnx.helper.make_node("Mul", ["m1_sub_out", "m1_sub_out"], ["m1_mul0_out"]) + m1_rm1_node = onnx.helper.make_node("ReduceMean", ["m1_mul0_out", "axes_const"], ["m1_rm1_out"]) + m1_add0_node = onnx.helper.make_node("Add", ["m1_rm1_out", "eps_const"], ["m1_add0_out"]) + m1_sqrt_node = onnx.helper.make_node("Sqrt", ["m1_add0_out"], ["m1_sqrt_out"]) + m1_div_node = onnx.helper.make_node("Div", ["m1_sub_out", "m1_sqrt_out"], ["m1_div_out"]) + m1_mul1_node = onnx.helper.make_node("Mul", ["m1_div_out", "scale_const"], ["m1_mul1_out"]) + m1_add1_node = onnx.helper.make_node("Add", ["m1_mul1_out", "bias_const"], ["output"]) graph = onnx.helper.make_graph( [ @@ -76,15 +87,24 @@ def build_model(self, shape, scale_val, bias_val): l2_clip_node, l2_expand_node, l2_div_node, - m_rm0_node, - m_sub_node, - m_pow_node, - m_rm1_node, - m_add0_node, - m_sqrt_node, - m_div_node, - m_mul_node, - m_add1_node, + m0_rm0_node, + m0_sub_node, + m0_pow_node, + m0_rm1_node, + m0_add0_node, + m0_sqrt_node, + m0_div_node, + m0_mul_node, + m0_add1_node, + m1_rm0_node, + m1_sub_node, + m1_mul0_node, + m1_rm1_node, + m1_add0_node, + m1_sqrt_node, + m1_div_node, + m1_mul1_node, + m1_add1_node, ], "qnn_f32_model", [root_inp], @@ -119,8 +139,8 @@ def test_all_fusions(self): fused_model = onnx.load_model("model.qnn_pp.onnx") - # 3 fused Ops: Gelu, LpNorm, LayerNorm - self.assertEqual(len(fused_model.graph.node), 3) + # 4 fused Ops: Gelu, LpNorm, LayerNorm of two patterns + self.assertEqual(len(fused_model.graph.node), 4) expected_op_types = {"Gelu", "LpNormalization", "LayerNormalization"} for node in fused_model.graph.node: self.assertIn(node.op_type, expected_op_types) @@ -167,8 +187,8 @@ def test_external_data(self): fused_model = onnx.load_model("model.qnn_pp.onnx", load_external_data=False) - # 3 fused Ops: Gelu, LpNorm, LayerNorm - self.assertEqual(len(fused_model.graph.node), 3) + # 4 fused Ops: Gelu, LpNorm, LayerNorm of two patterns + self.assertEqual(len(fused_model.graph.node), 4) expected_op_types = {"Gelu", "LpNormalization", "LayerNormalization"} for node in fused_model.graph.node: self.assertIn(node.op_type, expected_op_types)