Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 106 additions & 83 deletions unstable_source/inductor_cpp_wrapper_tutorial.rst
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
Inductor C++ Wrapper Tutorial
TorchInductor C++ Wrapper Tutorial
==============================================================

**Author**: `Chunyuan Wu <https://github.com/chunyuan-w>`_, `Bin Bao <https://github.com/desertfire>`__, `Jiong Gong <https://github.com/jgong5>`__
Expand All @@ -10,85 +10,119 @@ Prerequisites:
Introduction
------------

Python, as the primary interface of PyTorch, is easy to use and efficient for development and debugging.
The Inductor's default wrapper generates Python code to invoke generated kernels and external kernels.
However, in deployments requiring high performance, Python, as an interpreted language, runs relatively slower compared to compiled languages.
In ``torch.compile``, the default backend **TorchInductor** emits Python wrapper
code that manages memory allocation and kernel invocation. This design provides
flexibility and ease of debugging, but the interpreted nature of Python
introduces runtime overhead in performance-sensitive environments.

We implemented an Inductor C++ wrapper by leveraging the PyTorch C++ APIs
to generate pure C++ code that combines the generated and external kernels.
This allows for the execution of each captured Dynamo graph in pure C++,
thereby reducing the Python overhead within the graph.
To address this limitation, TorchInductor includes a specialized mode that
generates **C++ wrapper code** in place of the Python wrapper, enabling faster
execution with minimal Python involvement.


Enabling the API
Enabling the C++ wrapper mode
----------------
This feature is still in prototype stage. To activate this feature, add the following to your code:
To enable this C++ wrapper mode for TorchInductor, add the following config to your code:

.. code:: python

import torch._inductor.config as config
config.cpp_wrapper = True

This will speed up your models by reducing the Python overhead of the Inductor wrapper.


Example code
------------

We will use the below frontend code as an example:
We will use the following model code as an example:

.. code:: python

import torch
import torch._inductor.config as config

def fn(x):
return torch.tensor(list(range(2, 40, 2)), device=x.device) + x
config.cpp_wrapper = True

x = torch.randn(1)
opt_fn = torch.compile()(fn)
y = opt_fn(x)
def fn(x, y):
return (x + y).sum()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
x = torch.randn(128, 128, device=device)
y = torch.randn(128, 128, device=device)

opt_fn = torch.compile(fn)
result = opt_fn(x, y)


**For CPU**

The main part of Inductor-generated code with the default Python wrapper will look like this:
The main part of TorchInductor-generated code with the default Python wrapper will look like this:

.. code:: python

def call(args):
arg0_1, = args
args.clear()
assert_size_stride(arg0_1, (1, ), (1, ))
buf0 = empty_strided((19, ), (1, ), device='cpu', dtype=torch.float32)
cpp_fused_add_lift_fresh_0(c_void_p(constant0.data_ptr()), c_void_p(arg0_1.data_ptr()), c_void_p(buf0.data_ptr()))
del arg0_1
return (buf0, )
class Runner:
def __init__(self, partitions):
self.partitions = partitions

def call(self, args):
arg0_1, arg1_1 = args
args.clear()
assert_size_stride(arg0_1, (128, 128), (128, 1))
assert_size_stride(arg1_1, (128, 128), (128, 1))
buf0 = empty_strided_cpu((), (), torch.float32)
cpp_fused_add_sum_0(arg0_1, arg1_1, buf0)
del arg0_1
del arg1_1
return (buf0, )

By turning on the C++ wrapper, the generated code for the ``call`` function becomes a C++ function
``inductor_entry_cpp`` of the C++ extension ``module``:
``inductor_entry_impl``:

.. code:: python

std::vector<at::Tensor> inductor_entry_cpp(const std::vector<at::Tensor>& args) {
at::Tensor arg0_1 = args[0];
at::Tensor constant0 = args[1];
auto buf0 = at::empty_strided({19L, }, {1L, }, at::device(at::kCPU).dtype(at::kFloat));
cpp_fused_add_lift_fresh_0((long*)(constant0.data_ptr()), (float*)(arg0_1.data_ptr()), (float*)(buf0.data_ptr()));
cpp_wrapper_src = (
r'''
#include <torch/csrc/inductor/cpp_wrapper/cpu.h>
extern "C" void cpp_fused_add_sum_0(const float* in_ptr0,
const float* in_ptr1,
float* out_ptr0);
CACHE_TORCH_DTYPE(float32);
CACHE_TORCH_DEVICE(cpu);

void inductor_entry_impl(
AtenTensorHandle*
input_handles, // array of input AtenTensorHandle; handles
// are stolen; the array itself is borrowed
AtenTensorHandle*
output_handles // array for writing output AtenTensorHandle; handles
// will be stolen by the caller; the array itself is
// borrowed)
) {
py::gil_scoped_release_simple release;

auto inputs = steal_from_raw_handles_to_raii_handles(input_handles, 2);
auto arg0_1 = std::move(inputs[0]);
auto arg1_1 = std::move(inputs[1]);
static constexpr int64_t *int_array_0=nullptr;
AtenTensorHandle buf0_handle;
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided(0, int_array_0, int_array_0, cached_torch_dtype_float32, cached_torch_device_type_cpu, 0, &buf0_handle));
RAIIAtenTensorHandle buf0(buf0_handle);
cpp_fused_add_sum_0((const float*)(arg0_1.data_ptr()), (const float*)(arg1_1.data_ptr()), (float*)(buf0.data_ptr()));
arg0_1.reset();
return {buf0};
}

module = CppWrapperCodeCache.load(cpp_wrapper_src, 'inductor_entry_cpp', 'c2buojsvlqbywxe3itb43hldieh4jqulk72iswa2awalwev7hjn2', False)

def _wrap_func(f):
def g(args):
args_tensor = [arg if isinstance(arg, torch.Tensor) else torch.tensor(arg) for arg in args]
constants_tensor = [constant0]
args_tensor.extend(constants_tensor)

return f(args_tensor)
return g
call = _wrap_func(module.inductor_entry_cpp)
arg1_1.reset();
output_handles[0] = buf0.release();
} // inductor_entry_impl
...
'''
)

inductor_entry = CppWrapperCodeCache.load_pybinding(
argtypes=["std::vector<AtenTensorHandle>"],
main_code=cpp_wrapper_src,
device_type="cpu",
num_outputs=1,
kernel_code=None,
)

call = _wrap_func(inductor_entry)

**For GPU**

Expand All @@ -113,47 +147,36 @@ Based on the same example code, the generated code for GPU will look like this:
With the C++ wrapper turned on, the below equivalent C++ code will be generated:

.. code:: python

std::vector<at::Tensor> inductor_entry_cpp(const std::vector<at::Tensor>& args) {
at::Tensor arg0_1 = args[0];
at::Tensor constant0 = args[1];

at::cuda::CUDAGuard device_guard(0);
auto buf0 = at::empty_strided({19L, }, {1L, }, at::TensorOptions(c10::Device(at::kCUDA, 0)).dtype(at::kFloat));
// Source Nodes: [add, tensor], Original ATen: [aten.add, aten.lift_fresh]
if (triton_poi_fused_add_lift_fresh_0 == nullptr) {
triton_poi_fused_add_lift_fresh_0 = loadKernel("/tmp/torchinductor_user/mm/cmm6xjgijjffxjku4akv55eyzibirvw6bti6uqmfnruujm5cvvmw.cubin", "triton_poi_fused_add_lift_fresh_0_0d1d2d3");
}
CUdeviceptr var_0 = reinterpret_cast<CUdeviceptr>(constant0.data_ptr());
CUdeviceptr var_1 = reinterpret_cast<CUdeviceptr>(arg0_1.data_ptr());
CUdeviceptr var_2 = reinterpret_cast<CUdeviceptr>(buf0.data_ptr());
auto var_3 = 19;
void* kernel_args_var_0[] = {&var_0, &var_1, &var_2, &var_3};
cudaStream_t stream0 = at::cuda::getCurrentCUDAStream(0);
launchKernel(triton_poi_fused_add_lift_fresh_0, 1, 1, 1, 1, 0, kernel_args_var_0, stream0);
arg0_1.reset();
return {buf0};
}

module = CppWrapperCodeCache.load(cpp_wrapper_src, 'inductor_entry_cpp', 'czbpeilh4qqmbyejdgsbpdfuk2ss5jigl2qjb7xs4gearrjvuwem', True)
inductor_entry = CppWrapperCodeCache.load_pybinding(
argtypes=["std::vector<AtenTensorHandle>"],
main_code=cpp_wrapper_src,
device_type="cuda",
num_outputs=1,
kernel_code=None,
)

def _wrap_func(f):
def g(args):
args_tensor = [arg if isinstance(arg, torch.Tensor) else torch.tensor(arg) for arg in args]
constants_tensor = [constant0]
args_tensor.extend(constants_tensor)
input_tensors = [arg if isinstance(arg, torch.Tensor) else torch.tensor(arg, device='cpu') for arg in args]
input_handles = torch._C._aoti.unsafe_alloc_void_ptrs_from_tensors(input_tensors)

args.clear()
del input_tensors

output_handles = f(input_handles)
output_tensors = torch._C._aoti.alloc_tensors_by_stealing_from_void_ptrs(output_handles)
return output_tensors

return f(args_tensor)
return g
call = _wrap_func(module.inductor_entry_cpp)

call = _wrap_func(inductor_entry)


Conclusion
------------

In this tutorial, we introduced a new C++ wrapper in TorchInductor to speed up your models with just two lines of code changes.
We explained the motivation of this new feature and walked through the easy-to-use API to activate this experimental feature.
Furthermore, we demonstrated the Inductor-generated code using the default Python wrapper and the new C++ wrapper on both CPU and GPU
to visually showcase the difference between these two wrappers.

This feature is still in prototype stage. If you have any feature requests or run into any issues, please file a bug report at `GitHub issues <https://github.com/pytorch/pytorch/issues>`_.
This tutorial introduced the **C++ wrapper** feature in TorchInductor, designed
to improve model performance with minimal code modification. We described the
motivation for this feature, detailed the experimental API used to enable it,
and compared the generated outputs of the default Python wrapper and the new
C++ wrapper on both CPU and GPU backends to illustrate their distinctions.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add links to some torchinductor documentation/related tutorials