From 5b4fa4a898ae23dbfeb065bbc73225410d1a6368 Mon Sep 17 00:00:00 2001 From: Bin Bao Date: Mon, 13 Oct 2025 08:07:12 -0700 Subject: [PATCH] Update inductor_cpp_wrapper_tutorial.rst Summary: Update the TorchInductor cpp-wrapper mode tutorial with update-to-date information. --- .../inductor_cpp_wrapper_tutorial.rst | 189 ++++++++++-------- 1 file changed, 106 insertions(+), 83 deletions(-) diff --git a/unstable_source/inductor_cpp_wrapper_tutorial.rst b/unstable_source/inductor_cpp_wrapper_tutorial.rst index 4bcc9009075..9812e956fdb 100644 --- a/unstable_source/inductor_cpp_wrapper_tutorial.rst +++ b/unstable_source/inductor_cpp_wrapper_tutorial.rst @@ -1,4 +1,4 @@ -Inductor C++ Wrapper Tutorial +TorchInductor C++ Wrapper Tutorial ============================================================== **Author**: `Chunyuan Wu `_, `Bin Bao `__, `Jiong Gong `__ @@ -10,85 +10,119 @@ Prerequisites: Introduction ------------ -Python, as the primary interface of PyTorch, is easy to use and efficient for development and debugging. -The Inductor's default wrapper generates Python code to invoke generated kernels and external kernels. -However, in deployments requiring high performance, Python, as an interpreted language, runs relatively slower compared to compiled languages. +In ``torch.compile``, the default backend **TorchInductor** emits Python wrapper +code that manages memory allocation and kernel invocation. This design provides +flexibility and ease of debugging, but the interpreted nature of Python +introduces runtime overhead in performance-sensitive environments. -We implemented an Inductor C++ wrapper by leveraging the PyTorch C++ APIs -to generate pure C++ code that combines the generated and external kernels. -This allows for the execution of each captured Dynamo graph in pure C++, -thereby reducing the Python overhead within the graph. +To address this limitation, TorchInductor includes a specialized mode that +generates **C++ wrapper code** in place of the Python wrapper, enabling faster +execution with minimal Python involvement. -Enabling the API +Enabling the C++ wrapper mode ---------------- -This feature is still in prototype stage. To activate this feature, add the following to your code: +To enable this C++ wrapper mode for TorchInductor, add the following config to your code: .. code:: python import torch._inductor.config as config config.cpp_wrapper = True -This will speed up your models by reducing the Python overhead of the Inductor wrapper. - Example code ------------ -We will use the below frontend code as an example: +We will use the following model code as an example: .. code:: python - + import torch + import torch._inductor.config as config - def fn(x): - return torch.tensor(list(range(2, 40, 2)), device=x.device) + x + config.cpp_wrapper = True - x = torch.randn(1) - opt_fn = torch.compile()(fn) - y = opt_fn(x) + def fn(x, y): + return (x + y).sum() + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + x = torch.randn(128, 128, device=device) + y = torch.randn(128, 128, device=device) + + opt_fn = torch.compile(fn) + result = opt_fn(x, y) **For CPU** -The main part of Inductor-generated code with the default Python wrapper will look like this: +The main part of TorchInductor-generated code with the default Python wrapper will look like this: .. code:: python - def call(args): - arg0_1, = args - args.clear() - assert_size_stride(arg0_1, (1, ), (1, )) - buf0 = empty_strided((19, ), (1, ), device='cpu', dtype=torch.float32) - cpp_fused_add_lift_fresh_0(c_void_p(constant0.data_ptr()), c_void_p(arg0_1.data_ptr()), c_void_p(buf0.data_ptr())) - del arg0_1 - return (buf0, ) + class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def call(self, args): + arg0_1, arg1_1 = args + args.clear() + assert_size_stride(arg0_1, (128, 128), (128, 1)) + assert_size_stride(arg1_1, (128, 128), (128, 1)) + buf0 = empty_strided_cpu((), (), torch.float32) + cpp_fused_add_sum_0(arg0_1, arg1_1, buf0) + del arg0_1 + del arg1_1 + return (buf0, ) By turning on the C++ wrapper, the generated code for the ``call`` function becomes a C++ function -``inductor_entry_cpp`` of the C++ extension ``module``: +``inductor_entry_impl``: .. code:: python - - std::vector inductor_entry_cpp(const std::vector& args) { - at::Tensor arg0_1 = args[0]; - at::Tensor constant0 = args[1]; - auto buf0 = at::empty_strided({19L, }, {1L, }, at::device(at::kCPU).dtype(at::kFloat)); - cpp_fused_add_lift_fresh_0((long*)(constant0.data_ptr()), (float*)(arg0_1.data_ptr()), (float*)(buf0.data_ptr())); + cpp_wrapper_src = ( + r''' + #include + extern "C" void cpp_fused_add_sum_0(const float* in_ptr0, + const float* in_ptr1, + float* out_ptr0); + CACHE_TORCH_DTYPE(float32); + CACHE_TORCH_DEVICE(cpu); + + void inductor_entry_impl( + AtenTensorHandle* + input_handles, // array of input AtenTensorHandle; handles + // are stolen; the array itself is borrowed + AtenTensorHandle* + output_handles // array for writing output AtenTensorHandle; handles + // will be stolen by the caller; the array itself is + // borrowed) + ) { + py::gil_scoped_release_simple release; + + auto inputs = steal_from_raw_handles_to_raii_handles(input_handles, 2); + auto arg0_1 = std::move(inputs[0]); + auto arg1_1 = std::move(inputs[1]); + static constexpr int64_t *int_array_0=nullptr; + AtenTensorHandle buf0_handle; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided(0, int_array_0, int_array_0, cached_torch_dtype_float32, cached_torch_device_type_cpu, 0, &buf0_handle)); + RAIIAtenTensorHandle buf0(buf0_handle); + cpp_fused_add_sum_0((const float*)(arg0_1.data_ptr()), (const float*)(arg1_1.data_ptr()), (float*)(buf0.data_ptr())); arg0_1.reset(); - return {buf0}; - } - - module = CppWrapperCodeCache.load(cpp_wrapper_src, 'inductor_entry_cpp', 'c2buojsvlqbywxe3itb43hldieh4jqulk72iswa2awalwev7hjn2', False) - - def _wrap_func(f): - def g(args): - args_tensor = [arg if isinstance(arg, torch.Tensor) else torch.tensor(arg) for arg in args] - constants_tensor = [constant0] - args_tensor.extend(constants_tensor) - - return f(args_tensor) - return g - call = _wrap_func(module.inductor_entry_cpp) + arg1_1.reset(); + output_handles[0] = buf0.release(); + } // inductor_entry_impl + ... + ''' + ) + + inductor_entry = CppWrapperCodeCache.load_pybinding( + argtypes=["std::vector"], + main_code=cpp_wrapper_src, + device_type="cpu", + num_outputs=1, + kernel_code=None, + ) + + call = _wrap_func(inductor_entry) **For GPU** @@ -113,47 +147,36 @@ Based on the same example code, the generated code for GPU will look like this: With the C++ wrapper turned on, the below equivalent C++ code will be generated: .. code:: python - - std::vector inductor_entry_cpp(const std::vector& args) { - at::Tensor arg0_1 = args[0]; - at::Tensor constant0 = args[1]; - - at::cuda::CUDAGuard device_guard(0); - auto buf0 = at::empty_strided({19L, }, {1L, }, at::TensorOptions(c10::Device(at::kCUDA, 0)).dtype(at::kFloat)); - // Source Nodes: [add, tensor], Original ATen: [aten.add, aten.lift_fresh] - if (triton_poi_fused_add_lift_fresh_0 == nullptr) { - triton_poi_fused_add_lift_fresh_0 = loadKernel("/tmp/torchinductor_user/mm/cmm6xjgijjffxjku4akv55eyzibirvw6bti6uqmfnruujm5cvvmw.cubin", "triton_poi_fused_add_lift_fresh_0_0d1d2d3"); - } - CUdeviceptr var_0 = reinterpret_cast(constant0.data_ptr()); - CUdeviceptr var_1 = reinterpret_cast(arg0_1.data_ptr()); - CUdeviceptr var_2 = reinterpret_cast(buf0.data_ptr()); - auto var_3 = 19; - void* kernel_args_var_0[] = {&var_0, &var_1, &var_2, &var_3}; - cudaStream_t stream0 = at::cuda::getCurrentCUDAStream(0); - launchKernel(triton_poi_fused_add_lift_fresh_0, 1, 1, 1, 1, 0, kernel_args_var_0, stream0); - arg0_1.reset(); - return {buf0}; - } - - module = CppWrapperCodeCache.load(cpp_wrapper_src, 'inductor_entry_cpp', 'czbpeilh4qqmbyejdgsbpdfuk2ss5jigl2qjb7xs4gearrjvuwem', True) + inductor_entry = CppWrapperCodeCache.load_pybinding( + argtypes=["std::vector"], + main_code=cpp_wrapper_src, + device_type="cuda", + num_outputs=1, + kernel_code=None, + ) def _wrap_func(f): def g(args): - args_tensor = [arg if isinstance(arg, torch.Tensor) else torch.tensor(arg) for arg in args] - constants_tensor = [constant0] - args_tensor.extend(constants_tensor) + input_tensors = [arg if isinstance(arg, torch.Tensor) else torch.tensor(arg, device='cpu') for arg in args] + input_handles = torch._C._aoti.unsafe_alloc_void_ptrs_from_tensors(input_tensors) + + args.clear() + del input_tensors + + output_handles = f(input_handles) + output_tensors = torch._C._aoti.alloc_tensors_by_stealing_from_void_ptrs(output_handles) + return output_tensors - return f(args_tensor) return g - call = _wrap_func(module.inductor_entry_cpp) + + call = _wrap_func(inductor_entry) Conclusion ------------ -In this tutorial, we introduced a new C++ wrapper in TorchInductor to speed up your models with just two lines of code changes. -We explained the motivation of this new feature and walked through the easy-to-use API to activate this experimental feature. -Furthermore, we demonstrated the Inductor-generated code using the default Python wrapper and the new C++ wrapper on both CPU and GPU -to visually showcase the difference between these two wrappers. - -This feature is still in prototype stage. If you have any feature requests or run into any issues, please file a bug report at `GitHub issues `_. +This tutorial introduced the **C++ wrapper** feature in TorchInductor, designed +to improve model performance with minimal code modification. We described the +motivation for this feature, detailed the experimental API used to enable it, +and compared the generated outputs of the default Python wrapper and the new +C++ wrapper on both CPU and GPU backends to illustrate their distinctions.