diff --git a/python/tvm/relax/backend/patterns.py b/python/tvm/relax/backend/patterns.py index 1faef9cceb05..b1be3778fa39 100644 --- a/python/tvm/relax/backend/patterns.py +++ b/python/tvm/relax/backend/patterns.py @@ -27,6 +27,7 @@ wildcard, GlobalVarPattern, TuplePattern, + # make_fused_bias_activation_pattern, ) @@ -35,10 +36,15 @@ def _with_bias_activation_pattern( annotations: Dict[str, DFPattern], with_bias: bool = False, activation: str = None, + allow_reshape: bool = False, ) -> Tuple[DFPattern, Mapping[str, DFPattern]]: if with_bias: annotations["bias"] = bias = wildcard() - out = is_op("relax.add")(out, bias) + if allow_reshape: + reshaped_bias = is_op("relax.reshape")(bias, wildcard(), varg_default_wildcard=True) + out = is_op("relax.add")(out, reshaped_bias, varg_default_wildcard=True) + else: + out = is_op("relax.add")(out, bias) if activation: out = is_op(activation)(out) @@ -50,6 +56,7 @@ def make_fused_bias_activation_pattern( op_name: str, with_bias: bool = False, activation: str = None, + allow_reshape: bool = False, ) -> Tuple[DFPattern, Mapping[str, DFPattern]]: """ A simple utility to create patterns for an operation fused with bias addition and activation. @@ -80,7 +87,7 @@ def make_fused_bias_activation_pattern( out = is_op(op_name)(lhs, rhs) annotations = {"lhs": lhs, "rhs": rhs, "root": out} - return _with_bias_activation_pattern(out, annotations, with_bias, activation) + return _with_bias_activation_pattern(out, annotations, with_bias, activation, allow_reshape) def make_residual_block_pattern( diff --git a/python/tvm/relax/dpl/pattern.py b/python/tvm/relax/dpl/pattern.py index 633c2c6790da..e621e5b596b6 100644 --- a/python/tvm/relax/dpl/pattern.py +++ b/python/tvm/relax/dpl/pattern.py @@ -1119,7 +1119,9 @@ def _only_used_by( return ffi.only_used_by(lhs, rhs, index) # type: ignore -def make_fused_bias_activation_pattern(op_name, with_bias=False, activation=None): +def make_fused_bias_activation_pattern( + op_name, with_bias=False, activation=None, allow_reshape=False +): """ A simple utility to create patterns for an operation fused with bias addition and activation. @@ -1134,6 +1136,9 @@ def make_fused_bias_activation_pattern(op_name, with_bias=False, activation=None activation: str The name of an activation Relax op, such as "relax.nn.relu" + allow_reshape: bool + Whether to allow reshape operation before bias addition (for PyTorch frontend) + Returns ------- pattern: DFPattern @@ -1145,7 +1150,11 @@ def make_fused_bias_activation_pattern(op_name, with_bias=False, activation=None if with_bias: bias = wildcard() - out = is_op("relax.add")(out, bias) + if allow_reshape: + reshaped_bias = is_op("relax.reshape")(bias, wildcard(), varg_default_wildcard=True) + out = is_op("relax.add")(out, reshaped_bias, varg_default_wildcard=True) + else: + out = is_op("relax.add")(out, bias) if activation: return is_op(activation)(out) diff --git a/tests/python/relax/test_fuse_pytorch_conv2d_bias_pattern.py b/tests/python/relax/test_fuse_pytorch_conv2d_bias_pattern.py new file mode 100644 index 000000000000..4b45c2bf534c --- /dev/null +++ b/tests/python/relax/test_fuse_pytorch_conv2d_bias_pattern.py @@ -0,0 +1,161 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import torch +import tvm +import tvm.testing +from tvm import relax +from tvm.relax.frontend.torch import from_fx +from tvm.relax.dpl.pattern import make_fused_bias_activation_pattern +from tvm.script import ir as I +from tvm.script import relax as R + + +def test_conv2d_bias_relu_fusion(): + """Test PyTorch conv2d + bias + relu fusion with reshape pattern""" + + class Conv2dBiasRelu(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 6, 3, bias=True) + self.relu = torch.nn.ReLU() + + def forward(self, x): + return self.relu(self.conv(x)) + + # Convert PyTorch model to Relax IR + model = Conv2dBiasRelu() + graph_model = torch.fx.symbolic_trace(model) + input_info = [([1, 3, 10, 10], "float32")] + + with torch.no_grad(): + mod = from_fx(graph_model, input_info) + + # Apply fusion with modified pattern + patterns = [ + ( + "conv2d_bias_activation_with_reshape", + make_fused_bias_activation_pattern( + "relax.nn.conv2d", with_bias=True, activation="relax.nn.relu", allow_reshape=True + ), + ) + ] + + fused_mod = relax.transform.FuseOpsByPattern(patterns, bind_constants=False)(mod) + + # Verify fusion occurred + fused_functions = [name for name in fused_mod.functions.keys() if "fused" in str(name)] + + assert len(fused_functions) == 1, "Expected exactly one fused function" + + # Verify the fused function contains all operations + fused_func = fused_mod[fused_functions[0]] + assert hasattr(fused_func, "attrs"), "Fused function should have attributes" + assert "Composite" in fused_func.attrs, "Fused function should have Composite attribute" + + +def test_conv2d_bias_relu_fusion_comparison(): + """Compare fusion with and without allow_reshape option""" + + class Conv2dBiasRelu(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 6, 3, bias=True) + self.relu = torch.nn.ReLU() + + def forward(self, x): + return self.relu(self.conv(x)) + + model = Conv2dBiasRelu() + graph_model = torch.fx.symbolic_trace(model) + input_info = [([1, 3, 10, 10], "float32")] + + with torch.no_grad(): + mod = from_fx(graph_model, input_info) + + # Test with allow_reshape=False + old_patterns = [ + ( + "conv2d_bias_activation_old", + make_fused_bias_activation_pattern( + "relax.nn.conv2d", with_bias=True, activation="relax.nn.relu", allow_reshape=False + ), + ) + ] + + old_fused_mod = relax.transform.FuseOpsByPattern(old_patterns, bind_constants=False)(mod) + + # Test with allow_reshape=True + new_patterns = [ + ( + "conv2d_bias_activation_new", + make_fused_bias_activation_pattern( + "relax.nn.conv2d", with_bias=True, activation="relax.nn.relu", allow_reshape=True + ), + ) + ] + + new_fused_mod = relax.transform.FuseOpsByPattern(new_patterns, bind_constants=False)(mod) + + # Both should create fused functions + old_fused_functions = [name for name in old_fused_mod.functions.keys() if "fused" in str(name)] + new_fused_functions = [name for name in new_fused_mod.functions.keys() if "fused" in str(name)] + + assert len(old_fused_functions) >= 1, "Old pattern should create at least one fused function" + assert len(new_fused_functions) >= 1, "New pattern should create at least one fused function" + + +def test_conv2d_no_fusion_case(): + """Test case where fusion should not occur""" + + class Conv2dNoBias(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 6, 3, bias=False) + + def forward(self, x): + return self.conv(x) + + model = Conv2dNoBias() + graph_model = torch.fx.symbolic_trace(model) + input_info = [([1, 3, 10, 10], "float32")] + + with torch.no_grad(): + mod = from_fx(graph_model, input_info) + + # Apply fusion pattern + patterns = [ + ( + "conv2d_bias_activation", + make_fused_bias_activation_pattern( + "relax.nn.conv2d", with_bias=True, activation="relax.nn.relu", allow_reshape=True + ), + ) + ] + + fused_mod = relax.transform.FuseOpsByPattern(patterns, bind_constants=False)(mod) + + # No fusion should occur + fused_functions = [name for name in fused_mod.functions.keys() if "fused" in str(name)] + + assert len(fused_functions) == 0, "No fusion should occur for conv2d without bias and relu" + + +if __name__ == "__main__": + test_conv2d_bias_relu_fusion() + test_conv2d_bias_relu_fusion_comparison() + test_conv2d_no_fusion_case()