Skip to content

Commit 42f30cc

Browse files
Gasoonjiafacebook-github-bot
authored andcommitted
cuda export supported (pytorch#14478)
Summary: this diff introuce the cuda backend that compiles the partitioned model graph to run on CUDA devices. It uses the AOTInductor compiler to generate optimized CUDA kernels for the model's operators with libtorch-free. The compiled model can be executed on CUDA devices using the Executorch runtime. Reviewed By: angelayi, larryliu0820 Differential Revision: D82987410
1 parent 16ced4e commit 42f30cc

File tree

6 files changed

+465
-4
lines changed

6 files changed

+465
-4
lines changed

backends/cuda/TARGETS

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,22 @@ load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
22

33
oncall("executorch")
44

5+
runtime.python_library(
6+
name = "cuda_backend",
7+
srcs = [
8+
"cuda_backend.py",
9+
],
10+
visibility = [
11+
"//executorch/...",
12+
],
13+
deps = [
14+
"//caffe2:torch",
15+
"//executorch/exir/_serialize:lib",
16+
"//executorch/exir/backend:backend_details",
17+
"//executorch/exir/backend:compile_spec_schema",
18+
],
19+
)
20+
521
runtime.python_library(
622
name = "cuda_partitioner",
723
srcs = [

backends/cuda/cuda_backend.py

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import contextlib
8+
import os
9+
import typing
10+
11+
from typing import Any, Dict, final, List, Optional, Set
12+
13+
import torch
14+
from executorch.exir._serialize._named_data_store import NamedDataStore
15+
from executorch.exir._warnings import experimental
16+
from executorch.exir.backend.backend_details import (
17+
BackendDetails,
18+
ExportedProgram,
19+
PreprocessResult,
20+
)
21+
from executorch.exir.backend.compile_spec_schema import CompileSpec
22+
from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu
23+
from torch.export.passes import move_to_device_pass
24+
25+
26+
# exist fallback operators in et namespace;
27+
supported_fallback_kernels: Dict[str, Any] = {}
28+
29+
# required fallback kernels but not supported
30+
missing_fallback_kernels: Set[str] = set()
31+
32+
33+
# context manager for non-fallback guarantee
34+
# it will raise exception when generating fallback kernels during aoti compile
35+
@contextlib.contextmanager
36+
def collect_unsupported_fallback_kernels():
37+
original_generate_c_shim_extern_kernel_call = (
38+
CppWrapperCpu.generate_c_shim_extern_kernel_call
39+
)
40+
original_generate_fallback_kernel_with_runtime_lookup_aot = (
41+
CppWrapperCpu.generate_fallback_kernel_with_runtime_lookup_aot
42+
)
43+
44+
def generate_c_shim_extern_kernel_call_and_collect_unsupported_kernels(
45+
self,
46+
kernel: str,
47+
args: list[str],
48+
device: str,
49+
*,
50+
debug_args: Optional[list[str]] = None,
51+
):
52+
if kernel not in supported_fallback_kernels:
53+
missing_fallback_kernels.add(kernel)
54+
55+
original_generate_c_shim_extern_kernel_call(
56+
self, kernel, args, device, debug_args=debug_args
57+
)
58+
59+
def generate_fallback_kernel_with_runtime_lookup_aot_and_collect_unsupported_kernels(
60+
self,
61+
op_overload,
62+
raw_args,
63+
output_args,
64+
raw_outputs,
65+
):
66+
# Extract kernel name for collection
67+
kernel_name = getattr(op_overload, "_name", str(op_overload))
68+
if kernel_name not in supported_fallback_kernels:
69+
missing_fallback_kernels.add(kernel_name)
70+
71+
original_generate_fallback_kernel_with_runtime_lookup_aot(
72+
self, op_overload, raw_args, output_args, raw_outputs
73+
)
74+
75+
CppWrapperCpu.generate_c_shim_extern_kernel_call = (
76+
generate_c_shim_extern_kernel_call_and_collect_unsupported_kernels
77+
)
78+
CppWrapperCpu.generate_fallback_kernel_with_runtime_lookup_aot = (
79+
generate_fallback_kernel_with_runtime_lookup_aot_and_collect_unsupported_kernels
80+
)
81+
try:
82+
yield
83+
finally:
84+
CppWrapperCpu.generate_c_shim_extern_kernel_call = (
85+
original_generate_c_shim_extern_kernel_call
86+
)
87+
CppWrapperCpu.generate_fallback_kernel_with_runtime_lookup_aot = (
88+
original_generate_fallback_kernel_with_runtime_lookup_aot
89+
)
90+
91+
92+
@final
93+
@experimental(
94+
"This API and all of cuda backend related functionality are experimental."
95+
)
96+
class CudaBackend(BackendDetails):
97+
"""
98+
CudaBackend is a backend that compiles a model to run on CUDA devices. It uses the AOTInductor compiler to generate
99+
optimized CUDA kernels for the model's operators with libtorch-free. The compiled model can be executed on CUDA devices
100+
using the Executorch runtime.
101+
"""
102+
103+
@staticmethod
104+
def preprocess(
105+
edge_program: ExportedProgram,
106+
compile_specs: List[CompileSpec],
107+
) -> PreprocessResult:
108+
# Move the edge_program from CPU to CUDA for aoti compile
109+
cuda_edge_program = move_to_device_pass(edge_program, "cuda")
110+
111+
edge_program_module = cuda_edge_program.module()
112+
113+
# Grab all input placeholders from the graph
114+
user_input_names = cuda_edge_program.graph_signature.user_inputs
115+
user_input_placeholders = []
116+
for node in cuda_edge_program.graph.nodes:
117+
if node.op == "placeholder" and node.name in user_input_names:
118+
user_input_placeholders.append(node.meta["val"])
119+
120+
# Create pseudo user inputs using torch.randn and metadata from input placeholders
121+
faked_user_inputs = []
122+
for placeholder in user_input_placeholders:
123+
if isinstance(placeholder, torch.Tensor):
124+
# Generate fake input with same shape and dtype, on CUDA
125+
fake_input = torch.randn(
126+
placeholder.shape, dtype=placeholder.dtype, device="cuda"
127+
)
128+
faked_user_inputs.append(fake_input)
129+
130+
faked_user_inputs = tuple(faked_user_inputs)
131+
132+
options: dict[str, typing.Any] = {
133+
# Embed CUDA kernel binaries directly into the compiled shared object
134+
"aot_inductor.embed_kernel_binary": True,
135+
# Do not link against the full PyTorch/libtorch library
136+
"aot_inductor.link_libtorch": False,
137+
# Package model constants and other generated files directly in the shared object (.so) file
138+
"aot_inductor.package_constants_in_so": True,
139+
# Enable maximum automatic tuning for optimal performance
140+
"max_autotune": True,
141+
# Use TRITON for GEMM (General Matrix Multiply) operations tuning only to avoid using operators in libtorch
142+
"max_autotune_gemm_backends": "TRITON",
143+
# Use TRITON backend for convolution operations tuning only to avoid using operators in libtorch
144+
"max_autotune_conv_backends": "TRITON",
145+
}
146+
147+
with collect_unsupported_fallback_kernels():
148+
so_path = torch._inductor.aot_compile(edge_program_module, faked_user_inputs, options=options) # type: ignore[arg-type]
149+
if len(missing_fallback_kernels) > 0:
150+
formatted_kernels = "\n - ".join(sorted(missing_fallback_kernels))
151+
raise RuntimeError(
152+
f"Missing fallback kernels ({len(missing_fallback_kernels)} total):\n - {formatted_kernels}\n"
153+
"Please add them to the AOTI backend."
154+
)
155+
156+
# pyre-ignorep[6]: Incompatible parameter type
157+
with open(so_path, "rb") as f:
158+
so_data = f.read()
159+
160+
named_data_store = NamedDataStore()
161+
named_data_store.add_named_data("so_blob", so_data, 1, "aoti_cuda_blob")
162+
163+
# Clean up the generated so file; it has been packaged into the NamdeDataStore
164+
# pyre-ignorep[6]: Incompatible parameter type
165+
os.remove(so_path)
166+
167+
return PreprocessResult(
168+
processed_bytes=b"",
169+
debug_handle_map={},
170+
data_store_output=named_data_store.get_named_data_store_output(),
171+
)

backends/cuda/cuda_partitioner.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,22 @@
77
from typing import Callable, Dict, final, List, Optional, Tuple
88

99
import torch
10+
from executorch.backends.cuda.cuda_backend import CudaBackend # usort: skip
1011
from executorch.exir.backend.compile_spec_schema import CompileSpec
1112
from executorch.exir.backend.partitioner import (
1213
DelegationSpec,
1314
Partitioner,
1415
PartitionResult,
1516
)
1617
from executorch.exir.backend.utils import tag_constant_data
18+
from executorch.exir._warnings import experimental
1719
from torch.export.exported_program import ExportedProgram
1820

1921

2022
@final
23+
@experimental(
24+
"This API and all of cuda backend related functionality are experimental."
25+
)
2126
class CudaPartitioner(Partitioner):
2227
"""
2328
CUDA partitioner for AOTInductor backend integration.
@@ -31,7 +36,7 @@ class CudaPartitioner(Partitioner):
3136
"""
3237

3338
def __init__(self, compile_spec: List[CompileSpec]) -> None:
34-
self.delegation_spec = DelegationSpec("CudaBackend", compile_spec)
39+
self.delegation_spec = DelegationSpec(CudaBackend.__name__, compile_spec)
3540

3641
def partition(self, exported_program: ExportedProgram) -> PartitionResult:
3742
"""

backends/cuda/tests/TARGETS

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,28 @@
11
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
22
load("@fbcode_macros//build_defs:python_unittest.bzl", "python_unittest")
3+
load("@fbcode_macros//build_defs:python_unittest_remote_gpu.bzl", "python_unittest_remote_gpu")
34

45
oncall("executorch")
56

7+
python_unittest_remote_gpu(
8+
name = "test_cuda_export",
9+
srcs = [
10+
"test_cuda_export.py",
11+
],
12+
visibility = [
13+
"//executorch/...",
14+
],
15+
deps = [
16+
"//caffe2:torch",
17+
"//executorch/backends/cuda:cuda_backend",
18+
"//executorch/backends/cuda:cuda_partitioner",
19+
"//executorch/exir:lib",
20+
"//executorch/exir/backend:backend_api",
21+
"//executorch/exir/backend:compile_spec_schema",
22+
],
23+
keep_gpu_sections = True,
24+
)
25+
626
python_unittest(
727
name = "test_cuda_partitioner",
828
srcs = [
@@ -14,6 +34,7 @@ python_unittest(
1434
deps = [
1535
"//caffe2:torch",
1636
"//executorch/backends/cuda:cuda_partitioner",
37+
"//executorch/backends/cuda:cuda_backend",
1738
"//executorch/exir:lib",
1839
"//executorch/exir/backend:compile_spec_schema",
1940
],

0 commit comments

Comments
 (0)