Skip to content

Commit f9ffb32

Browse files
Gasoonjiafacebook-github-bot
authored andcommitted
cuda partioner supported (#14477)
Summary: This diff introduce partitioner for cuda delegated, which driven by aoti library. The partitioner will partition the input into exactly one partitioned graph that contains all operators from the input graph, also will keep all operators (except the ops that couldn't handled by aoti-cuda lib) away from executorch operator decomposition. Operator will be decomposed in the cuda backend using aoti-cuda specific decomposition table. Reviewed By: larryliu0820 Differential Revision: D82987193
1 parent fd9f176 commit f9ffb32

File tree

6 files changed

+266
-0
lines changed

6 files changed

+266
-0
lines changed

backends/cuda/TARGETS

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
3+
oncall("executorch")
4+
5+
runtime.python_library(
6+
name = "cuda_partitioner",
7+
srcs = [
8+
"cuda_partitioner.py",
9+
],
10+
visibility = [
11+
"//executorch/...",
12+
],
13+
deps = [
14+
"//caffe2:torch",
15+
"//executorch/exir/backend:partitioner",
16+
"//executorch/exir/backend:utils",
17+
],
18+
)

backends/cuda/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.

backends/cuda/cuda_partitioner.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Callable, Dict, final, List, Optional, Tuple
8+
9+
import torch
10+
from executorch.exir.backend.compile_spec_schema import CompileSpec
11+
from executorch.exir.backend.partitioner import (
12+
DelegationSpec,
13+
Partitioner,
14+
PartitionResult,
15+
)
16+
from executorch.exir.backend.utils import tag_constant_data
17+
from torch.export.exported_program import ExportedProgram
18+
19+
20+
@final
21+
class CudaPartitioner(Partitioner):
22+
"""
23+
CUDA partitioner for AOTInductor backend integration.
24+
25+
This partitioner creates a single partition containing all operators from the input graph.
26+
It skips core ATen decomposition, allowing the CUDA backend to handle decomposition using
27+
AOTInductor's CUDA-specific decomposition table.
28+
29+
Only operators that cannot be handled by the aoti-cuda library will be excluded from
30+
the partition and fall back to ExecuTorch's default or custom handling.
31+
"""
32+
33+
def __init__(self, compile_spec: List[CompileSpec]) -> None:
34+
self.delegation_spec = DelegationSpec("CudaBackend", compile_spec)
35+
36+
def partition(self, exported_program: ExportedProgram) -> PartitionResult:
37+
"""
38+
Fully delegate the graph to AOTInductor by tagging all nodes as a single partition.
39+
"""
40+
41+
partition_tags: Dict[str, DelegationSpec] = {}
42+
for node in exported_program.graph.nodes:
43+
if node.op != "call_function":
44+
continue
45+
tag = "tag0"
46+
node.meta["delegation_tag"] = tag
47+
partition_tags[tag] = self.delegation_spec
48+
49+
tag_constant_data(exported_program)
50+
51+
return PartitionResult(
52+
tagged_exported_program=exported_program, partition_tags=partition_tags
53+
)
54+
55+
def ops_to_not_decompose(
56+
self, ep: ExportedProgram
57+
) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]:
58+
"""
59+
Return a list of operations that should not be decomposed and let the AOT compiler handle them.
60+
Currently we skip ATen decompositon for all ops, and let the cuda backend handle them.
61+
"""
62+
do_not_decompose = set()
63+
64+
for node in ep.graph.nodes:
65+
if node.op == "call_function" and isinstance(
66+
node.target, torch._ops.OpOverload
67+
):
68+
do_not_decompose.add(node.target)
69+
return list(do_not_decompose), None

backends/cuda/tests/TARGETS

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
load("@fbcode_macros//build_defs:python_unittest.bzl", "python_unittest")
3+
4+
oncall("executorch")
5+
6+
python_unittest(
7+
name = "test_cuda_partitioner",
8+
srcs = [
9+
"test_cuda_partitioner.py",
10+
],
11+
visibility = [
12+
"//executorch/...",
13+
],
14+
deps = [
15+
"//caffe2:torch",
16+
"//executorch/backends/cuda:cuda_partitioner",
17+
"//executorch/exir:lib",
18+
"//executorch/exir/backend:compile_spec_schema",
19+
],
20+
)

backends/cuda/tests/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
from typing import Tuple
9+
10+
import torch
11+
from executorch.backends.cuda.cuda_partitioner import CudaPartitioner
12+
from executorch.exir.backend.compile_spec_schema import CompileSpec
13+
from executorch.exir.backend.partitioner import PartitionResult
14+
from torch.export import export
15+
16+
17+
class TestCudaPartitioner(unittest.TestCase):
18+
"""
19+
Test CUDA partitioner functionality.
20+
21+
After CUDA partitioning, there should be exactly one partitioned graph that contains
22+
all operators from the input graph. This means all operators should be tagged with
23+
the same delegation tag, indicating they will all be executed by the CUDA backend.
24+
"""
25+
26+
def setUp(self):
27+
"""Set up test environment."""
28+
# Skip tests if CUDA is not available
29+
if not torch.cuda.is_available():
30+
self.skipTest("CUDA is not available")
31+
32+
def _get_partition_result(
33+
self, module: torch.nn.Module, inputs: Tuple[torch.Tensor, ...]
34+
) -> PartitionResult:
35+
"""Helper method to get partition result for a given module."""
36+
# Export the model
37+
exported_program = export(module, inputs, strict=True)
38+
39+
# Create partitioner and compile specs
40+
compile_specs = [CompileSpec("cuda_compile_options", b"")]
41+
partitioner = CudaPartitioner(compile_specs)
42+
43+
# Get partition result
44+
partition_result = partitioner.partition(exported_program)
45+
46+
# Verify partition result structure
47+
self.assertIsNotNone(partition_result)
48+
self.assertTrue(hasattr(partition_result, "tagged_exported_program"))
49+
self.assertTrue(hasattr(partition_result, "partition_tags"))
50+
51+
return partition_result
52+
53+
def _check_fully_partitioned(self, partition_result: PartitionResult) -> bool:
54+
"""Check if the graph is fully partitioned (all operators have the same tag)."""
55+
tagged_nodes = []
56+
untagged_ops = []
57+
58+
for node in partition_result.tagged_exported_program.graph.nodes:
59+
if node.op == "call_function":
60+
if hasattr(node, "meta") and "delegation_tag" in node.meta:
61+
tagged_nodes.append(node)
62+
else:
63+
untagged_ops.append(node)
64+
65+
# Check if we have any tagged nodes
66+
if not tagged_nodes:
67+
return False
68+
69+
# Check if all tagged nodes have the same tag
70+
first_tag = tagged_nodes[0].meta["delegation_tag"]
71+
all_same_tag = all(
72+
node.meta.get("delegation_tag") == first_tag for node in tagged_nodes
73+
)
74+
75+
# Should have no untagged operations for full partitioning
76+
fully_partitioned = len(untagged_ops) == 0 and all_same_tag
77+
78+
return fully_partitioned
79+
80+
def test_simple_add_partition(self):
81+
"""
82+
Test that CUDA partitioner creates exactly one partition containing all operators.
83+
Simple element-wise addition should result in a single graph with all ops tagged identically.
84+
"""
85+
86+
class AddModule(torch.nn.Module):
87+
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
88+
return x + y
89+
90+
module = AddModule()
91+
inputs = (torch.randn(3, 4), torch.randn(3, 4))
92+
93+
partition_result = self._get_partition_result(module, inputs)
94+
fully_partitioned = self._check_fully_partitioned(partition_result)
95+
96+
self.assertTrue(
97+
fully_partitioned,
98+
"Graph should be fully partitioned with all operators having the same tag",
99+
)
100+
101+
def test_conv2d_partition(self):
102+
"""
103+
Test that CUDA partitioner creates exactly one partition containing all operators.
104+
Conv2D operation should result in a single graph with all ops tagged identically.
105+
"""
106+
107+
class Conv2dModule(torch.nn.Module):
108+
def __init__(self):
109+
super().__init__()
110+
self.conv = torch.nn.Conv2d(3, 16, kernel_size=3, padding=1)
111+
112+
def forward(self, x: torch.Tensor) -> torch.Tensor:
113+
return self.conv(x)
114+
115+
module = Conv2dModule()
116+
inputs = (torch.randn(1, 3, 32, 32),)
117+
118+
partition_result = self._get_partition_result(module, inputs)
119+
fully_partitioned = self._check_fully_partitioned(partition_result)
120+
121+
self.assertTrue(
122+
fully_partitioned,
123+
"Graph should be fully partitioned with all operators having the same tag",
124+
)
125+
126+
def test_linear_partition(self):
127+
"""
128+
Test that CUDA partitioner creates exactly one partition containing all operators.
129+
Linear layer operation should result in a single graph with all ops tagged identically.
130+
"""
131+
132+
class LinearModule(torch.nn.Module):
133+
def __init__(self):
134+
super().__init__()
135+
self.linear = torch.nn.Linear(128, 64)
136+
137+
def forward(self, x: torch.Tensor) -> torch.Tensor:
138+
return self.linear(x)
139+
140+
module = LinearModule()
141+
inputs = (torch.randn(8, 128),)
142+
143+
partition_result = self._get_partition_result(module, inputs)
144+
fully_partitioned = self._check_fully_partitioned(partition_result)
145+
146+
self.assertTrue(
147+
fully_partitioned,
148+
"Graph should be fully partitioned with all operators having the same tag",
149+
)

0 commit comments

Comments
 (0)