-
Notifications
You must be signed in to change notification settings - Fork 517
[Bug Fix] Fix padding when running in NHWC #9729
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/9729
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 2 PendingAs of commit d013b83 with merge base 65ebabb ( NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice catch.
ebc14c9
to
3492e53
Compare
3492e53
to
d013b83
Compare
@pytorchbot cherry-pick --onto release/0.6 -c critical |
### Summary There is a bug when there is a constant_pad between two convolutions. In order to minimize permutes associated with memory format changes, we sometimes compute ops in NHWC. This is the case for ConstantPad when it is between two convs: ``` a = conv(a) a = constant_pad(a, paddings=[1, 2, 3, 4]) a = conv(a) ``` in this case we need to make sure the paddings given to constant_pad are also permuted to nhwc. ### Test plan python install_executorch.py --editable python -m unittest backends.xnnpack.test.ops.test_static_constant_pad.TestStaticConstantPad.test_fp32_static_constant_pad_nhwc (cherry picked from commit 7d35c68)
Cherry picking #9729The cherry pick PR is at #9816 and it is recommended to link a critical cherry pick PR with an issue. The following tracker issues are updated: Details for Dev Infra teamRaised by workflow job |
### Summary There is a bug when there is a constant_pad between two convolutions. In order to minimize permutes associated with memory format changes, we sometimes compute ops in NHWC. This is the case for ConstantPad when it is between two convs: ``` a = conv(a) a = constant_pad(a, paddings=[1, 2, 3, 4]) a = conv(a) ``` in this case we need to make sure the paddings given to constant_pad are also permuted to nhwc. ### Test plan python install_executorch.py --editable python -m unittest backends.xnnpack.test.ops.test_static_constant_pad.TestStaticConstantPad.test_fp32_static_constant_pad_nhwc
Summary
There is a bug when there is a constant_pad between two convolutions. In order to minimize permutes associated with memory format changes, we sometimes compute ops in NHWC. This is the case for ConstantPad when it is between two convs:
in this case we need to make sure the paddings given to constant_pad are also permuted to nhwc.
Test plan
python install_executorch.py --editable
python -m unittest backends.xnnpack.test.ops.test_static_constant_pad.TestStaticConstantPad.test_fp32_static_constant_pad_nhwc