Skip to content

Commit 840e7e9

Browse files
pytorchbotGasoonjiaJacobSzwejbka
authored
decompose con1d to conv2d for cuda backend (#15092)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #15017 by @Gasoonjia ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/gasoonjia/58/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/gasoonjia/58/head Merge bot PR base: https://github.com/pytorch/executorch/tree/main Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/gasoonjia/58/orig Differential Revision: [D84296877](https://our.internmc.facebook.com/intern/diff/D84296877/) @diff-train-skip-merge Co-authored-by: gasoonjia <[email protected]> Co-authored-by: Jacob Szwejbka <[email protected]>
1 parent 5ed5097 commit 840e7e9

File tree

6 files changed

+57
-2
lines changed

6 files changed

+57
-2
lines changed

.github/workflows/cuda.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ jobs:
7171
strategy:
7272
fail-fast: false
7373
matrix:
74-
model: [linear, add, add_mul, resnet18]
74+
model: [linear, add, add_mul, resnet18, conv1d]
7575
with:
7676
timeout: 90
7777
runner: linux.g5.4xlarge.nvidia.gpu

backends/cuda/cuda_backend.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,14 @@
2424
)
2525
from executorch.exir.backend.compile_spec_schema import CompileSpec
2626
from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu
27+
from torch._inductor.decomposition import conv1d_to_conv2d
2728
from torch.export.passes import move_to_device_pass
2829
from torch.nn.attention import SDPBackend
2930

31+
cuda_decomposition_table = {
32+
torch.ops.aten.conv1d.default: conv1d_to_conv2d,
33+
}
34+
3035
# exist fallback operators in et namespace;
3136
supported_fallback_kernels: Dict[str, Any] = {}
3237

@@ -119,6 +124,10 @@ def preprocess(
119124
# replace slice_copy with slice
120125
ReplaceSliceCopyWithSlicePass()(cuda_edge_program.graph_module)
121126

127+
cuda_edge_program = cuda_edge_program.run_decompositions(
128+
cuda_decomposition_table
129+
)
130+
122131
edge_program_module = cuda_edge_program.module()
123132

124133
# Grab all input placeholders from the graph

backends/cuda/tests/test_cuda_export.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,3 +251,22 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
251251
self.assertIsNotNone(
252252
edge_program_manager, "Mathematical operations export failed"
253253
)
254+
255+
def test_conv1d(self):
256+
"""Test CUDA export for 1D convolution."""
257+
258+
class Conv1dModule(torch.nn.Module):
259+
def __init__(self):
260+
super().__init__()
261+
self.conv = torch.nn.Conv1d(3, 16, kernel_size=3, padding=1)
262+
263+
def forward(self, x: torch.Tensor) -> torch.Tensor:
264+
return self.conv(x)
265+
266+
module = Conv1dModule()
267+
module.eval()
268+
inputs = (torch.randn(1, 3, 10),)
269+
270+
# Test export
271+
edge_program_manager = self._export_to_cuda_with_lower(module, inputs)
272+
self.assertIsNotNone(edge_program_manager, "Conv1d operation export failed")

examples/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ class Model(str, Enum):
1414
Add = "add"
1515
AddMul = "add_mul"
1616
Softmax = "softmax"
17+
Conv1d = "conv1d"
1718
Dl3 = "dl3"
1819
Edsr = "edsr"
1920
EmformerTranscribe = "emformer_transcribe"
@@ -59,6 +60,7 @@ def __str__(self) -> str:
5960
str(Model.Add): ("toy_model", "AddModule"),
6061
str(Model.AddMul): ("toy_model", "AddMulModule"),
6162
str(Model.Softmax): ("toy_model", "SoftmaxModule"),
63+
str(Model.Conv1d): ("toy_model", "Conv1dModule"),
6264
str(Model.Dl3): ("deeplab_v3", "DeepLabV3ResNet50Model"),
6365
str(Model.Edsr): ("edsr", "EdsrModel"),
6466
str(Model.EmformerTranscribe): ("emformer_rnnt", "EmformerRnntTranscriberModel"),

examples/models/toy_model/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,19 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from .model import AddModule, AddMulModule, LinearModule, MulModule, SoftmaxModule
7+
from .model import (
8+
AddModule,
9+
AddMulModule,
10+
Conv1dModule,
11+
LinearModule,
12+
MulModule,
13+
SoftmaxModule,
14+
)
815

916
__all__ = [
1017
AddModule,
1118
AddMulModule,
19+
Conv1dModule,
1220
LinearModule,
1321
MulModule,
1422
SoftmaxModule,

examples/models/toy_model/model.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,3 +88,20 @@ def get_eager_model(self) -> torch.nn.Module:
8888

8989
def get_example_inputs(self):
9090
return (torch.ones(2, 2),)
91+
92+
93+
class Conv1dModule(torch.nn.Module, EagerModelBase):
94+
def __init__(self):
95+
super().__init__()
96+
self.conv1d = torch.nn.Conv1d(
97+
in_channels=3, out_channels=16, kernel_size=3, padding=1
98+
)
99+
100+
def forward(self, x):
101+
return self.conv1d(x)
102+
103+
def get_eager_model(self) -> torch.nn.Module:
104+
return self
105+
106+
def get_example_inputs(self):
107+
return (torch.randn(1, 3, 10),)

0 commit comments

Comments
 (0)