diff --git a/.gitignore b/.gitignore index 23e35b9d..cb2c2aa2 100644 --- a/.gitignore +++ b/.gitignore @@ -295,3 +295,6 @@ _docs/ .gdb_history build/ + +*.cubin +*.fatbin diff --git a/cmake/Utils/EmbedCubin.cmake b/cmake/Utils/EmbedCubin.cmake index 26af6d82..4ad0357f 100644 --- a/cmake/Utils/EmbedCubin.cmake +++ b/cmake/Utils/EmbedCubin.cmake @@ -15,215 +15,132 @@ # specific language governing permissions and limitations # under the License. +# If CMAKE_CUDA_RUNTIME_LIBRARY is not set, we default it to Shared. This prevents static linking of +# cudart which requires exact driver version match. +if (NOT DEFINED CMAKE_CUDA_RUNTIME_LIBRARY) + set(CMAKE_CUDA_RUNTIME_LIBRARY Shared) + message(STATUS "CMAKE_CUDA_RUNTIME_LIBRARY not set, defaulting to Shared. " + "If you want to use driver API only, set CMAKE_CUDA_RUNTIME_LIBRARY to None." + ) +endif () + +set(OBJECT_COPY_UTIL "${CMAKE_CURRENT_LIST_DIR}/ObjectCopyUtil.cmake") + # ~~~ -# tvm_ffi_generate_cubin( -# OUTPUT -# SOURCE -# [ARCH ] -# [OPTIONS ...] -# [DEPENDS ...] -# ) +# add_tvm_ffi_cubin( CUDA ) # -# Compiles a CUDA source file to CUBIN format using nvcc. +# Creates an object library that compiles CUDA source to CUBIN format. +# This function uses CMake's native CUDA support and respects CMAKE_CUDA_ARCHITECTURES. +# This is a compatibility util for cmake < 3.27, user can create +# cmake target with `CUDA_CUBIN_COMPILATION` for cmake >= 3.27. # # Parameters: -# OUTPUT: Path to the output CUBIN file (e.g., kernel.cubin) -# SOURCE: Path to the CUDA source file (e.g., kernel.cu) -# ARCH: Target GPU architecture (default: native for auto-detection) -# Examples: sm_75, sm_80, sm_86, compute_80, native -# OPTIONS: Additional nvcc compiler options (e.g., -O3, --use_fast_math) -# DEPENDS: Optional additional dependencies -# -# The function will: -# 1. Find the CUDA compiler (nvcc) -# 2. Compile the SOURCE to CUBIN with specified architecture and options -# 3. Create the output CUBIN file +# target_name: Name of the object library target +# CUDA: One CUDA source file # # Example: -# tvm_ffi_generate_cubin( -# OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/kernel.cubin -# SOURCE src/kernel.cu -# ARCH native -# OPTIONS -O3 --use_fast_math -# ) +# add_tvm_ffi_cubin(my_kernel_cubin CUDA kernel.cu) # ~~~ - -# cmake-lint: disable=C0111,C0103 -function (tvm_ffi_generate_cubin) - # Parse arguments - set(options "") - set(oneValueArgs OUTPUT SOURCE ARCH) - set(multiValueArgs OPTIONS DEPENDS) - cmake_parse_arguments(ARG "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) - - # Validate required arguments - if (NOT ARG_OUTPUT) - message(FATAL_ERROR "tvm_ffi_generate_cubin: OUTPUT is required") - endif () - if (NOT ARG_SOURCE) - message(FATAL_ERROR "tvm_ffi_generate_cubin: SOURCE is required") +function (add_tvm_ffi_cubin target_name) + cmake_parse_arguments(ARG "" "CUDA" "" ${ARGN}) + if (NOT ARG_CUDA) + message(FATAL_ERROR "add_tvm_ffi_cubin: CUDA source is required") endif () - # Default architecture to native if not specified - if (NOT ARG_ARCH) - set(ARG_ARCH "native") - endif () - - # Ensure CUDA compiler is available - if (NOT CMAKE_CUDA_COMPILER) - message( - FATAL_ERROR - "tvm_ffi_generate_cubin: CMAKE_CUDA_COMPILER not found. Enable CUDA language in project()." - ) - endif () + add_library(${target_name} OBJECT ${ARG_CUDA}) + target_compile_options(${target_name} PRIVATE $<$:--cubin>) - # Get absolute paths - get_filename_component(ARG_SOURCE_ABS "${ARG_SOURCE}" ABSOLUTE) - get_filename_component(ARG_OUTPUT_ABS "${ARG_OUTPUT}" ABSOLUTE) - - # Build nvcc command - add_custom_command( - OUTPUT "${ARG_OUTPUT_ABS}" - COMMAND ${CMAKE_CUDA_COMPILER} --cubin -arch=${ARG_ARCH} ${ARG_OPTIONS} "${ARG_SOURCE_ABS}" -o - "${ARG_OUTPUT_ABS}" - DEPENDS "${ARG_SOURCE_ABS}" ${ARG_DEPENDS} - COMMENT "Compiling ${ARG_SOURCE} to CUBIN (arch: ${ARG_ARCH})" + add_custom_target( + ${target_name}_bin ALL + COMMAND ${CMAKE_COMMAND} -DOBJECTS="$" -DOUT_DIR="" -DEXT="cubin" + -P "${OBJECT_COPY_UTIL}" + DEPENDS ${target_name} + COMMENT "Generating .cubin files for ${target_name}" VERBATIM ) endfunction () # ~~~ -# tvm_ffi_embed_cubin( -# OUTPUT -# SOURCE -# CUBIN -# NAME -# [DEPENDS ...] -# ) +# add_tvm_ffi_fatbin( CUDA ) # -# Compiles a C++ source file and embeds a CUBIN file into it, creating a -# combined object file that can be linked into a shared library or executable. +# Creates an object library that compiles CUDA source to FATBIN format. +# This function uses CMake's native CUDA support and respects CMAKE_CUDA_ARCHITECTURES. +# This is a compatibility util for cmake < 3.27, user can create +# cmake target with `CUDA_FATBIN_COMPILATION` for cmake >= 3.27. # # Parameters: -# OUTPUT: Path to the output object file (e.g., lib_embedded_with_cubin.o) -# SOURCE: Path to the C++ source file that uses TVM_FFI_EMBED_CUBIN macro -# CUBIN: Path to the CUBIN file to embed (can be a file path or a custom target output) -# NAME: Name used in the TVM_FFI_EMBED_CUBIN macro (e.g., "env" for TVM_FFI_EMBED_CUBIN(env)) -# DEPENDS: Optional additional dependencies (e.g., custom targets) -# -# The function will: -# 1. Compile the SOURCE file to an intermediate object file -# 2. Use the tvm_ffi.utils.embed_cubin Python utility to merge the object file -# with the CUBIN data -# 3. Create symbols: __tvm_ffi__cubin_ and __tvm_ffi__cubin__end +# target_name: Name of the object library target +# CUDA: One CUDA source file # # Example: -# tvm_ffi_embed_cubin( -# OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/lib_embedded_with_cubin.o -# SOURCE src/lib_embedded.cc -# CUBIN ${CMAKE_CURRENT_BINARY_DIR}/kernel.cubin -# NAME env -# ) -# -# add_library(lib_embedded SHARED ${CMAKE_CURRENT_BINARY_DIR}/lib_embedded_with_cubin.o) -# target_link_libraries(lib_embedded PRIVATE tvm_ffi_header CUDA::cudart) -# -# Note: The .note.GNU-stack section is automatically added to mark the stack as -# non-executable, so you don't need to add linker options manually +# add_tvm_ffi_fatbin(my_kernel_cubin CUDA kernel.cu) # ~~~ - -# cmake-lint: disable=C0111,C0103 -function (tvm_ffi_embed_cubin) - # Parse arguments - set(options "") - set(oneValueArgs OUTPUT SOURCE CUBIN NAME) - set(multiValueArgs DEPENDS) - cmake_parse_arguments(ARG "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) - - # Validate required arguments - if (NOT ARG_OUTPUT) - message(FATAL_ERROR "tvm_ffi_embed_cubin: OUTPUT is required") - endif () - if (NOT ARG_SOURCE) - message(FATAL_ERROR "tvm_ffi_embed_cubin: SOURCE is required") - endif () - if (NOT ARG_CUBIN) - message(FATAL_ERROR "tvm_ffi_embed_cubin: CUBIN is required") - endif () - if (NOT ARG_NAME) - message(FATAL_ERROR "tvm_ffi_embed_cubin: NAME is required") +function (add_tvm_ffi_fatbin target_name) + cmake_parse_arguments(ARG "" "CUDA" "" ${ARGN}) + if (NOT ARG_CUDA) + message(FATAL_ERROR "add_tvm_ffi_fatbin: CUDA source is required") endif () - # Ensure Python is found (prefer virtualenv) - if (NOT Python_EXECUTABLE) - set(Python_FIND_VIRTUALENV FIRST) - find_package( - Python - COMPONENTS Interpreter - REQUIRED - ) - endif () + add_library(${target_name} OBJECT ${ARG_CUDA}) + target_compile_options(${target_name} PRIVATE $<$:--fatbin>) - # Get absolute paths - get_filename_component(ARG_SOURCE_ABS "${ARG_SOURCE}" ABSOLUTE) - get_filename_component(ARG_OUTPUT_ABS "${ARG_OUTPUT}" ABSOLUTE) + add_custom_target( + ${target_name}_bin ALL + COMMAND ${CMAKE_COMMAND} -DOBJECTS="$" -DOUT_DIR="" + -DEXT="fatbin" -P "${OBJECT_COPY_UTIL}" + DEPENDS ${target_name} + COMMENT "Generating .fatbin files for ${target_name}" + VERBATIM + ) +endfunction () - # Generate intermediate object file path - get_filename_component(OUTPUT_DIR "${ARG_OUTPUT_ABS}" DIRECTORY) - get_filename_component(OUTPUT_NAME "${ARG_OUTPUT_ABS}" NAME_WE) - set(INTERMEDIATE_OBJ "${OUTPUT_DIR}/${OUTPUT_NAME}_intermediate.o") +# ~~~ +# tvm_ffi_embed_bin_into( +# SYMBOL +# BIN ) +# +# Embed one cubin/fatbin into given target with specified library name, +# can be loaded with `TVM_FFI_EMBED_CUBIN(symbol_name)`. +# Can only have one object in target and one cubin/fatbin. +# +# The reason of this design is to integrate with cmake's workflow. +# +# Parameters: +# target_name: Name of the object library target +# symbol_name: Name of the symbol in TVM_FFI_EMBED_CUBIN macro. +# BIN: CUBIN or FATBIN file +# +# Example: +# tvm_ffi_embed_bin_into(lib_embedded SYMBOL env BIN "$") +# ~~~ +function (tvm_ffi_embed_bin_into target_name) + cmake_parse_arguments(ARG "" "SYMBOL;BIN" "" ${ARGN}) - # Get include directories from tvm_ffi header target - if (TARGET tvm_ffi::header) - set(TVM_FFI_HEADER_TARGET tvm_ffi::header) - elseif (TARGET tvm_ffi_header) - set(TVM_FFI_HEADER_TARGET tvm_ffi_header) - else () - message( - FATAL_ERROR - "tvm_ffi_embed_cubin: required target 'tvm_ffi::header' or 'tvm_ffi_header' does not exist." - ) + if (NOT ARG_BIN) + message(FATAL_ERROR "tvm_ffi_embed_bin_into: BIN is required") endif () - get_target_property(TVM_FFI_INCLUDES ${TVM_FFI_HEADER_TARGET} INTERFACE_INCLUDE_DIRECTORIES) - - # Convert list to -I flags - set(INCLUDE_FLAGS "") - foreach (inc_dir ${TVM_FFI_INCLUDES}) - list(APPEND INCLUDE_FLAGS "-I${inc_dir}") - endforeach () - - # Add CUDA include directories if CUDAToolkit is found - if (TARGET CUDA::cudart) - get_target_property(CUDA_INCLUDES CUDA::cudart INTERFACE_INCLUDE_DIRECTORIES) - foreach (inc_dir ${CUDA_INCLUDES}) - list(APPEND INCLUDE_FLAGS "-I${inc_dir}") - endforeach () + if (NOT ARG_SYMBOL) + message(FATAL_ERROR "tvm_ffi_embed_bin_into: SYMBOL is required") endif () - # Step 1: Compile source file to intermediate object file + set(intermediate_path "${CMAKE_CURRENT_BINARY_DIR}/${ARG_SYMBOL}_intermediate.o") + add_custom_command( - OUTPUT "${INTERMEDIATE_OBJ}" - COMMAND ${CMAKE_CXX_COMPILER} -c -fPIC -std=c++17 ${INCLUDE_FLAGS} "${ARG_SOURCE_ABS}" -o - "${INTERMEDIATE_OBJ}" - DEPENDS "${ARG_SOURCE_ABS}" - COMMENT "Compiling ${ARG_SOURCE} to intermediate object file" - VERBATIM + TARGET ${target_name} + PRE_LINK + COMMAND ${CMAKE_COMMAND} -E copy_if_different "$" + "${intermediate_path}" + COMMENT "Moving $ -> ${intermediate_path}" ) - # Step 2: Embed CUBIN into the object file using Python utility Note: The Python utility - # automatically adds .note.GNU-stack section add_custom_command( - OUTPUT "${ARG_OUTPUT_ABS}" - COMMAND ${Python_EXECUTABLE} -m tvm_ffi.utils.embed_cubin --output-obj "${ARG_OUTPUT_ABS}" - --input-obj "${INTERMEDIATE_OBJ}" --cubin "${ARG_CUBIN}" --name "${ARG_NAME}" - DEPENDS "${INTERMEDIATE_OBJ}" "${ARG_CUBIN}" ${ARG_DEPENDS} - COMMENT "Embedding CUBIN into object file (name: ${ARG_NAME})" + TARGET ${target_name} + PRE_LINK + COMMAND + ${Python_EXECUTABLE} -m tvm_ffi.utils.embed_cubin --output-obj + "$" --name "${ARG_SYMBOL}" --input-obj "${intermediate_path}" + --cubin "${ARG_BIN}" + COMMENT "Embedding CUBIN into object file (name: ${ARG_SYMBOL})" VERBATIM ) - - # Set a variable in parent scope so users can add dependencies - set(${ARG_NAME}_EMBEDDED_OBJ - "${ARG_OUTPUT_ABS}" - PARENT_SCOPE - ) endfunction () diff --git a/cmake/Utils/ObjectCopyUtil.cmake b/cmake/Utils/ObjectCopyUtil.cmake new file mode 100644 index 00000000..50281cf8 --- /dev/null +++ b/cmake/Utils/ObjectCopyUtil.cmake @@ -0,0 +1,48 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# We need this to simulate `CUDA_{CUBIN,FATBIN}_COMPILATION` in `add_tvm_ffi_{cubin,fatbin}`, to +# copy `a.cu.o` to `a.cubin`/`a.fatbin`. + +# Usage: cmake -DOBJECTS=;...; -DOUT_DIR= +# -DEXT= -P + +# Parameter: OBJECTS: semicolon-separated list of input object files; OUT_DIR: output directory, +# empty for the same directory as the object file EXT: extension to rename to + +string(REPLACE "\"" "" ext_strip "${EXT}") +string(REPLACE "\"" "" out_dir_strip "${OUT_DIR}") +foreach (obj_raw ${OBJECTS}) + string(REPLACE "\"" "" obj "${obj_raw}") + + # Extract filename: /path/to/kernel.cu.o -> kernel Note: CMake objects are usually named + # source.cu.o, so we strip extensions twice. + get_filename_component(fname ${obj} NAME_WE) + get_filename_component(fname ${fname} NAME_WE) + + # If OUT_DIR is provided, use it. Otherwise, use the object's directory. + if (NOT out_dir_strip STREQUAL "") + set(FINAL_DIR "${out_dir_strip}") + else () + get_filename_component(FINAL_DIR ${obj} DIRECTORY) + endif () + + message("Copying ${obj} -> ${FINAL_DIR}/${fname}.${ext_strip}") + execute_process( + COMMAND ${CMAKE_COMMAND} -E copy_if_different "${obj}" "${FINAL_DIR}/${fname}.${ext_strip}" + ) +endforeach () diff --git a/docs/guides/cubin_launcher.rst b/docs/guides/cubin_launcher.rst index e3e7d968..8407eb52 100644 --- a/docs/guides/cubin_launcher.rst +++ b/docs/guides/cubin_launcher.rst @@ -18,16 +18,48 @@ CUBIN Launcher Guide ==================== -This guide demonstrates how to load and launch CUDA kernels from CUBIN (CUDA Binary) modules using TVM-FFI. The CUBIN launcher enables you to execute pre-compiled or runtime-compiled CUDA kernels efficiently through the CUDA Runtime API. +This guide demonstrates how to load and launch CUDA kernels from CUBIN (CUDA Binary) modules using TVM-FFI. The CUBIN launcher enables you to execute pre-compiled or runtime-compiled CUDA kernels efficiently through the CUDA Runtime API or Driver API. Overview -------- -TVM-FFI provides utilities for loading and launching CUDA kernels from CUBIN modules. The implementation is in ``tvm/ffi/extra/cuda/cubin_launcher.h`` and provides: +TVM-FFI provides utilities for loading and launching CUDA kernels from CUBIN modules. The implementation supports both **CUDA Runtime API** (default for CUDA >= 12.8) and **CUDA Driver API**. + +**Runtime API (CUDA >= 12.8):** + +- ``cudaLibraryLoadData()`` - Load CUBIN from memory buffer +- ``cudaLibraryGetKernel()`` - Get kernel handle by name +- ``cudaLaunchKernel()`` - Launch kernel with grid/block dimensions + +**Driver API:** + +- ``cuLibraryLoadData()`` - Load CUBIN from memory buffer +- ``cuLibraryGetKernel()`` - Get kernel handle by name +- ``cuLaunchKernel()`` - Launch kernel with grid/block dimensions + +**Customization:** + +By default, the implementation uses the Runtime API if compiled with CUDA >= 12.8, falling back to the Driver API for older versions. You can force the usage of the Driver API (or Runtime API) by defining the macro ``TVM_FFI_CUBIN_LAUNCHER_USE_DRIVER_API`` (set to ``1`` for Driver API, ``0`` for Runtime API) before including the header. + +.. warning:: + + **CMAKE_CUDA_RUNTIME_LIBRARY and Driver API** + + When using CMake, the default behavior (if ``CMAKE_CUDA_RUNTIME_LIBRARY`` is not set) is to link against the CUDA Runtime Library (``cudart``). TVM-FFI's CMake utility automatically defaults this variable to ``Shared`` if it is undefined. This introduces a dependency on the CUDA runtime version, requiring the system's driver to be compatible with that runtime version. + + If you intend to use the Driver API only (e.g. by setting ``TVM_FFI_CUBIN_LAUNCHER_USE_DRIVER_API=1``) to avoid this runtime dependency: + + 1. You must explicitly set ``CMAKE_CUDA_RUNTIME_LIBRARY`` to ``None`` in your CMake configuration to prevent linking ``cudart``. + 2. You must manually link your target against the CUDA Driver library (usually ``cuda`` on Linux/Windows or `CUDA::cuda_driver` provided by CMake's ``FindCUDAToolkit``). + + This ensures your application relies solely on the widely compatible CUDA Driver API (``libcuda.so.1``). + +The implementation is in ``tvm/ffi/extra/cuda/cubin_launcher.h`` and provides: - :cpp:class:`tvm::ffi::CubinModule`: RAII wrapper for loading CUBIN modules from memory - :cpp:class:`tvm::ffi::CubinKernel`: Handle for launching CUDA kernels with specified parameters -- :c:macro:`TVM_FFI_EMBED_CUBIN`: Macro for embedding CUBIN data at compile time +- :c:macro:`TVM_FFI_EMBED_CUBIN`: Macro for embedding CUBIN data at compile time (legacy / object-linking approach) +- :c:macro:`TVM_FFI_EMBED_CUBIN_FROM_BYTES`: Macro for embedding CUBIN data from byte arrays (manual embedding approach) - :c:macro:`TVM_FFI_EMBED_CUBIN_GET_KERNEL`: Macro for retrieving kernels from embedded CUBIN The CUBIN launcher supports: @@ -41,9 +73,9 @@ The CUBIN launcher supports: TVM-FFI provides convenient tools for embedding CUBIN data at build time: -- **CMake utilities** (``cmake/Utils/EmbedCubin.cmake``): Functions for compiling CUDA to CUBIN and embedding it into C++ code -- **Python utility** (``python -m tvm_ffi.utils.embed_cubin``): Command-line tool for embedding CUBIN into object files -- **Python API** (:py:func:`tvm_ffi.cpp.load_inline`): Runtime embedding via ``embed_cubin`` parameter +- **CMake utilities** (``cmake/Utils/EmbedCubin.cmake``): Functions for compiling CUDA to CUBIN/FATBIN and embedding it into C++ code or linking it. +- **Python utility** (``python -m tvm_ffi.utils.embed_cubin``): Command-line tool for embedding CUBIN into object files. +- **Python API** (:py:func:`tvm_ffi.cpp.load_inline`): Runtime embedding via ``embed_cubin`` parameter. Python Usage ------------ @@ -117,9 +149,35 @@ C++ Usage Embedding CUBIN at Compile Time ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -The recommended approach in C++ is to embed CUBIN data directly into your shared library: +The most convenient way to embed CUBIN/FATBIN data in C++ is using the TVM-FFI build utilities. There are three main approaches: -.. literalinclude:: ../../examples/cubin_launcher/embedded_cubin/src/lib_embedded.cc +1. **Object Linking (Standard)**: Use CMake utilities to compile and link the CUBIN data. +2. **Header Inclusion (Portable)**: Convert CUBIN to a C header file using ``bin2c``. +3. **C++ Embedding (Modern)**: Use C++23 ``#embed`` (or compiler extensions). + +**Method 1: Object Linking (Standard)** + +This approach uses CMake utilities to compile and link the CUBIN data. It works across all supported compilers and handles the low-level details of object file generation and symbol naming. + +.. literalinclude:: ../../examples/cubin_launcher/embedded_cubin/embed_with_tvm_ffi/src/lib_embedded.cc + :language: cpp + :start-after: [example.begin] + :end-before: [example.end] + +**Method 2: Header Inclusion (Portable)** + +You can use tools like ``bin2c`` to generate a header file containing the byte array and include it. + +.. literalinclude:: ../../examples/cubin_launcher/embedded_cubin/include_bin2c/src/lib_embedded.cc + :language: cpp + :start-after: [example.begin] + :end-before: [example.end] + +**Method 3: C++ Embedding (Modern)** + +Using C++23 ``#embed`` (or compiler extensions like ``#embed`` in Clang/GCC) allows you to include the binary data directly. + +.. literalinclude:: ../../examples/cubin_launcher/embedded_cubin/cpp_embed/src/lib_embedded.cc :language: cpp :start-after: [example.begin] :end-before: [example.end] @@ -130,7 +188,7 @@ The recommended approach in C++ is to embed CUBIN data directly into your shared - Kernel arguments must be pointers to the actual values (use ``&`` for addresses) - :cpp:type:`tvm::ffi::dim3` supports 1D, 2D, or 3D configurations: ``dim3(x)``, ``dim3(x, y)``, ``dim3(x, y, z)`` - ``TVMFFIEnvGetStream`` retrieves the correct CUDA stream for the device -- Always check kernel launch results with :c:macro:`TVM_FFI_CHECK_CUDA_ERROR` (which checks CUDA Runtime API errors) +- Always check kernel launch results with :c:macro:`TVM_FFI_CHECK_CUDA_ERROR` (which checks CUDA Runtime API or Driver API errors depending on configuration) Loading CUBIN at Runtime ~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -149,36 +207,37 @@ TVM-FFI provides CMake utility functions that simplify the CUBIN embedding proce **Using CMake Utilities:** -.. literalinclude:: ../../examples/cubin_launcher/embedded_cubin/CMakeLists.txt +.. literalinclude:: ../../examples/cubin_launcher/embedded_cubin/embed_with_tvm_ffi/CMakeLists.txt :language: cmake :start-after: [cmake_example.begin] :end-before: [cmake_example.end] **Available CMake Functions:** -- ``tvm_ffi_generate_cubin()``: Compiles CUDA source to CUBIN using nvcc +- ``add_tvm_ffi_cubin( CUDA )``: + Creates an object library that compiles CUDA source to CUBIN format. + This is a compatibility wrapper; for CMake >= 3.27, you can use standard ``CUDA_CUBIN_COMPILATION`` property. + +- ``add_tvm_ffi_fatbin( CUDA )``: + Creates an object library that compiles CUDA source to FATBIN format. + This is a compatibility wrapper; for CMake >= 3.27, you can use standard ``CUDA_FATBIN_COMPILATION`` property. - - ``OUTPUT``: Path to output CUBIN file - - ``SOURCE``: Path to CUDA source file - - ``ARCH``: Target GPU architecture (default: ``native`` for auto-detection) - - ``OPTIONS``: Additional nvcc compiler options (optional) - - ``DEPENDS``: Additional dependencies (optional) +- ``tvm_ffi_embed_bin_into( SYMBOL BIN )``: + Embeds a CUBIN/FATBIN file into an existing object library target. + This works by linking the binary data into the target, allowing access via ``TVM_FFI_EMBED_CUBIN()``. -- ``tvm_ffi_embed_cubin()``: Compiles C++ source and embeds CUBIN data + - ``target``: The target to embed into (must be an object library or have object files). + - ``symbol``: Symbol name to use (must match ``TVM_FFI_EMBED_CUBIN(symbol)``). + - ``BIN``: Path to the CUBIN/FATBIN file (e.g., from ``$``). - - ``OUTPUT``: Path to output combined object file - - ``SOURCE``: Path to C++ source file with ``TVM_FFI_EMBED_CUBIN`` macro - - ``CUBIN``: Path to CUBIN file to embed - - ``NAME``: Symbol name used in ``TVM_FFI_EMBED_CUBIN(name)`` macro - - ``DEPENDS``: Additional dependencies (optional) +.. note:: -The utilities automatically handle: + When including ``cmake/Utils/EmbedCubin.cmake``, if ``CMAKE_CUDA_RUNTIME_LIBRARY`` is not set, it defaults to ``Shared``. + This prevents static linking of cudart, which requires an exact driver version match. + If you intend to use the Driver API only (e.g., via ``TVM_FFI_CUBIN_LAUNCHER_USE_DRIVER_API=1``), + you should explicitly set ``CMAKE_CUDA_RUNTIME_LIBRARY`` to ``None`` in your CMake configuration before including this utility to avoid linking against the CUDA runtime library. + And link with CUDA Driver API. -- Compiling C++ source to intermediate object file -- Creating CUBIN symbols with proper naming -- Merging object files using ``ld -r`` -- Adding ``.note.GNU-stack`` section for security -- Localizing symbols to prevent conflicts Embedding CUBIN with Python Utility ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -374,6 +433,7 @@ C++ Macros ~~~~~~~~~~ - :c:macro:`TVM_FFI_EMBED_CUBIN`: Declare embedded CUBIN module +- :c:macro:`TVM_FFI_EMBED_CUBIN_FROM_BYTES`: Load CUBIN from byte array - :c:macro:`TVM_FFI_EMBED_CUBIN_GET_KERNEL`: Get kernel from embedded module - :c:macro:`TVM_FFI_CHECK_CUDA_ERROR`: Check CUDA Runtime API result @@ -389,7 +449,7 @@ Python Utilities - ``python -m tvm_ffi.utils.embed_cubin``: Command-line utility to embed CUBIN into object files - ``--output-obj PATH``: Output combined object file path - - ``--input-obj PATH``: Input object file containing C++ code with ``TVM_FFI_EMBED_CUBIN`` + - ``--input-obj PATH``: Input compiled object file containing C++ code with ``TVM_FFI_EMBED_CUBIN`` - ``--cubin PATH``: Input CUBIN file to embed - ``--name NAME``: Symbol name matching ``TVM_FFI_EMBED_CUBIN(name)`` macro - ``--verbose``: Print detailed command output (optional) @@ -405,18 +465,6 @@ Python Utilities CMake Functions ~~~~~~~~~~~~~~~ -- ``tvm_ffi_generate_cubin()``: Compile CUDA source to CUBIN - - - ``OUTPUT``: Path to output CUBIN file - - ``SOURCE``: Path to CUDA source file (.cu) - - ``ARCH``: Target architecture (default: ``native``) - - ``OPTIONS``: Additional nvcc compiler flags (optional) - - ``DEPENDS``: Additional dependencies (optional) - -- ``tvm_ffi_embed_cubin()``: Compile C++ source and embed CUBIN data - - - ``OUTPUT``: Path to output combined object file - - ``SOURCE``: Path to C++ source file - - ``CUBIN``: Path to CUBIN file to embed - - ``NAME``: Symbol name matching ``TVM_FFI_EMBED_CUBIN(name)`` in source - - ``DEPENDS``: Additional dependencies (optional) +- ``add_tvm_ffi_cubin( CUDA )``: Compile CUDA source to CUBIN +- ``add_tvm_ffi_fatbin( CUDA )``: Compile CUDA source to FATBIN +- ``tvm_ffi_embed_bin_into( SYMBOL BIN )``: Embed CUBIN/FATBIN into object target diff --git a/examples/cubin_launcher/README.md b/examples/cubin_launcher/README.md index 770f4257..7b0ac37e 100644 --- a/examples/cubin_launcher/README.md +++ b/examples/cubin_launcher/README.md @@ -23,35 +23,56 @@ Demonstrates loading and executing CUDA kernels from CUBIN files using TVM-FFI. ## Techniques -The implementation uses CUDA Runtime API Library Management: +The implementation supports both CUDA Runtime API (CUDA >= 12.8) and Driver API for Library Management. + +**Runtime API (CUDA >= 12.8):** - **`cudaLibraryLoadData()`** - Load CUBIN from memory buffer - **`cudaLibraryGetKernel()`** - Get kernel handle by name -- **`cudaKernelGetFunction()`** - Get function handle for current CUDA context - **`cudaLaunchKernel()`** - Launch kernel with grid/block dimensions +**Driver API:** + +- **`cuLibraryLoadData()`** - Load CUBIN from memory buffer +- **`cuLibraryGetKernel()`** - Get kernel handle by name +- **`cuLaunchKernel()`** - Launch kernel with grid/block dimensions + +**Customization:** + +By default, the implementation uses the Runtime API if compiled with CUDA >= 12.8, falling back to the Driver API for older versions. You can force the usage of the Driver API (or Runtime API) by defining the macro `TVM_FFI_CUBIN_LAUNCHER_USE_DRIVER_API` (set to `1` for Driver API, `0` for Runtime API) before including the header. + Key features: - Multi-GPU support via CUDA primary contexts - RAII-based resource management (CubinModule, CubinKernel) -- CUBIN embedding at compile time (via `ld` + `objcopy`) +- CUBIN embedding at compile time + - Object linking (via `ld` + `objcopy`) + - Header inclusion (via `bin2c`) + - Modern C++ embedding (via `#embed`) - TVM-FFI integration for tensor argument passing -- **New:** `TVM_FFI_EMBED_CUBIN` and `TVM_FFI_EMBED_CUBIN_GET_KERNEL` macros for easy CUBIN embedding -- **New:** `embed_cubin` parameter in `tvm_ffi.cpp.load_inline` for seamless CUBIN integration -- **New:** `tvm_ffi.cpp.nvrtc` module for runtime CUDA compilation +- **Macros:** + - `TVM_FFI_EMBED_CUBIN`: Declare symbols for object-linked CUBIN + - `TVM_FFI_EMBED_CUBIN_FROM_BYTES`: Load CUBIN from byte array (for `#embed`/`bin2c`) + - `TVM_FFI_EMBED_CUBIN_GET_KERNEL`: Helper to retrieve kernels +- **Python Integration:** `embed_cubin` parameter in `tvm_ffi.cpp.load_inline` for seamless CUBIN integration +- **Runtime Compilation:** `tvm_ffi.cpp.nvrtc` module for runtime CUDA compilation ## Examples ### 1. Embedded CUBIN -Demonstrates embedding CUBIN data directly into the shared library at build time using the `tvm_ffi_embed_cubin` CMake utility. +The `embedded_cubin` directory contains three examples demonstrating different embedding techniques. -**Location:** `embedded_cubin/` +#### 1.1 Object Linking (Standard) + +Demonstrates embedding CUBIN data directly into the shared library at build time using the `tvm_ffi_embed_bin_into` CMake utility. This is the most robust method for CMake projects. + +**Location:** `embedded_cubin/embed_with_tvm_ffi/` **Build and run:** ```bash -cd examples/cubin_launcher/embedded_cubin +cd examples/cubin_launcher/embedded_cubin/embed_with_tvm_ffi mkdir build && cd build cmake .. make @@ -59,12 +80,39 @@ cd .. python main.py ``` -**Key features:** +#### 1.2 Header Inclusion (Portable) + +Demonstrates converting the CUBIN to a C header file using `bin2c` and including it in the C++ source. This is highly portable and works with any compiler. -- CUBIN is embedded at compile time using `ld` and `objcopy` -- No separate CUBIN file needed at runtime -- Symbols are localized to prevent conflicts -- `.note.GNU-stack` section automatically added for security +**Location:** `embedded_cubin/include_bin2c/` + +**Build and run:** + +```bash +cd examples/cubin_launcher/embedded_cubin/include_bin2c +mkdir build && cd build +cmake .. +make +cd .. +python main.py +``` + +#### 1.3 C++ Embedding (Modern) + +Demonstrates using C++23 `#embed` (or compiler extensions in GCC/Clang) to directly include binary data. This is the cleanest approach for modern toolchains. + +**Location:** `embedded_cubin/cpp_embed/` + +**Build and run:** + +```bash +cd examples/cubin_launcher/embedded_cubin/cpp_embed +mkdir build && cd build +cmake .. +make +cd .. +python main.py +``` ### 2. Dynamic CUBIN Loading @@ -152,24 +200,18 @@ mod = cpp.load_inline( - `include/tvm/ffi/extra/cuda/cubin_launcher.h` - Header-only C++ library with CUBIN utilities - `python/tvm_ffi/utils/embed_cubin.py` - Python utility for embedding CUBIN into object files - `python/tvm_ffi/cpp/nvrtc.py` - NVRTC compilation utilities -- `cmake/Utils/EmbedCubin.cmake` - CMake utilities (`tvm_ffi_generate_cubin`, `tvm_ffi_embed_cubin`) +- `cmake/Utils/EmbedCubin.cmake` - CMake utilities ### Example Directories -**`embedded_cubin/`** - CUBIN embedded at build time +**`embedded_cubin/`** - Different CUBIN embedding techniques: -- `CMakeLists.txt` - Build configuration using `tvm_ffi_embed_cubin` -- `main.py` - Python test script -- `src/lib_embedded.cc` - C++ code using `TVM_FFI_EMBED_CUBIN` macro -- `src/kernel.cu` - CUDA kernels (add_one, mul_two) +- `embed_with_tvm_ffi/` - Standard object linking +- `include_bin2c/` - Header inclusion +- `cpp_embed/` - Modern C++ `#embed` **`dynamic_cubin/`** - CUBIN loaded at runtime -- `CMakeLists.txt` - Build configuration using `tvm_ffi_generate_cubin` -- `main.py` - Python test script -- `src/lib_dynamic.cc` - C++ code using `CubinModule::GetKernel()` -- `src/kernel.cu` - CUDA kernels (add_one, mul_two) - **Additional Examples** (at root level) - `example_triton_cubin.py` - Triton kernel with embedded CUBIN diff --git a/examples/cubin_launcher/dynamic_cubin/CMakeLists.txt b/examples/cubin_launcher/dynamic_cubin/CMakeLists.txt index dd30d8ba..fdb62297 100644 --- a/examples/cubin_launcher/dynamic_cubin/CMakeLists.txt +++ b/examples/cubin_launcher/dynamic_cubin/CMakeLists.txt @@ -21,6 +21,11 @@ project(dynamic_cubin_example LANGUAGES CXX CUDA) # Prefer virtualenv when searching for python set(Python_FIND_VIRTUALENV FIRST) # cmake-lint: disable=C0103 +set(CMAKE_TVM_FFI_CUBIN_LAUNCHER_USE_DRIVER_API + OFF + CACHE BOOL "Use driver API in cubin launcher" +) + # Find tvm-ffi package find_package( Python @@ -37,26 +42,43 @@ find_package(tvm_ffi CONFIG REQUIRED) # Find CUDA toolkit find_package(CUDAToolkit REQUIRED) -# Step 1: Compile kernel.cu to CUBIN using tvm_ffi_generate_cubin utility Use -arch=native to -# automatically detect the GPU architecture -tvm_ffi_generate_cubin( - OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/kernel.cubin SOURCE src/kernel.cu ARCH native -) +# [cmake_example.begin] + +# Step 1: Compile kernel.cu to CUBIN using add_tvm_ffi_cubin utility or CUDA_CUBIN_COMPILATION. Use +# CMAKE_CUDA_ARCHITECTURES=native to automatically detect the GPU architecture +set(CMAKE_CUDA_ARCHITECTURES native) +if (CMAKE_VERSION VERSION_LESS "3.27.0") + add_tvm_ffi_cubin(kernel_cubin CUDA src/kernel.cu) +else () + add_library(kernel_cubin OBJECT src/kernel.cu) + set_property(TARGET kernel_cubin PROPERTY CUDA_CUBIN_COMPILATION ON) +endif () -# Create a target that depends on the CUBIN add_custom_target( - generate_cubin ALL - DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/kernel.cubin - COMMENT "Generating CUBIN file" + kernel.cubin + COMMAND ${CMAKE_COMMAND} -E copy_if_different "$" + "${CMAKE_CURRENT_BINARY_DIR}/kernel.cubin" + DEPENDS kernel_cubin + COMMENT "Copy cubin to build dir" ) # Step 2: Build lib_dynamic shared library (loads CUBIN from file at runtime) add_library(lib_dynamic SHARED src/lib_dynamic.cc) -target_link_libraries(lib_dynamic PRIVATE tvm_ffi::header tvm_ffi::shared CUDA::cudart) -add_dependencies(lib_dynamic generate_cubin) +include_directories(${CUDAToolkit_INCLUDE_DIRS}) +target_link_libraries(lib_dynamic PRIVATE tvm_ffi::header tvm_ffi::shared) +add_dependencies(lib_dynamic kernel.cubin) set_target_properties( lib_dynamic PROPERTIES LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/" PREFIX "" SUFFIX ".so" ) + +# Step 3: Link against CUDA Driver API or Runtime API based on config +if (CMAKE_TVM_FFI_CUBIN_LAUNCHER_USE_DRIVER_API) + add_compile_definitions(TVM_FFI_CUBIN_LAUNCHER_USE_DRIVER_API=1) + target_link_libraries(lib_dynamic PRIVATE cuda) +else () + target_link_libraries(lib_dynamic PRIVATE CUDA::cudart) +endif () +# [cmake_example.end] diff --git a/examples/cubin_launcher/dynamic_cubin/src/lib_dynamic.cc b/examples/cubin_launcher/dynamic_cubin/src/lib_dynamic.cc index 712e1bd8..61f050db 100644 --- a/examples/cubin_launcher/dynamic_cubin/src/lib_dynamic.cc +++ b/examples/cubin_launcher/dynamic_cubin/src/lib_dynamic.cc @@ -80,14 +80,19 @@ void AddOne(tvm::ffi::TensorView x, tvm::ffi::TensorView y) { // Get CUDA stream DLDevice device = x.device(); - cudaStream_t stream = - static_cast(TVMFFIEnvGetStream(device.device_type, device.device_id)); + tvm::ffi::cuda_api::StreamHandle stream = static_cast( + TVMFFIEnvGetStream(device.device_type, device.device_id)); // Launch kernel - cudaError_t result = g_add_one_kernel->Launch(args, grid, block, stream); + tvm::ffi::cuda_api::ResultType result = g_add_one_kernel->Launch(args, grid, block, stream); TVM_FFI_CHECK_CUDA_ERROR(result); } +} // namespace cubin_dynamic +// [example.end] + +namespace cubin_dynamic { + /*! * \brief Launch mul_two_cuda kernel on input tensor. * \param x Input tensor (float32, 1D) @@ -115,11 +120,11 @@ void MulTwo(tvm::ffi::TensorView x, tvm::ffi::TensorView y) { // Get CUDA stream DLDevice device = x.device(); - cudaStream_t stream = - static_cast(TVMFFIEnvGetStream(device.device_type, device.device_id)); + tvm::ffi::cuda_api::StreamHandle stream = static_cast( + TVMFFIEnvGetStream(device.device_type, device.device_id)); // Launch kernel - cudaError_t result = g_mul_two_kernel->Launch(args, grid, block, stream); + tvm::ffi::cuda_api::ResultType result = g_mul_two_kernel->Launch(args, grid, block, stream); TVM_FFI_CHECK_CUDA_ERROR(result); } @@ -129,4 +134,3 @@ TVM_FFI_DLL_EXPORT_TYPED_FUNC(add_one, cubin_dynamic::AddOne); TVM_FFI_DLL_EXPORT_TYPED_FUNC(mul_two, cubin_dynamic::MulTwo); } // namespace cubin_dynamic -// [example.end] diff --git a/examples/cubin_launcher/embedded_cubin/cpp_embed/CMakeLists.txt b/examples/cubin_launcher/embedded_cubin/cpp_embed/CMakeLists.txt new file mode 100644 index 00000000..084648e3 --- /dev/null +++ b/examples/cubin_launcher/embedded_cubin/cpp_embed/CMakeLists.txt @@ -0,0 +1,83 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +cmake_minimum_required(VERSION 3.27) +project(embedded_cubin_example LANGUAGES CXX CUDA) + +set(CMAKE_CXX_STANDARD 26) + +set(CMAKE_TVM_FFI_CUBIN_LAUNCHER_USE_DRIVER_API + OFF + CACHE BOOL "Use driver API in cubin launcher" +) + +# Prefer virtualenv when searching for python +set(Python_FIND_VIRTUALENV FIRST) # cmake-lint: disable=C0103 + +# Find tvm-ffi package +find_package( + Python + COMPONENTS Interpreter + REQUIRED +) +execute_process( + COMMAND "${Python_EXECUTABLE}" -m tvm_ffi.config --cmakedir + OUTPUT_STRIP_TRAILING_WHITESPACE + OUTPUT_VARIABLE tvm_ffi_ROOT +) +find_package(tvm_ffi CONFIG REQUIRED) + +# Find CUDA toolkit +find_package(CUDAToolkit REQUIRED) +include_directories(${CUDAToolkit_INCLUDE_DIRS}) + +# [cmake_example.begin] + +# Step 1: Compile kernel.cu to FATBIN using add_tvm_ffi_fatbin utility or `CUDA_FATBIN_COMPILATION` +set(CMAKE_CUDA_ARCHITECTURES 75;80;86;89;90;100;120) +add_library(kernel_fatbin OBJECT src/kernel.cu) +set_target_properties(kernel_fatbin PROPERTIES CUDA_FATBIN_COMPILATION ON) + +add_custom_target( + kernel_fatbin.fatbin + COMMAND ${CMAKE_COMMAND} -E copy_if_different "$" + "${CMAKE_CURRENT_SOURCE_DIR}/src/kernel_fatbin.fatbin" + DEPENDS kernel_fatbin + COMMENT "Copy fatbin to source dir" +) + +# Step 2: Build lib_embedded shared library +add_library(lib_embedded SHARED src/lib_embedded.cc) +add_dependencies(lib_embedded kernel_fatbin.fatbin) +target_link_libraries(lib_embedded PRIVATE tvm_ffi::header tvm_ffi::shared) +set_target_properties(lib_embedded PROPERTIES POSITION_INDEPENDENT_CODE ON) + +# Step 3: Link against CUDA Driver API or Runtime API based on config +if (CMAKE_TVM_FFI_CUBIN_LAUNCHER_USE_DRIVER_API) + add_compile_definitions(TVM_FFI_CUBIN_LAUNCHER_USE_DRIVER_API=1) + target_link_libraries(lib_embedded PRIVATE cuda) +else () + target_link_libraries(lib_embedded PRIVATE CUDA::cudart) +endif () + +set_target_properties( + lib_embedded + PROPERTIES PREFIX "" + SUFFIX ".so" + LINKER_LANGUAGE CXX +) +# [cmake_example.end] diff --git a/examples/cubin_launcher/embedded_cubin/main.py b/examples/cubin_launcher/embedded_cubin/cpp_embed/main.py similarity index 100% rename from examples/cubin_launcher/embedded_cubin/main.py rename to examples/cubin_launcher/embedded_cubin/cpp_embed/main.py diff --git a/examples/cubin_launcher/embedded_cubin/src/kernel.cu b/examples/cubin_launcher/embedded_cubin/cpp_embed/src/kernel.cu similarity index 100% rename from examples/cubin_launcher/embedded_cubin/src/kernel.cu rename to examples/cubin_launcher/embedded_cubin/cpp_embed/src/kernel.cu diff --git a/examples/cubin_launcher/embedded_cubin/cpp_embed/src/lib_embedded.cc b/examples/cubin_launcher/embedded_cubin/cpp_embed/src/lib_embedded.cc new file mode 100644 index 00000000..357522e5 --- /dev/null +++ b/examples/cubin_launcher/embedded_cubin/cpp_embed/src/lib_embedded.cc @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file examples/cubin_launcher/src/lib_embedded.cc + * \brief TVM-FFI library with embedded CUBIN kernels. + * + * This library exports TVM-FFI functions to launch CUDA kernels from + * embedded CUBIN data. + */ + +#include +#include +#include +#include +#include + +// [example.begin] +constexpr unsigned char image[]{ +// clang >= 20 or gcc >= 14 +#embed "kernel_fatbin.fatbin" +}; + +TVM_FFI_EMBED_CUBIN_FROM_BYTES(env, image); +// [example.end] + +namespace cubin_embedded { + +/*! + * \brief Launch add_one_cuda kernel on input tensor. + * \param x Input tensor (float32, 1D) + * \param y Output tensor (float32, 1D, same shape as x) + */ +void AddOne(tvm::ffi::TensorView x, tvm::ffi::TensorView y) { + // Get kernel from embedded CUBIN (cached in static variable for efficiency) + static auto kernel = TVM_FFI_EMBED_CUBIN_GET_KERNEL(env, "add_one_cuda"); + + TVM_FFI_CHECK(x.ndim() == 1, ValueError) << "Input must be 1D tensor"; + TVM_FFI_CHECK(y.ndim() == 1, ValueError) << "Output must be 1D tensor"; + TVM_FFI_CHECK(x.size(0) == y.size(0), ValueError) << "Input and output must have same size"; + + int64_t n = x.size(0); + void* x_ptr = x.data_ptr(); + void* y_ptr = y.data_ptr(); + + // Prepare kernel arguments + void* args[] = {reinterpret_cast(&x_ptr), reinterpret_cast(&y_ptr), + reinterpret_cast(&n)}; + + // Launch configuration + tvm::ffi::dim3 grid((n + 255) / 256); + tvm::ffi::dim3 block(256); + + // Get CUDA stream + DLDevice device = x.device(); + tvm::ffi::cuda_api::StreamHandle stream = static_cast( + TVMFFIEnvGetStream(device.device_type, device.device_id)); + + // Launch kernel + tvm::ffi::cuda_api::ResultType result = kernel.Launch(args, grid, block, stream); + TVM_FFI_CHECK_CUDA_ERROR(result); +} + +} // namespace cubin_embedded + +namespace cubin_embedded { + +/*! + * \brief Launch mul_two_cuda kernel on input tensor. + * \param x Input tensor (float32, 1D) + * \param y Output tensor (float32, 1D, same shape as x) + */ +void MulTwo(tvm::ffi::TensorView x, tvm::ffi::TensorView y) { + // Get kernel from embedded CUBIN (cached in static variable for efficiency) + static auto kernel = TVM_FFI_EMBED_CUBIN_GET_KERNEL(env, "mul_two_cuda"); + + TVM_FFI_CHECK(x.ndim() == 1, ValueError) << "Input must be 1D tensor"; + TVM_FFI_CHECK(y.ndim() == 1, ValueError) << "Output must be 1D tensor"; + TVM_FFI_CHECK(x.size(0) == y.size(0), ValueError) << "Input and output must have same size"; + + int64_t n = x.size(0); + void* x_ptr = x.data_ptr(); + void* y_ptr = y.data_ptr(); + + // Prepare kernel arguments + void* args[] = {reinterpret_cast(&x_ptr), reinterpret_cast(&y_ptr), + reinterpret_cast(&n)}; + + // Launch configuration + tvm::ffi::dim3 grid((n + 255) / 256); + tvm::ffi::dim3 block(256); + + // Get CUDA stream + DLDevice device = x.device(); + tvm::ffi::cuda_api::StreamHandle stream = static_cast( + TVMFFIEnvGetStream(device.device_type, device.device_id)); + + // Launch kernel + tvm::ffi::cuda_api::ResultType result = kernel.Launch(args, grid, block, stream); + TVM_FFI_CHECK_CUDA_ERROR(result); +} + +// Export TVM-FFI functions +TVM_FFI_DLL_EXPORT_TYPED_FUNC(add_one, cubin_embedded::AddOne); +TVM_FFI_DLL_EXPORT_TYPED_FUNC(mul_two, cubin_embedded::MulTwo); + +} // namespace cubin_embedded diff --git a/examples/cubin_launcher/embedded_cubin/CMakeLists.txt b/examples/cubin_launcher/embedded_cubin/embed_with_tvm_ffi/CMakeLists.txt similarity index 53% rename from examples/cubin_launcher/embedded_cubin/CMakeLists.txt rename to examples/cubin_launcher/embedded_cubin/embed_with_tvm_ffi/CMakeLists.txt index 44804172..7bbdcefc 100644 --- a/examples/cubin_launcher/embedded_cubin/CMakeLists.txt +++ b/examples/cubin_launcher/embedded_cubin/embed_with_tvm_ffi/CMakeLists.txt @@ -18,6 +18,11 @@ cmake_minimum_required(VERSION 3.20) project(embedded_cubin_example LANGUAGES CXX CUDA) +set(CMAKE_TVM_FFI_CUBIN_LAUNCHER_USE_DRIVER_API + OFF + CACHE BOOL "Use driver API in cubin launcher" +) + # Prefer virtualenv when searching for python set(Python_FIND_VIRTUALENV FIRST) # cmake-lint: disable=C0103 @@ -36,33 +41,39 @@ find_package(tvm_ffi CONFIG REQUIRED) # Find CUDA toolkit find_package(CUDAToolkit REQUIRED) +include_directories(${CUDAToolkit_INCLUDE_DIRS}) -# [cmake_example.begin] Step 1: Compile kernel.cu to CUBIN using tvm_ffi_generate_cubin utility Use -# -arch=native to automatically detect the GPU architecture -tvm_ffi_generate_cubin( - OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/kernel.cubin SOURCE src/kernel.cu ARCH native -) +# [cmake_example.begin] -# Step 2: Embed CUBIN into the object file using tvm_ffi_embed_cubin utility This creates symbols: -# __tvm_ffi__cubin_env, __tvm_ffi__cubin_env_end (local) -tvm_ffi_embed_cubin( - OUTPUT - ${CMAKE_CURRENT_BINARY_DIR}/lib_embedded_with_cubin.o - SOURCE - src/lib_embedded.cc - CUBIN - ${CMAKE_CURRENT_BINARY_DIR}/kernel.cubin - NAME - env -) +# Step 1: Compile kernel.cu to FATBIN using add_tvm_ffi_fatbin utility or `CUDA_FATBIN_COMPILATION` +set(CMAKE_CUDA_ARCHITECTURES 75;80;86;89;90;100;120) +if (CMAKE_VERSION VERSION_LESS "3.27.0") + add_tvm_ffi_fatbin(kernel_fatbin CUDA src/kernel.cu) +else () + add_library(kernel_fatbin OBJECT src/kernel.cu) + set_target_properties(kernel_fatbin PROPERTIES CUDA_FATBIN_COMPILATION ON) +endif () + +# Step 2: Build lib_embedded shared library +add_library(lib_embedded SHARED src/lib_embedded.cc) +target_link_libraries(lib_embedded PRIVATE tvm_ffi::header tvm_ffi::shared) +set_target_properties(lib_embedded PROPERTIES POSITION_INDEPENDENT_CODE ON) + +# Step 3: Link against CUDA Driver API or Runtime API based on config +if (CMAKE_TVM_FFI_CUBIN_LAUNCHER_USE_DRIVER_API) + add_compile_definitions(TVM_FFI_CUBIN_LAUNCHER_USE_DRIVER_API=1) + target_link_libraries(lib_embedded PRIVATE cuda) +else () + target_link_libraries(lib_embedded PRIVATE CUDA::cudart) +endif () + +# Step 4: Embed CUBIN into shared library just defined, using tvm_ffi_embed_cubin utility This +# creates symbols: __tvm_ffi__cubin_env (local) +tvm_ffi_embed_bin_into(lib_embedded SYMBOL env BIN "$") -# Step 3: Build lib_embedded shared library (with embedded CUBIN) -add_library(lib_embedded SHARED ${CMAKE_CURRENT_BINARY_DIR}/lib_embedded_with_cubin.o) -target_link_libraries(lib_embedded PRIVATE tvm_ffi::header tvm_ffi::shared CUDA::cudart) set_target_properties( lib_embedded - PROPERTIES LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/" - PREFIX "" + PROPERTIES PREFIX "" SUFFIX ".so" LINKER_LANGUAGE CXX ) diff --git a/examples/cubin_launcher/embedded_cubin/embed_with_tvm_ffi/main.py b/examples/cubin_launcher/embedded_cubin/embed_with_tvm_ffi/main.py new file mode 100644 index 00000000..006ba36c --- /dev/null +++ b/examples/cubin_launcher/embedded_cubin/embed_with_tvm_ffi/main.py @@ -0,0 +1,112 @@ +#!/usr/bin/env python3 +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Example script for embedded CUBIN library. + +This example demonstrates using lib_embedded.so which has CUBIN data +embedded directly in the shared library using objcopy. +""" + +import sys +from pathlib import Path + +import torch +from tvm_ffi import load_module + + +def main() -> int: + """Test the lib_embedded.so library with embedded CUBIN.""" + print("=" * 60) + print("Example: Embedded CUBIN Library") + print("=" * 60) + + # Check CUDA availability + if not torch.cuda.is_available(): + print("[ERROR] CUDA is not available") + return 1 + + print(f"CUDA device: {torch.cuda.get_device_name(0)}") + print(f"PyTorch version: {torch.__version__}\n") + + # Load the library + lib_path = Path(__file__).parent / "build" / "lib_embedded.so" + mod = load_module(str(lib_path)) + print(f"Loaded library: {lib_path}") + + # Get the functions + add_one = mod["add_one"] + mul_two = mod["mul_two"] + print("Loaded functions: add_one, mul_two") + + # Test add_one kernel + print("\n[Test 1] add_one kernel") + n = 1024 + x = torch.arange(n, dtype=torch.float32, device="cuda") + y = torch.empty(n, dtype=torch.float32, device="cuda") + + print(f" Input shape: {x.shape}, device: {x.device}") + add_one(x, y) + + # Verify results + expected = x + 1 + if torch.allclose(y, expected): + print(f" [PASS] Verified {n} elements correctly") + else: + print(f" [FAIL] Verification failed, max error: {(y - expected).abs().max().item()}") + return 1 + + # Test mul_two kernel + print("\n[Test 2] mul_two kernel") + n = 512 + x = torch.arange(n, dtype=torch.float32, device="cuda") * 0.5 + y = torch.empty(n, dtype=torch.float32, device="cuda") + + print(f" Input shape: {x.shape}, device: {x.device}") + mul_two(x, y) + + # Verify results + expected = x * 2 + if torch.allclose(y, expected): + print(f" [PASS] Verified {n} elements correctly") + else: + print(f" [FAIL] Verification failed, max error: {(y - expected).abs().max().item()}") + return 1 + + # Test chained execution + print("\n[Test 3] Chained execution: (x + 1) * 2") + n = 256 + x = torch.full((n,), 10.0, dtype=torch.float32, device="cuda") + temp = torch.empty(n, dtype=torch.float32, device="cuda") + y = torch.empty(n, dtype=torch.float32, device="cuda") + + print(f" Initial value: {x[0].item()}") + add_one(x, temp) # temp = x + 1 = 11 + mul_two(temp, y) # y = temp * 2 = 22 + + expected = 22.0 + if torch.allclose(y, torch.tensor(expected, device="cuda")): + print(f" [PASS] Result: {y[0].item()}") + else: + print(f" [FAIL] Expected {expected}, got {y[0].item()}") + return 1 + + print("\n[PASS] All tests passed!") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/examples/cubin_launcher/embedded_cubin/embed_with_tvm_ffi/src/kernel.cu b/examples/cubin_launcher/embedded_cubin/embed_with_tvm_ffi/src/kernel.cu new file mode 100644 index 00000000..452957bf --- /dev/null +++ b/examples/cubin_launcher/embedded_cubin/embed_with_tvm_ffi/src/kernel.cu @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file examples/cubin_launcher/src/kernel.cu + * \brief Simple CUDA kernel for testing cubin_launcher functionality. + */ + +#include + +// [kernels.begin] +/*! + * \brief CUDA kernel that adds 1 to each element of an array. + * + * \param x Input array pointer. + * \param y Output array pointer. + * \param n Number of elements. + */ +extern "C" __global__ void add_one_cuda(const float* x, float* y, int64_t n) { + int64_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + if (idx < n) { + y[idx] = x[idx] + 1.0f; + } +} + +/*! + * \brief CUDA kernel that multiplies each element by 2. + * + * \param x Input array pointer. + * \param y Output array pointer. + * \param n Number of elements. + */ +extern "C" __global__ void mul_two_cuda(const float* x, float* y, int64_t n) { + int64_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + if (idx < n) { + y[idx] = x[idx] * 2.0f; + } +} +// [kernels.end] diff --git a/examples/cubin_launcher/embedded_cubin/src/lib_embedded.cc b/examples/cubin_launcher/embedded_cubin/embed_with_tvm_ffi/src/lib_embedded.cc similarity index 87% rename from examples/cubin_launcher/embedded_cubin/src/lib_embedded.cc rename to examples/cubin_launcher/embedded_cubin/embed_with_tvm_ffi/src/lib_embedded.cc index 70a54ec9..53401fc6 100644 --- a/examples/cubin_launcher/embedded_cubin/src/lib_embedded.cc +++ b/examples/cubin_launcher/embedded_cubin/embed_with_tvm_ffi/src/lib_embedded.cc @@ -24,16 +24,17 @@ * embedded CUBIN data. */ -// [example.begin] #include #include #include #include #include +// [example.begin] // Embed CUBIN module with name "env" // This creates the necessary symbols and singleton struct for accessing the embedded CUBIN TVM_FFI_EMBED_CUBIN(env); +// [example.end] namespace cubin_embedded { @@ -64,14 +65,18 @@ void AddOne(tvm::ffi::TensorView x, tvm::ffi::TensorView y) { // Get CUDA stream DLDevice device = x.device(); - cudaStream_t stream = - static_cast(TVMFFIEnvGetStream(device.device_type, device.device_id)); + tvm::ffi::cuda_api::StreamHandle stream = static_cast( + TVMFFIEnvGetStream(device.device_type, device.device_id)); // Launch kernel - cudaError_t result = kernel.Launch(args, grid, block, stream); + tvm::ffi::cuda_api::ResultType result = kernel.Launch(args, grid, block, stream); TVM_FFI_CHECK_CUDA_ERROR(result); } +} // namespace cubin_embedded + +namespace cubin_embedded { + /*! * \brief Launch mul_two_cuda kernel on input tensor. * \param x Input tensor (float32, 1D) @@ -99,11 +104,11 @@ void MulTwo(tvm::ffi::TensorView x, tvm::ffi::TensorView y) { // Get CUDA stream DLDevice device = x.device(); - cudaStream_t stream = - static_cast(TVMFFIEnvGetStream(device.device_type, device.device_id)); + tvm::ffi::cuda_api::StreamHandle stream = static_cast( + TVMFFIEnvGetStream(device.device_type, device.device_id)); // Launch kernel - cudaError_t result = kernel.Launch(args, grid, block, stream); + tvm::ffi::cuda_api::ResultType result = kernel.Launch(args, grid, block, stream); TVM_FFI_CHECK_CUDA_ERROR(result); } @@ -112,4 +117,3 @@ TVM_FFI_DLL_EXPORT_TYPED_FUNC(add_one, cubin_embedded::AddOne); TVM_FFI_DLL_EXPORT_TYPED_FUNC(mul_two, cubin_embedded::MulTwo); } // namespace cubin_embedded -// [example.end] diff --git a/examples/cubin_launcher/embedded_cubin/include_bin2c/CMakeLists.txt b/examples/cubin_launcher/embedded_cubin/include_bin2c/CMakeLists.txt new file mode 100644 index 00000000..f825878b --- /dev/null +++ b/examples/cubin_launcher/embedded_cubin/include_bin2c/CMakeLists.txt @@ -0,0 +1,83 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +cmake_minimum_required(VERSION 3.27) +project(embedded_cubin_example LANGUAGES CXX CUDA) + +set(CMAKE_CXX_STANDARD 17) + +set(CMAKE_TVM_FFI_CUBIN_LAUNCHER_USE_DRIVER_API + OFF + CACHE BOOL "Use driver API in cubin launcher" +) + +# Prefer virtualenv when searching for python +set(Python_FIND_VIRTUALENV FIRST) # cmake-lint: disable=C0103 + +# Find tvm-ffi package +find_package( + Python + COMPONENTS Interpreter + REQUIRED +) +execute_process( + COMMAND "${Python_EXECUTABLE}" -m tvm_ffi.config --cmakedir + OUTPUT_STRIP_TRAILING_WHITESPACE + OUTPUT_VARIABLE tvm_ffi_ROOT +) +find_package(tvm_ffi CONFIG REQUIRED) + +# Find CUDA toolkit +find_package(CUDAToolkit REQUIRED) +include_directories(${CUDAToolkit_INCLUDE_DIRS}) + +# [cmake_example.begin] + +# Step 1: Compile kernel.cu to FATBIN using add_tvm_ffi_fatbin utility or `CUDA_FATBIN_COMPILATION` +set(CMAKE_CUDA_ARCHITECTURES 75;80;86;89;90;100;120) +add_library(kernel_fatbin OBJECT src/kernel.cu) +set_target_properties(kernel_fatbin PROPERTIES CUDA_FATBIN_COMPILATION ON) + +add_custom_target( + kernel_fatbin.h + COMMAND bin2c -c "$" > + "${CMAKE_CURRENT_SOURCE_DIR}/src/kernel_fatbin.h" + DEPENDS kernel_fatbin + COMMENT "Run bin2c for ${kernel_fatbin}" +) + +# Step 2: Build lib_embedded shared library +add_library(lib_embedded SHARED src/lib_embedded.cc) +add_dependencies(lib_embedded kernel_fatbin.h) +target_link_libraries(lib_embedded PRIVATE tvm_ffi::header tvm_ffi::shared) +set_target_properties(lib_embedded PROPERTIES POSITION_INDEPENDENT_CODE ON) + +# Step 3: Link against CUDA Driver API or Runtime API based on config +if (CMAKE_TVM_FFI_CUBIN_LAUNCHER_USE_DRIVER_API) + add_compile_definitions(TVM_FFI_CUBIN_LAUNCHER_USE_DRIVER_API=1) + target_link_libraries(lib_embedded PRIVATE cuda) +else () + target_link_libraries(lib_embedded PRIVATE CUDA::cudart) +endif () + +set_target_properties( + lib_embedded + PROPERTIES PREFIX "" + SUFFIX ".so" + LINKER_LANGUAGE CXX +) +# [cmake_example.end] diff --git a/examples/cubin_launcher/embedded_cubin/include_bin2c/main.py b/examples/cubin_launcher/embedded_cubin/include_bin2c/main.py new file mode 100644 index 00000000..006ba36c --- /dev/null +++ b/examples/cubin_launcher/embedded_cubin/include_bin2c/main.py @@ -0,0 +1,112 @@ +#!/usr/bin/env python3 +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Example script for embedded CUBIN library. + +This example demonstrates using lib_embedded.so which has CUBIN data +embedded directly in the shared library using objcopy. +""" + +import sys +from pathlib import Path + +import torch +from tvm_ffi import load_module + + +def main() -> int: + """Test the lib_embedded.so library with embedded CUBIN.""" + print("=" * 60) + print("Example: Embedded CUBIN Library") + print("=" * 60) + + # Check CUDA availability + if not torch.cuda.is_available(): + print("[ERROR] CUDA is not available") + return 1 + + print(f"CUDA device: {torch.cuda.get_device_name(0)}") + print(f"PyTorch version: {torch.__version__}\n") + + # Load the library + lib_path = Path(__file__).parent / "build" / "lib_embedded.so" + mod = load_module(str(lib_path)) + print(f"Loaded library: {lib_path}") + + # Get the functions + add_one = mod["add_one"] + mul_two = mod["mul_two"] + print("Loaded functions: add_one, mul_two") + + # Test add_one kernel + print("\n[Test 1] add_one kernel") + n = 1024 + x = torch.arange(n, dtype=torch.float32, device="cuda") + y = torch.empty(n, dtype=torch.float32, device="cuda") + + print(f" Input shape: {x.shape}, device: {x.device}") + add_one(x, y) + + # Verify results + expected = x + 1 + if torch.allclose(y, expected): + print(f" [PASS] Verified {n} elements correctly") + else: + print(f" [FAIL] Verification failed, max error: {(y - expected).abs().max().item()}") + return 1 + + # Test mul_two kernel + print("\n[Test 2] mul_two kernel") + n = 512 + x = torch.arange(n, dtype=torch.float32, device="cuda") * 0.5 + y = torch.empty(n, dtype=torch.float32, device="cuda") + + print(f" Input shape: {x.shape}, device: {x.device}") + mul_two(x, y) + + # Verify results + expected = x * 2 + if torch.allclose(y, expected): + print(f" [PASS] Verified {n} elements correctly") + else: + print(f" [FAIL] Verification failed, max error: {(y - expected).abs().max().item()}") + return 1 + + # Test chained execution + print("\n[Test 3] Chained execution: (x + 1) * 2") + n = 256 + x = torch.full((n,), 10.0, dtype=torch.float32, device="cuda") + temp = torch.empty(n, dtype=torch.float32, device="cuda") + y = torch.empty(n, dtype=torch.float32, device="cuda") + + print(f" Initial value: {x[0].item()}") + add_one(x, temp) # temp = x + 1 = 11 + mul_two(temp, y) # y = temp * 2 = 22 + + expected = 22.0 + if torch.allclose(y, torch.tensor(expected, device="cuda")): + print(f" [PASS] Result: {y[0].item()}") + else: + print(f" [FAIL] Expected {expected}, got {y[0].item()}") + return 1 + + print("\n[PASS] All tests passed!") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/examples/cubin_launcher/embedded_cubin/include_bin2c/src/.gitignore b/examples/cubin_launcher/embedded_cubin/include_bin2c/src/.gitignore new file mode 100644 index 00000000..65644644 --- /dev/null +++ b/examples/cubin_launcher/embedded_cubin/include_bin2c/src/.gitignore @@ -0,0 +1 @@ +kernel_fatbin.h diff --git a/examples/cubin_launcher/embedded_cubin/include_bin2c/src/kernel.cu b/examples/cubin_launcher/embedded_cubin/include_bin2c/src/kernel.cu new file mode 100644 index 00000000..452957bf --- /dev/null +++ b/examples/cubin_launcher/embedded_cubin/include_bin2c/src/kernel.cu @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file examples/cubin_launcher/src/kernel.cu + * \brief Simple CUDA kernel for testing cubin_launcher functionality. + */ + +#include + +// [kernels.begin] +/*! + * \brief CUDA kernel that adds 1 to each element of an array. + * + * \param x Input array pointer. + * \param y Output array pointer. + * \param n Number of elements. + */ +extern "C" __global__ void add_one_cuda(const float* x, float* y, int64_t n) { + int64_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + if (idx < n) { + y[idx] = x[idx] + 1.0f; + } +} + +/*! + * \brief CUDA kernel that multiplies each element by 2. + * + * \param x Input array pointer. + * \param y Output array pointer. + * \param n Number of elements. + */ +extern "C" __global__ void mul_two_cuda(const float* x, float* y, int64_t n) { + int64_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + if (idx < n) { + y[idx] = x[idx] * 2.0f; + } +} +// [kernels.end] diff --git a/examples/cubin_launcher/embedded_cubin/include_bin2c/src/lib_embedded.cc b/examples/cubin_launcher/embedded_cubin/include_bin2c/src/lib_embedded.cc new file mode 100644 index 00000000..dd8d578f --- /dev/null +++ b/examples/cubin_launcher/embedded_cubin/include_bin2c/src/lib_embedded.cc @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file examples/cubin_launcher/src/lib_embedded.cc + * \brief TVM-FFI library with embedded CUBIN kernels. + * + * This library exports TVM-FFI functions to launch CUDA kernels from + * embedded CUBIN data. + */ + +#include +#include +#include +#include +#include + +#include "kernel_fatbin.h" + +// [example.begin] +TVM_FFI_EMBED_CUBIN_FROM_BYTES(env, imageBytes); +// [example.end] + +namespace cubin_embedded { + +/*! + * \brief Launch add_one_cuda kernel on input tensor. + * \param x Input tensor (float32, 1D) + * \param y Output tensor (float32, 1D, same shape as x) + */ +void AddOne(tvm::ffi::TensorView x, tvm::ffi::TensorView y) { + // Get kernel from embedded CUBIN (cached in static variable for efficiency) + static auto kernel = TVM_FFI_EMBED_CUBIN_GET_KERNEL(env, "add_one_cuda"); + + TVM_FFI_CHECK(x.ndim() == 1, ValueError) << "Input must be 1D tensor"; + TVM_FFI_CHECK(y.ndim() == 1, ValueError) << "Output must be 1D tensor"; + TVM_FFI_CHECK(x.size(0) == y.size(0), ValueError) << "Input and output must have same size"; + + int64_t n = x.size(0); + void* x_ptr = x.data_ptr(); + void* y_ptr = y.data_ptr(); + + // Prepare kernel arguments + void* args[] = {reinterpret_cast(&x_ptr), reinterpret_cast(&y_ptr), + reinterpret_cast(&n)}; + + // Launch configuration + tvm::ffi::dim3 grid((n + 255) / 256); + tvm::ffi::dim3 block(256); + + // Get CUDA stream + DLDevice device = x.device(); + tvm::ffi::cuda_api::StreamHandle stream = static_cast( + TVMFFIEnvGetStream(device.device_type, device.device_id)); + + // Launch kernel + tvm::ffi::cuda_api::ResultType result = kernel.Launch(args, grid, block, stream); + TVM_FFI_CHECK_CUDA_ERROR(result); +} + +} // namespace cubin_embedded + +namespace cubin_embedded { + +/*! + * \brief Launch mul_two_cuda kernel on input tensor. + * \param x Input tensor (float32, 1D) + * \param y Output tensor (float32, 1D, same shape as x) + */ +void MulTwo(tvm::ffi::TensorView x, tvm::ffi::TensorView y) { + // Get kernel from embedded CUBIN (cached in static variable for efficiency) + static auto kernel = TVM_FFI_EMBED_CUBIN_GET_KERNEL(env, "mul_two_cuda"); + + TVM_FFI_CHECK(x.ndim() == 1, ValueError) << "Input must be 1D tensor"; + TVM_FFI_CHECK(y.ndim() == 1, ValueError) << "Output must be 1D tensor"; + TVM_FFI_CHECK(x.size(0) == y.size(0), ValueError) << "Input and output must have same size"; + + int64_t n = x.size(0); + void* x_ptr = x.data_ptr(); + void* y_ptr = y.data_ptr(); + + // Prepare kernel arguments + void* args[] = {reinterpret_cast(&x_ptr), reinterpret_cast(&y_ptr), + reinterpret_cast(&n)}; + + // Launch configuration + tvm::ffi::dim3 grid((n + 255) / 256); + tvm::ffi::dim3 block(256); + + // Get CUDA stream + DLDevice device = x.device(); + tvm::ffi::cuda_api::StreamHandle stream = static_cast( + TVMFFIEnvGetStream(device.device_type, device.device_id)); + + // Launch kernel + tvm::ffi::cuda_api::ResultType result = kernel.Launch(args, grid, block, stream); + TVM_FFI_CHECK_CUDA_ERROR(result); +} + +// Export TVM-FFI functions +TVM_FFI_DLL_EXPORT_TYPED_FUNC(add_one, cubin_embedded::AddOne); +TVM_FFI_DLL_EXPORT_TYPED_FUNC(mul_two, cubin_embedded::MulTwo); + +} // namespace cubin_embedded diff --git a/include/tvm/ffi/extra/cuda/base.h b/include/tvm/ffi/extra/cuda/base.h index 810fa064..d8ea486a 100644 --- a/include/tvm/ffi/extra/cuda/base.h +++ b/include/tvm/ffi/extra/cuda/base.h @@ -23,30 +23,36 @@ #ifndef TVM_FFI_EXTRA_CUDA_BASE_H_ #define TVM_FFI_EXTRA_CUDA_BASE_H_ -#include -#include - namespace tvm { namespace ffi { /*! - * \brief Macro for checking CUDA runtime API errors. - * - * This macro checks the return value of CUDA runtime API calls and throws - * a RuntimeError with detailed error information if the call fails. + * \brief A simple 3D dimension type for CUDA kernel launch configuration. * - * \param stmt The CUDA runtime API call to check. + * This struct mimics the behavior of dim3 from CUDA Runtime API and provides + * a compatible interface for kernel launch configuration. It can be constructed + * from 1, 2, or 3 dimensions. */ -#define TVM_FFI_CHECK_CUDA_ERROR(stmt) \ - do { \ - cudaError_t __err = (stmt); \ - if (__err != cudaSuccess) { \ - const char* __err_name = cudaGetErrorName(__err); \ - const char* __err_str = cudaGetErrorString(__err); \ - TVM_FFI_THROW(RuntimeError) << "CUDA Runtime Error: " << __err_name << " (" \ - << static_cast(__err) << "): " << __err_str; \ - } \ - } while (0) +struct dim3 { + /*! \brief X dimension (number of blocks in x-direction or threads in x-direction) */ + unsigned int x; + /*! \brief Y dimension (number of blocks in y-direction or threads in y-direction) */ + unsigned int y; + /*! \brief Z dimension (number of blocks in z-direction or threads in z-direction) */ + unsigned int z; + + /*! \brief Default constructor initializes to (1, 1, 1) */ + dim3() : x(1), y(1), z(1) {} + + /*! \brief Construct with x dimension, y and z default to 1 */ + explicit dim3(unsigned int x_) : x(x_), y(1), z(1) {} + + /*! \brief Construct with x and y dimensions, z defaults to 1 */ + dim3(unsigned int x_, unsigned int y_) : x(x_), y(y_), z(1) {} + + /*! \brief Construct with all three dimensions */ + dim3(unsigned int x_, unsigned int y_, unsigned int z_) : x(x_), y(y_), z(z_) {} +}; } // namespace ffi } // namespace tvm diff --git a/include/tvm/ffi/extra/cuda/cubin_launcher.h b/include/tvm/ffi/extra/cuda/cubin_launcher.h index 72eadd2e..c910e894 100644 --- a/include/tvm/ffi/extra/cuda/cubin_launcher.h +++ b/include/tvm/ffi/extra/cuda/cubin_launcher.h @@ -29,10 +29,12 @@ #ifndef TVM_FFI_EXTRA_CUDA_CUBIN_LAUNCHER_H_ #define TVM_FFI_EXTRA_CUDA_CUBIN_LAUNCHER_H_ +#include #include #include #include #include +#include #include #include @@ -41,34 +43,6 @@ namespace tvm { namespace ffi { -/*! - * \brief A simple 3D dimension type for CUDA kernel launch configuration. - * - * This struct mimics the behavior of dim3 from CUDA Runtime API and provides - * a compatible interface for kernel launch configuration. It can be constructed - * from 1, 2, or 3 dimensions. - */ -struct dim3 { - /*! \brief X dimension (number of blocks in x-direction or threads in x-direction) */ - unsigned int x; - /*! \brief Y dimension (number of blocks in y-direction or threads in y-direction) */ - unsigned int y; - /*! \brief Z dimension (number of blocks in z-direction or threads in z-direction) */ - unsigned int z; - - /*! \brief Default constructor initializes to (1, 1, 1) */ - dim3() : x(1), y(1), z(1) {} - - /*! \brief Construct with x dimension, y and z default to 1 */ - explicit dim3(unsigned int x_) : x(x_), y(1), z(1) {} - - /*! \brief Construct with x and y dimensions, z defaults to 1 */ - dim3(unsigned int x_, unsigned int y_) : x(x_), y(y_), z(1) {} - - /*! \brief Construct with all three dimensions */ - dim3(unsigned int x_, unsigned int y_, unsigned int z_) : x(x_), y(y_), z(z_) {} -}; - /*! * \brief Macro to embed a CUBIN module with static initialization. * @@ -191,6 +165,36 @@ struct dim3 { }; \ } /* anonymous namespace */ +/*! + * \brief Macro to load a CUBIN module from a byte array. + * + * This macro creates a singleton struct to manage the CubinModule instance + * initialized from a byte array (e.g. from `#embed ` or bin2c output). + * + * \par Usage Example + * \code{.cpp} + * constexpr unsigned char image[] = { ... }; + * TVM_FFI_EMBED_CUBIN_FROM_BYTES(my_kernels, image); + * + * void MyFunc() { + * static auto kernel = TVM_FFI_EMBED_CUBIN_GET_KERNEL(my_kernels, "kernel_name"); + * } + * \endcode + * + * \param name The identifier for this embedded CUBIN module. + * \param imageBytes The byte array containing the CUBIN/FATBIN data. + */ +#define TVM_FFI_EMBED_CUBIN_FROM_BYTES(name, imageBytes) \ + namespace { \ + struct EmbedCubinModule_##name { \ + tvm::ffi::CubinModule mod{imageBytes}; \ + static EmbedCubinModule_##name* Global() { \ + static EmbedCubinModule_##name inst; \ + return &inst; \ + } \ + }; \ + } /* anonymous namespace */ + /*! * \brief Macro to get a kernel from an embedded CUBIN module. * @@ -291,8 +295,7 @@ class CubinModule { * \param bytes CUBIN binary data as a Bytes object. */ explicit CubinModule(const Bytes& bytes) { - TVM_FFI_CHECK_CUDA_ERROR( - cudaLibraryLoadData(&library_, bytes.data(), nullptr, nullptr, 0, nullptr, nullptr, 0)); + TVM_FFI_CHECK_CUDA_ERROR(cuda_api::LoadLibrary(&library_, bytes.data())); } /*! @@ -302,14 +305,23 @@ class CubinModule { * \note The `code` buffer points to an ELF image. */ explicit CubinModule(const char* code) { - TVM_FFI_CHECK_CUDA_ERROR( - cudaLibraryLoadData(&library_, code, nullptr, nullptr, 0, nullptr, nullptr, 0)); + TVM_FFI_CHECK_CUDA_ERROR(cuda_api::LoadLibrary(&library_, code)); + } + + /*! + * \brief Load CUBIN module from raw memory buffer. + * + * \param code Pointer to CUBIN binary data. + * \note The `code` buffer points to an ELF image. + */ + explicit CubinModule(const unsigned char* code) { + TVM_FFI_CHECK_CUDA_ERROR(cuda_api::LoadLibrary(&library_, code)); } /*! \brief Destructor unloads the library */ ~CubinModule() { if (library_ != nullptr) { - cudaLibraryUnload(library_); + cuda_api::UnloadLibrary(library_); } } @@ -343,7 +355,7 @@ class CubinModule { CubinKernel operator[](const char* name); /*! \brief Get the underlying cudaLibrary_t handle */ - cudaLibrary_t GetHandle() const { return library_; } + cuda_api::LibraryHandle GetHandle() const { return library_; } // Non-copyable CubinModule(const CubinModule&) = delete; @@ -370,7 +382,7 @@ class CubinModule { CubinModule& operator=(CubinModule&& other) noexcept { if (this != &other) { if (library_ != nullptr) { - cudaLibraryUnload(library_); + cuda_api::UnloadLibrary(library_); } library_ = other.library_; other.library_ = nullptr; @@ -379,7 +391,7 @@ class CubinModule { } private: - cudaLibrary_t library_ = nullptr; + cuda_api::LibraryHandle library_ = nullptr; }; /*! @@ -421,8 +433,8 @@ class CubinKernel { * \param library The cudaLibrary_t handle. * \param name Name of the kernel function. */ - CubinKernel(cudaLibrary_t library, const char* name) { - TVM_FFI_CHECK_CUDA_ERROR(cudaLibraryGetKernel(&kernel_, library, name)); + CubinKernel(cuda_api::LibraryHandle library, const char* name) { + TVM_FFI_CHECK_CUDA_ERROR(cuda_api::GetKernel(&kernel_, library, name)); } /*! \brief Destructor (kernel handle doesn't need explicit cleanup) */ @@ -466,17 +478,13 @@ class CubinKernel { * \note The kernel executes asynchronously. Use cudaStreamSynchronize() or * cudaDeviceSynchronize() to wait for completion if needed. */ - cudaError_t Launch(void** args, dim3 grid, dim3 block, cudaStream_t stream, - uint32_t dyn_smem_bytes = 0) { - // Cast cudaKernel_t to const void* for use with cudaLaunchKernel - // The Runtime API accepts cudaKernel_t directly as a function pointer - auto kernel = reinterpret_cast(kernel_); - return cudaLaunchKernel(kernel, {grid.x, grid.y, grid.z}, {block.x, block.y, block.z}, args, - dyn_smem_bytes, stream); + cuda_api::ResultType Launch(void** args, dim3 grid, dim3 block, cuda_api::StreamHandle stream, + uint32_t dyn_smem_bytes = 0) { + return cuda_api::LaunchKernel(kernel_, args, grid, block, stream, dyn_smem_bytes); } /*! \brief Get the underlying cudaKernel_t handle */ - cudaKernel_t GetHandle() const { return kernel_; } + cuda_api::KernelHandle GetHandle() const { return kernel_; } // Non-copyable CubinKernel(const CubinKernel&) = delete; @@ -527,35 +535,34 @@ class CubinKernel { */ void SetMaxDynamicSharedMemory(int64_t dynamic_smem_max = -1) { int device_count = 0; - cudaError_t err = cudaGetDeviceCount(&device_count); - if (err != cudaSuccess || device_count == 0) { + cuda_api::ResultType err = cuda_api::GetDeviceCount(&device_count); + if (err != cuda_api::kSuccess || device_count == 0) { return; // No devices available, nothing to configure } bool any_success = false; for (int device_id = 0; device_id < device_count; ++device_id) { + auto device = cuda_api::GetDeviceHandle(device_id); // Query device's maximum shared memory per block int max_shared_mem = 0; - err = cudaDeviceGetAttribute(&max_shared_mem, cudaDevAttrMaxSharedMemoryPerBlock, device_id); - if (err != cudaSuccess) { + err = cuda_api::GetDeviceAttribute( + &max_shared_mem, + /* CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK/cudaDevAttrMaxSharedMemoryPerBlock */ + cuda_api::DeviceAttrType(8), device); + if (err != cuda_api::kSuccess) { continue; // Skip this device if we can't get its attribute } int shared_mem_to_set; if (dynamic_smem_max == -1) { - // Query the kernel's static shared memory usage - cudaFuncAttributes func_attr; - - // According to the documentation, we can use cudaFuncGetAttributes to get the attributes of - // cudaKernel_t returned by cudaLibraryGetKernel, just cast the kernel_ to const void* - err = cudaFuncGetAttributes(&func_attr, reinterpret_cast(kernel_)); - if (err != cudaSuccess) { + int static_shared; + err = cuda_api::GetKernelSharedMem(kernel_, static_shared, device); + if (err != cuda_api::kSuccess) { continue; // Skip this device if we can't get kernel attributes } // Calculate available dynamic shared memory: // device max shared memory - static shared memory used by kernel - int64_t static_shared = static_cast(func_attr.sharedSizeBytes); int64_t max_shared = static_cast(max_shared_mem); int64_t available = max_shared - static_shared; shared_mem_to_set = (available > 0) ? static_cast(available) : 0; @@ -564,9 +571,8 @@ class CubinKernel { } // Set the maximum dynamic shared memory size for this device - err = cudaKernelSetAttributeForDevice(kernel_, cudaFuncAttributeMaxDynamicSharedMemorySize, - shared_mem_to_set, device_id); - if (err == cudaSuccess) { + err = cuda_api::SetKernelMaxDynamicSharedMem(kernel_, shared_mem_to_set, device); + if (err == cuda_api::kSuccess) { any_success = true; } // Don't error out for individual device failures - user may only use some GPUs @@ -578,7 +584,7 @@ class CubinKernel { } } - cudaKernel_t kernel_ = nullptr; + cuda_api::KernelHandle kernel_ = nullptr; friend class CubinModule; }; diff --git a/include/tvm/ffi/extra/cuda/internal/unified_api.h b/include/tvm/ffi/extra/cuda/internal/unified_api.h new file mode 100644 index 00000000..302d8ef1 --- /dev/null +++ b/include/tvm/ffi/extra/cuda/internal/unified_api.h @@ -0,0 +1,206 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef TVM_FFI_EXTRA_CUDA_INTERNAL_UNIFIED_API_H_ +#define TVM_FFI_EXTRA_CUDA_INTERNAL_UNIFIED_API_H_ + +#include +#include + +#include + +// =========================================================================== +// Section 1: Configuration & Version Checks +// =========================================================================== + +// We only use unified API for cubin launcher for now +// this name is intentional to avoid confusion of other API usages +#ifndef TVM_FFI_CUBIN_LAUNCHER_USE_DRIVER_API +#if CUDART_VERSION >= 12080 +// Use Runtime API by default if possible (CUDA >= 12.8) +#define TVM_FFI_CUBIN_LAUNCHER_USE_DRIVER_API 0 +#else // if CUDART_VERSION < 12080 +#define TVM_FFI_CUBIN_LAUNCHER_USE_DRIVER_API 1 +#endif +#else // if defined(TVM_FFI_CUBIN_LAUNCHER_USE_DRIVER_API) +// User explicitly defined the macro, check compatibility +#if (!(TVM_FFI_CUBIN_LAUNCHER_USE_DRIVER_API)) && (CUDART_VERSION < 12080) +#define _STRINGIFY(x) #x +#define STR(x) _STRINGIFY(x) +static_assert(false, "Runtime API only supported for CUDA >= 12.8, got CUDA Runtime version: " STR( + CUDART_VERSION)); +#undef STR +#undef _STRINGIFY +#endif +#endif + +namespace tvm { +namespace ffi { +namespace cuda_api { + +// =========================================================================== +// Section 2: Type Definitions & Macros +// =========================================================================== + +#if TVM_FFI_CUBIN_LAUNCHER_USE_DRIVER_API + +// Driver API Types +using StreamHandle = CUstream; +using DeviceHandle = CUdevice; +using LibraryHandle = CUlibrary; +using KernelHandle = CUkernel; +using LaunchConfig = CUlaunchConfig; + +using ResultType = CUresult; +using LaunchAttrType = CUlaunchAttribute; +using DeviceAttrType = CUdevice_attribute; + +constexpr ResultType kSuccess = CUDA_SUCCESS; + +// Driver API Functions +#define _TVM_FFI_CUDA_FUNC(name) cu##name + +#else + +using StreamHandle = cudaStream_t; +using DeviceHandle = int; +using LibraryHandle = cudaLibrary_t; +using KernelHandle = cudaKernel_t; +using LaunchConfig = cudaLaunchConfig_t; + +using ResultType = cudaError_t; +using LaunchAttrType = cudaLaunchAttribute; +using DeviceAttrType = cudaDeviceAttr; + +constexpr ResultType kSuccess = cudaSuccess; + +// Runtime API Functions +#define _TVM_FFI_CUDA_FUNC(name) cuda##name + +#endif + +// =========================================================================== +// Section 3: Error Handling +// =========================================================================== + +// Helper to get error name/string based on API +inline void GetErrorString(ResultType err, const char** name, const char** str) { +#if TVM_FFI_CUBIN_LAUNCHER_USE_DRIVER_API + cuGetErrorName(err, name); + cuGetErrorString(err, str); +#else + *name = cudaGetErrorName(err); + *str = cudaGetErrorString(err); +#endif +} + +#define TVM_FFI_CHECK_CUDA_ERROR(stmt) \ + do { \ + ::tvm::ffi::cuda_api::ResultType __err = (stmt); \ + if (__err != ::tvm::ffi::cuda_api::kSuccess) { \ + const char *__err_name, *__err_str; \ + ::tvm::ffi::cuda_api::GetErrorString(__err, &__err_name, &__err_str); \ + TVM_FFI_THROW(RuntimeError) << "CUDA Error: " << __err_name << " (" \ + << static_cast(__err) << "): " << __err_str; \ + } \ + } while (0) + +// =========================================================================== +// Section 4: Unified API Wrappers +// =========================================================================== + +inline ResultType LoadLibrary(LibraryHandle* library, const void* image) { + return _TVM_FFI_CUDA_FUNC(LibraryLoadData)(library, image, nullptr, nullptr, 0, nullptr, nullptr, + 0); +} + +inline ResultType UnloadLibrary(LibraryHandle library) { + return _TVM_FFI_CUDA_FUNC(LibraryUnload)(library); +} + +inline ResultType GetKernel(KernelHandle* kernel, LibraryHandle library, const char* name) { + return _TVM_FFI_CUDA_FUNC(LibraryGetKernel)(kernel, library, name); +} + +inline DeviceHandle GetDeviceHandle(int device_id) { +#if TVM_FFI_CUBIN_LAUNCHER_USE_DRIVER_API + CUdevice dev; + // Note: We use CHECK here because this conversion usually shouldn't fail if ID is valid + // and we need to return a value. + TVM_FFI_CHECK_CUDA_ERROR(cuDeviceGet(&dev, device_id)); + return dev; +#else + return device_id; +#endif +} + +inline ResultType LaunchKernel(KernelHandle kernel, void** args, tvm::ffi::dim3 grid, + tvm::ffi::dim3 block, StreamHandle stream, + uint32_t dyn_smem_bytes = 0) { +#if TVM_FFI_CUBIN_LAUNCHER_USE_DRIVER_API + return cuLaunchKernel(reinterpret_cast(kernel), grid.x, grid.y, grid.z, block.x, + block.y, block.z, dyn_smem_bytes, stream, args, nullptr); +#else + return cudaLaunchKernel(reinterpret_cast(kernel), {grid.x, grid.y, grid.z}, + {block.x, block.y, block.z}, args, dyn_smem_bytes, stream); +#endif +} + +inline ResultType GetKernelSharedMem(KernelHandle kernel, int& out, DeviceHandle device) { +#if TVM_FFI_CUBIN_LAUNCHER_USE_DRIVER_API + return cuKernelGetAttribute(&out, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, kernel, device); +#else + cudaFuncAttributes func_attr; + cudaError_t err = cudaFuncGetAttributes(&func_attr, kernel); + if (err == cudaSuccess) { + out = func_attr.sharedSizeBytes; + } + return err; +#endif +} + +inline ResultType SetKernelMaxDynamicSharedMem(KernelHandle kernel, int shmem, + DeviceHandle device) { +#if TVM_FFI_CUBIN_LAUNCHER_USE_DRIVER_API + return cuKernelSetAttribute(CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shmem, kernel, + device); +#else + return cudaKernelSetAttributeForDevice(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem, + device); +#endif +} + +// Additional wrappers for device operations used in CubinLauncher +inline ResultType GetDeviceCount(int* count) { +#if TVM_FFI_CUBIN_LAUNCHER_USE_DRIVER_API + return cuDeviceGetCount(count); +#else + return cudaGetDeviceCount(count); +#endif +} + +inline ResultType GetDeviceAttribute(int* value, DeviceAttrType attr, DeviceHandle device) { + return _TVM_FFI_CUDA_FUNC(DeviceGetAttribute)(value, attr, device); +} + +} // namespace cuda_api +} // namespace ffi +} // namespace tvm + +#endif // TVM_FFI_EXTRA_CUDA_INTERNAL_UNIFIED_API_H_