Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions python/tvm/relax/backend/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
wildcard,
GlobalVarPattern,
TuplePattern,
# make_fused_bias_activation_pattern,
)


Expand All @@ -35,10 +36,15 @@ def _with_bias_activation_pattern(
annotations: Dict[str, DFPattern],
with_bias: bool = False,
activation: str = None,
allow_reshape: bool = False,
) -> Tuple[DFPattern, Mapping[str, DFPattern]]:
if with_bias:
annotations["bias"] = bias = wildcard()
out = is_op("relax.add")(out, bias)
if allow_reshape:
reshaped_bias = is_op("relax.reshape")(bias, wildcard(), varg_default_wildcard=True)
out = is_op("relax.add")(out, reshaped_bias, varg_default_wildcard=True)
else:
out = is_op("relax.add")(out, bias)

if activation:
out = is_op(activation)(out)
Expand All @@ -50,6 +56,7 @@ def make_fused_bias_activation_pattern(
op_name: str,
with_bias: bool = False,
activation: str = None,
allow_reshape: bool = False,
) -> Tuple[DFPattern, Mapping[str, DFPattern]]:
"""
A simple utility to create patterns for an operation fused with bias addition and activation.
Expand Down Expand Up @@ -80,7 +87,7 @@ def make_fused_bias_activation_pattern(
out = is_op(op_name)(lhs, rhs)
annotations = {"lhs": lhs, "rhs": rhs, "root": out}

return _with_bias_activation_pattern(out, annotations, with_bias, activation)
return _with_bias_activation_pattern(out, annotations, with_bias, activation, allow_reshape)


def make_residual_block_pattern(
Expand Down
13 changes: 11 additions & 2 deletions python/tvm/relax/dpl/pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -1119,7 +1119,9 @@ def _only_used_by(
return ffi.only_used_by(lhs, rhs, index) # type: ignore


def make_fused_bias_activation_pattern(op_name, with_bias=False, activation=None):
def make_fused_bias_activation_pattern(
op_name, with_bias=False, activation=None, allow_reshape=False
):
"""
A simple utility to create patterns for an operation fused with bias addition and activation.

Expand All @@ -1134,6 +1136,9 @@ def make_fused_bias_activation_pattern(op_name, with_bias=False, activation=None
activation: str
The name of an activation Relax op, such as "relax.nn.relu"

allow_reshape: bool
Whether to allow reshape operation before bias addition (for PyTorch frontend)

Returns
-------
pattern: DFPattern
Expand All @@ -1145,7 +1150,11 @@ def make_fused_bias_activation_pattern(op_name, with_bias=False, activation=None

if with_bias:
bias = wildcard()
out = is_op("relax.add")(out, bias)
if allow_reshape:
reshaped_bias = is_op("relax.reshape")(bias, wildcard(), varg_default_wildcard=True)
out = is_op("relax.add")(out, reshaped_bias, varg_default_wildcard=True)
else:
out = is_op("relax.add")(out, bias)

if activation:
return is_op(activation)(out)
Expand Down
161 changes: 161 additions & 0 deletions tests/python/relax/test_fuse_pytorch_conv2d_bias_pattern.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import torch
import tvm
import tvm.testing
from tvm import relax
from tvm.relax.frontend.torch import from_fx
from tvm.relax.dpl.pattern import make_fused_bias_activation_pattern
from tvm.script import ir as I
from tvm.script import relax as R


def test_conv2d_bias_relu_fusion():
"""Test PyTorch conv2d + bias + relu fusion with reshape pattern"""

class Conv2dBiasRelu(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 6, 3, bias=True)
self.relu = torch.nn.ReLU()

def forward(self, x):
return self.relu(self.conv(x))

# Convert PyTorch model to Relax IR
model = Conv2dBiasRelu()
graph_model = torch.fx.symbolic_trace(model)
input_info = [([1, 3, 10, 10], "float32")]

with torch.no_grad():
mod = from_fx(graph_model, input_info)

# Apply fusion with modified pattern
patterns = [
(
"conv2d_bias_activation_with_reshape",
make_fused_bias_activation_pattern(
"relax.nn.conv2d", with_bias=True, activation="relax.nn.relu", allow_reshape=True
),
)
]

fused_mod = relax.transform.FuseOpsByPattern(patterns, bind_constants=False)(mod)

# Verify fusion occurred
fused_functions = [name for name in fused_mod.functions.keys() if "fused" in str(name)]

assert len(fused_functions) == 1, "Expected exactly one fused function"

# Verify the fused function contains all operations
fused_func = fused_mod[fused_functions[0]]
assert hasattr(fused_func, "attrs"), "Fused function should have attributes"
assert "Composite" in fused_func.attrs, "Fused function should have Composite attribute"


def test_conv2d_bias_relu_fusion_comparison():
"""Compare fusion with and without allow_reshape option"""

class Conv2dBiasRelu(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 6, 3, bias=True)
self.relu = torch.nn.ReLU()

def forward(self, x):
return self.relu(self.conv(x))

model = Conv2dBiasRelu()
graph_model = torch.fx.symbolic_trace(model)
input_info = [([1, 3, 10, 10], "float32")]

with torch.no_grad():
mod = from_fx(graph_model, input_info)

# Test with allow_reshape=False
old_patterns = [
(
"conv2d_bias_activation_old",
make_fused_bias_activation_pattern(
"relax.nn.conv2d", with_bias=True, activation="relax.nn.relu", allow_reshape=False
),
)
]

old_fused_mod = relax.transform.FuseOpsByPattern(old_patterns, bind_constants=False)(mod)

# Test with allow_reshape=True
new_patterns = [
(
"conv2d_bias_activation_new",
make_fused_bias_activation_pattern(
"relax.nn.conv2d", with_bias=True, activation="relax.nn.relu", allow_reshape=True
),
)
]

new_fused_mod = relax.transform.FuseOpsByPattern(new_patterns, bind_constants=False)(mod)

# Both should create fused functions
old_fused_functions = [name for name in old_fused_mod.functions.keys() if "fused" in str(name)]
new_fused_functions = [name for name in new_fused_mod.functions.keys() if "fused" in str(name)]

assert len(old_fused_functions) >= 1, "Old pattern should create at least one fused function"
assert len(new_fused_functions) >= 1, "New pattern should create at least one fused function"


def test_conv2d_no_fusion_case():
"""Test case where fusion should not occur"""

class Conv2dNoBias(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 6, 3, bias=False)

def forward(self, x):
return self.conv(x)

model = Conv2dNoBias()
graph_model = torch.fx.symbolic_trace(model)
input_info = [([1, 3, 10, 10], "float32")]

with torch.no_grad():
mod = from_fx(graph_model, input_info)

# Apply fusion pattern
patterns = [
(
"conv2d_bias_activation",
make_fused_bias_activation_pattern(
"relax.nn.conv2d", with_bias=True, activation="relax.nn.relu", allow_reshape=True
),
)
]

fused_mod = relax.transform.FuseOpsByPattern(patterns, bind_constants=False)(mod)

# No fusion should occur
fused_functions = [name for name in fused_mod.functions.keys() if "fused" in str(name)]

assert len(fused_functions) == 0, "No fusion should occur for conv2d without bias and relu"


if __name__ == "__main__":
test_conv2d_bias_relu_fusion()
test_conv2d_bias_relu_fusion_comparison()
test_conv2d_no_fusion_case()