-
Notifications
You must be signed in to change notification settings - Fork 3.7k
[Optimization][Operator] Implement and enable Conv2d-Reshape-Add-ReLU fusion #18173
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
kimm240
wants to merge
2
commits into
apache:main
Choose a base branch
from
kimm240:conv2d-reshape-add-relu-fusion-lint
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+193
−0
Open
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
115 changes: 115 additions & 0 deletions
115
python/tvm/relax/transform/fuse_conv2d_reshape_add_relu.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,115 @@ | ||
| # Licensed to the Apache Software Foundation (ASF) under one | ||
| # or more contributor license agreements. See the NOTICE file | ||
| # distributed with this work for additional information | ||
| # regarding copyright ownership. The ASF licenses this file | ||
| # to you under the Apache License, Version 2.0 (the | ||
| # "License"); you may not use this file except in compliance | ||
| # with the License. You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, | ||
| # software distributed under the License is distributed on an | ||
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
| # KIND, either express or implied. See the License for the | ||
| # specific language governing permissions and limitations | ||
| # under the License. | ||
|
|
||
| """This module provides a TVM Relax pass for fusing Conv2d-Reshape-Add-ReLU pattern.""" | ||
|
|
||
| import tvm | ||
| from tvm import IRModule, relax | ||
| from tvm.relax.dpl.pattern import is_op, wildcard | ||
|
|
||
| # Define a TVM module pass for fusing specific operations. | ||
| # @tvm.transform.module_pass decorates a class to turn it into a TVM IRModule pass. | ||
| # opt_level=0 means this pass can be run at any optimization level. | ||
| # name="FuseConv2dReshapeAddRelu" gives a descriptive name to the pass. | ||
|
|
||
|
|
||
| @tvm.transform.module_pass(opt_level=0, name="FuseConv2dReshapeAddRelu") | ||
| class FuseConv2dReshapeAddRelu: | ||
| """A Relax pass that fuses the Conv2d-Reshape-Add-ReLU pattern into a composite function.""" | ||
|
|
||
| # The main transformation method that applies the pass to an IRModule. | ||
| # mod: The input IRModule to be transformed. | ||
| # _ctx: PassContext (unused in this specific pass but required by the decorator). | ||
| def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule: | ||
| """Transforms the input IRModule by applying the Conv2d-Reshape-Add-ReLU fusion. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| mod : IRModule | ||
| The input IRModule to be transformed. | ||
| _ctx : tvm.transform.PassContext | ||
| The pass context (unused in this specific pass but required by the decorator). | ||
|
|
||
| Returns | ||
| ------- | ||
| IRModule | ||
| The transformed IRModule with the fused pattern. | ||
| """ | ||
| # Apply the FuseOpsByPattern transformation. | ||
| # This pass identifies specific operator patterns in the IRModule | ||
| # and fuses them into a single composite function. | ||
| mod = relax.transform.FuseOpsByPattern( | ||
| # Define the patterns to fuse. It's a list of tuples: | ||
| # ("composite_function_name", pattern_root, annotations, check_function) | ||
| # "dnnl.conv2d_reshape_add_relu" is the name given to the fused operation, | ||
| # indicating it's suitable for DNNL backend. | ||
| [("dnnl.conv2d_reshape_add_relu", *_conv2d_reshape_add_relu_pattern())], | ||
| # bind_constants=False means that constants in the pattern (like shapes) | ||
| # are not treated as part of the pattern to be matched, allowing for more flexibility. | ||
| bind_constants=False, | ||
| )(mod) | ||
|
|
||
| # Return the transformed IRModule. | ||
| return mod | ||
|
|
||
|
|
||
| # Helper function to define the operator fusion pattern for Conv2d-Reshape-Add-ReLU. | ||
| # This function uses TVM's declarative pattern language (DPL). | ||
| def _conv2d_reshape_add_relu_pattern(): | ||
| # Define wildcard placeholders for the input tensors. | ||
| # 'wildcard()' matches any Relax expression. | ||
| data = wildcard() | ||
| weight = wildcard() | ||
| bias = wildcard() | ||
| shape = wildcard() # Wildcard for the target shape of the reshape operation | ||
|
|
||
| # Define the sequence of operations in the pattern: | ||
| # 1. Convolution (relax.nn.conv2d) | ||
| # varg_default_wildcard=True means that any variadic arguments (like strides, padding) | ||
| # will also be matched by wildcards, making the pattern more general. | ||
| conv_out = is_op("relax.nn.conv2d")(data, weight, varg_default_wildcard=True) | ||
| # 2. Reshape (relax.reshape) | ||
| # This matches a reshape operation applied to the 'bias' tensor with any 'shape'. | ||
| reshaped_bias = is_op("relax.reshape")(bias, shape) | ||
| # 3. Addition (relax.add) | ||
| # This matches an add operation where 'conv_out' and 'reshaped_bias' are inputs. | ||
| add_out = is_op("relax.add")(conv_out, reshaped_bias) | ||
| # 4. ReLU (relax.nn.relu) | ||
| # This matches a ReLU operation applied to the output of the add operation. | ||
| relu_out = is_op("relax.nn.relu")(add_out) | ||
|
|
||
| # Define annotations for the pattern. | ||
| # These map internal names (keys) to the matched Relax expressions (values). | ||
| # This is useful for debugging and for custom check functions. | ||
| annotations = { | ||
| "conv_out": conv_out, | ||
| "reshaped_bias": reshaped_bias, | ||
| "add_out": add_out, | ||
| "relu_out": relu_out, | ||
| } | ||
|
|
||
| # Define a custom check function for the pattern. | ||
| # This function is executed after a potential match is found. | ||
| # It can be used to add more specific conditions for the fusion. | ||
| # In this case, 'return True' means it always matches if the structure is found. | ||
| def _check(_context): | ||
| """A check function for the pattern (currently always returns True).""" | ||
| return True | ||
|
|
||
| # Return the root of the pattern, the annotations, and the check function. | ||
| # The 'relu_out' is the final output of the sequence being matched. | ||
| return relu_out, annotations, _check | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,75 @@ | ||
| # Licensed to the Apache Software Foundation (ASF) under one | ||
| # or more contributor license agreements. See the NOTICE file | ||
| # distributed with this work for additional information | ||
| # regarding copyright ownership. The ASF licenses this file | ||
| # to you under the Apache License, Version 2.0 (the | ||
| # "License"); you may not use this file except in compliance | ||
| # with the License. You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, | ||
| # software distributed under the License is distributed on an | ||
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
| # KIND, either express or implied. See the License for the | ||
| # specific language governing permissions and limitations | ||
| # under the License. | ||
|
|
||
| import tvm | ||
| from tvm import relax | ||
| from tvm.relax.transform import FuseConv2dReshapeAddRelu | ||
| from tvm.script import relax as R | ||
|
|
||
|
|
||
| def test_transform_pass(): | ||
|
|
||
| # Define the initial IRModule | ||
| @tvm.script.ir_module | ||
| class TestModule: | ||
| @R.function | ||
| def main( | ||
| data: R.Tensor((1, 3, 224, 224), dtype="float32"), | ||
| weight: R.Tensor((64, 3, 3, 3), dtype="float32"), | ||
| bias: R.Tensor((64,), dtype="float32"), | ||
| ): | ||
| with R.dataflow(): | ||
| conv_out = R.nn.conv2d(data, weight) | ||
| bias_reshaped = R.reshape(bias, [1, 64, 1, 1]) | ||
| bias_add = R.add(conv_out, bias_reshaped) | ||
| relu_out = R.nn.relu(bias_add) | ||
| R.output(relu_out) | ||
| return relu_out | ||
|
|
||
| print(TestModule) | ||
|
|
||
| # Step 1: Apply the FuseConv2dReshapeAddRelu pass | ||
| # This pass identifies the fusion pattern (conv2d-reshape-add-relu) | ||
| # and encapsulates it into a new Relax function with "Composite" attribute. | ||
| fused_mod = FuseConv2dReshapeAddRelu()(TestModule) | ||
| print("=== IR after Step 1 (FuseConv2dReshapeAddRelu) ===") | ||
| print(fused_mod) | ||
|
|
||
| # Step 2: Apply Sequential passes including MergeCompositeFunctions | ||
| # MergeCompositeFunctions takes functions marked with "Composite" | ||
| # and transforms them into functions with a "Codegen" attribute, | ||
| # indicating they should be offloaded to an external backend (e.g., DNNL). | ||
| final_mod = tvm.ir.transform.Sequential( | ||
| [ | ||
| relax.transform.FuseConv2dReshapeAddRelu(), | ||
| relax.transform.MergeCompositeFunctions(), | ||
| ] | ||
| )(TestModule) | ||
|
|
||
| print("=== IR after Final Fusion (Sequential Passes) ===") | ||
| print(final_mod) | ||
|
|
||
| # Check attributes of functions in the final module | ||
| # This helps confirm if "Codegen" attribute was successfully added to the fused function. | ||
| print("=== Function Attributes in Final IR ===") | ||
| for name, func in final_mod.functions.items(): | ||
| if hasattr(func, "attrs") and func.attrs: | ||
| print(f"Function {name} attributes:", func.attrs) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| test_transform_pass() |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am wondering if
transform.FuseOpswill fuse them, I guess it might workThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yongwww
Excellent point! However, after checking the actual implementation, I've confirmed that the generic FuseOps cannot handle this specific pattern.
Summary
The generic relax.transform.FuseOps pass is currently unable to fuse the common conv2d + bias + activation pattern when imported from PyTorch. The root cause is that the PyTorch frontend generates a conv2d -> reshape -> add sequence for the bias term, which the existing pattern matcher in FuseOps does not recognize. This leaves a critical, common pattern unoptimized.
The Pattern Generated by the PyTorch Frontend
When handling a torch.nn.Conv2d layer with bias=True, the PyTorch frontend consistently generates a reshape + add pattern for the bias. This is not specific to Conv2d and is standard behavior for other convolution types as well:
Conv1d: See test_frontend_from_exported_program.py:1752-1753
Conv2d: See test_frontend_from_fx.py:269-270
Conv3d: See test_frontend_from_exported_program.py:3822-3823
Limitation of TVM's Current Pattern Matching
The pattern designed to fuse bias and activation, make_fused_bias_activation_pattern, is defined in pattern.py:1179-1181. This function is currently implemented to match only a simple relax.add operation following the convolution. It cannot see past the reshape operation inserted by the frontend, thus failing to match the sequence.
Proof by Code: A Reproducible Example
The following test case demonstrates that FuseOps fails to fuse this pattern.
Execution Results
Converted IR (Before FuseOps): A sequence of four separate operations is generated: conv2d → reshape → add → relu.
IR After FuseOps: The IR remains completely unchanged, confirming that the fusion failed.
This failure is a direct result of the pattern in pattern.py:1179-1181 matching only relax.add and not the reshape + add sequence.
Conclusion and Proposal
The generic FuseOps pass cannot handle this frontend-specific pattern, leaving a common PyTorch model structure (conv2d + bias + relu) unoptimized.
Therefore, a specialized pass like FuseConv2dReshapeAddRelu is essential to correctly identify and fuse this pattern. This targeted pass is necessary to bridge the gap between the PyTorch frontend's IR generation and TVM's optimization capabilities, unlocking performance for a wide range of models.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we could extend FuseOps to handle this - that way, other cases could benefit as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a moment, I'll get to it.