diff --git a/.github/workflows/ci_test.yml b/.github/workflows/ci_test.yml
index 1bed6513..641d3a8c 100644
--- a/.github/workflows/ci_test.yml
+++ b/.github/workflows/ci_test.yml
@@ -47,30 +47,13 @@ jobs:
lint:
needs: [prepare]
- if: >
- needs.prepare.outputs.should_skip_ci_commit != 'true' &&
- needs.prepare.outputs.should_skip_ci_docs_only != 'true'
- name: Lint
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- with:
- submodules: recursive
- - uses: actions/setup-python@v6
- with:
- python-version: '3.13'
- - name: Install dependencies
- run: |
- pip install black pylint ruff
- sudo apt-get install -y clang-format-15
-
- - name: Lint
- run: |
- tests/scripts/task_lint.sh
-
+ - uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1
test:
- needs: [prepare]
+ needs: [lint, prepare]
if: >
needs.prepare.outputs.should_skip_ci_commit != 'true' &&
needs.prepare.outputs.should_skip_ci_docs_only != 'true'
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index dc24845b..f31f7f3f 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -15,8 +15,29 @@
# specific language governing permissions and limitations
# under the License.
+# TODO(@junrushao): adding a few extra hooks:
+# - Python type checking via mypy or ty
+# - CMake linters
+# - Conventional commits
+default_install_hook_types:
+ - pre-commit
repos:
- # Standard hooks
+ - repo: local
+ hooks:
+ - id: check-asf-header
+ name: check ASF Header
+ entry: python tests/lint/check_asf_header.py --check
+ language: system
+ pass_filenames: false
+ verbose: false
+ - repo: local
+ hooks:
+ - id: check-file-type
+ name: check file types
+ entry: python tests/lint/check_file_type.py
+ language: system
+ pass_filenames: false
+ verbose: false
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
hooks:
@@ -28,3 +49,31 @@ repos:
- id: mixed-line-ending
- id: requirements-txt-fixer
- id: trailing-whitespace
+ - id: check-yaml
+ - id: check-toml
+ - repo: https://github.com/astral-sh/ruff-pre-commit
+ rev: v0.12.3
+ hooks:
+ - id: ruff-check
+ types_or: [python, pyi, jupyter]
+ args: [--fix]
+ - id: ruff-format
+ types_or: [python, pyi, jupyter]
+ - repo: https://github.com/pre-commit/mirrors-clang-format
+ rev: "v20.1.8"
+ hooks:
+ - id: clang-format
+ - repo: https://github.com/MarcoGorelli/cython-lint
+ rev: v0.16.7
+ hooks:
+ - id: cython-lint
+ args: [--max-line-length=120]
+ - id: double-quote-cython-strings
+ - repo: https://github.com/scop/pre-commit-shfmt
+ rev: v3.12.0-2
+ hooks:
+ - id: shfmt
+ - repo: https://github.com/shellcheck-py/shellcheck-py
+ rev: v0.10.0.1
+ hooks:
+ - id: shellcheck
diff --git a/README.md b/README.md
index 525518e3..88083e99 100644
--- a/README.md
+++ b/README.md
@@ -18,5 +18,3 @@
# tvm ffi
[](https://github.com/apache/tvm-ffi/actions/workflows/ci_test.yml)
-
-
diff --git a/docs/_static/custom.css b/docs/_static/custom.css
index 6277c6d4..088d4992 100644
--- a/docs/_static/custom.css
+++ b/docs/_static/custom.css
@@ -2,4 +2,4 @@
See: https://github.com/executablebooks/sphinx-book-theme/issues/732 */
#rtd-footer-container {
margin: 0px !important;
-}
\ No newline at end of file
+}
diff --git a/docs/conf.py b/docs/conf.py
index 8a1c5b36..e6218781 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -16,7 +16,6 @@
# under the License.
# -*- coding: utf-8 -*-
import os
-import sys
import tomli
@@ -198,7 +197,7 @@ def footer_html():
@@ -226,6 +225,6 @@ def footer_html():
"conf_py_path": "/docs/",
}
-html_static_path = ['_static']
+html_static_path = ["_static"]
-html_css_files = ['custom.css']
+html_css_files = ["custom.css"]
diff --git a/docs/get_started/quick_start.md b/docs/get_started/quick_start.md
index 449c1db4..043ea289 100644
--- a/docs/get_started/quick_start.md
+++ b/docs/get_started/quick_start.md
@@ -279,4 +279,3 @@ The main takeaway points are:
- **ffi::Tensor** is a universal tensor structure that enables zero-copy exchange of array data
- **Module loading** is provided by tvm ffi APIs in multiple languages.
- **C ABI** is provided for easy low-level integration
-
diff --git a/docs/guides/compiler_integration.md b/docs/guides/compiler_integration.md
index a1355aff..7338dbf2 100644
--- a/docs/guides/compiler_integration.md
+++ b/docs/guides/compiler_integration.md
@@ -112,14 +112,14 @@ with various kernel DSLs and libraries.
## Runtime and State Management for Compilers
-While TVM FFI provides a standard ABI for compiler-generated kernels, many compilers and domain-specific languages
-(DSLs) require their own **runtime** to manage states like dynamic shapes, workspace memory, or other
-application-specific data. This runtime can be a separate shared library accessible to all kernels from a specific
+While TVM FFI provides a standard ABI for compiler-generated kernels, many compilers and domain-specific languages
+(DSLs) require their own **runtime** to manage states like dynamic shapes, workspace memory, or other
+application-specific data. This runtime can be a separate shared library accessible to all kernels from a specific
compiler.
### Recommended Approach for State Management
-The recommended approach for managing compiler-specific state is to define the state within a **separate shared library**.
+The recommended approach for managing compiler-specific state is to define the state within a **separate shared library**.
This library exposes its functionality by registering functions as global `tvm::ffi::Function`s.
Here's a breakdown of the process:
@@ -144,21 +144,21 @@ Here's a breakdown of the process:
This method allows both C++ and Python to access the runtime state through a consistent API.
3. **Access State from Kernels**: Within your compiler-generated kernels, you can use
`GetGlobalRequired("mylang.get_global_state")` in C++ or the C equivalent
- `TVMFFIGetGlobalFunction("mylang.get_global_state", ...)` to get the function and then call it to retrieve the state
+ `TVMFFIGetGlobalFunction("mylang.get_global_state", ...)` to get the function and then call it to retrieve the state
pointer.
### Distributing the Runtime
-For a user to use a kernel from your compiler, they must have access to your runtime library. The preferred method is to
-package the runtime shared library (e.g., `libmylang_runtime.so`) as part of a Python or C++ package. Users must install
-and import this package before loading any kernels compiled by your system.
+For a user to use a kernel from your compiler, they must have access to your runtime library. The preferred method is to
+package the runtime shared library (e.g., `libmylang_runtime.so`) as part of a Python or C++ package. Users must install
+and import this package before loading any kernels compiled by your system.
This approach ensures the state is shared among different kernels.
### Common vs. Custom State
-It's important to distinguish between compiler-specific state and **common state** managed by TVM FFI. TVM FFI handles
-common states like **streams** and **memory allocators** through environment functions (e.g., `TVMFFIEnvGetStream`),
-allowing kernels to access these without managing their own. However, for any unique state required by your compiler,
+It's important to distinguish between compiler-specific state and **common state** managed by TVM FFI. TVM FFI handles
+common states like **streams** and **memory allocators** through environment functions (e.g., `TVMFFIEnvGetStream`),
+allowing kernels to access these without managing their own. However, for any unique state required by your compiler,
the global function registration approach is the most suitable method.
## Advanced: Custom Modules
@@ -196,4 +196,3 @@ the overall import relations from `` and return the final composed
As long as the compiler generates the `__tvm_ffi__library_bin` in the above format, {py:func}`tvm_ffi.load_module` will correctly
handle the loading and recover the original module. Note that we will need the custom module class definition to be available
during loading, either by importing another runtime DLL, or embedding it in the generated library.
-
diff --git a/docs/guides/python_guide.md b/docs/guides/python_guide.md
index fdf03a50..cd997af1 100644
--- a/docs/guides/python_guide.md
+++ b/docs/guides/python_guide.md
@@ -178,7 +178,7 @@ torch.testing.assert_close(x + 1, y)
```
The above code defines a C++ function `add_one_cpu` in Python script, compiles it on the fly and then loads the compiled
-{py:class}`tvm_ffi.Module` object via {py:func}`tvm_ffi.cpp.load_inline`. You can then call the function `add_one_cpu`
+{py:class}`tvm_ffi.Module` object via {py:func}`tvm_ffi.cpp.load_inline`. You can then call the function `add_one_cpu`
from the module as usual.
## Error Handling
diff --git a/docs/requirements.txt b/docs/requirements.txt
index 74784b51..55a85655 100644
--- a/docs/requirements.txt
+++ b/docs/requirements.txt
@@ -1,6 +1,6 @@
autodocsumm
-exhale
breathe
+exhale
linkify-it-py
matplotlib
myst-parser
diff --git a/examples/packaging/python/my_ffi_extension/_ffi_api.py b/examples/packaging/python/my_ffi_extension/_ffi_api.py
index 616b1ee8..5e034899 100644
--- a/examples/packaging/python/my_ffi_extension/_ffi_api.py
+++ b/examples/packaging/python/my_ffi_extension/_ffi_api.py
@@ -17,7 +17,6 @@
import tvm_ffi
# make sure lib is loaded first
-from .base import _LIB
# this is a short cut to register all the global functions
# prefixed by `my_ffi_extension.` to this module
diff --git a/examples/quick_start/run_example.py b/examples/quick_start/run_example.py
index e126af14..698bc2af 100644
--- a/examples/quick_start/run_example.py
+++ b/examples/quick_start/run_example.py
@@ -21,7 +21,6 @@
except ImportError:
torch = None
-import ctypes
import numpy
diff --git a/examples/quick_start/run_example.sh b/examples/quick_start/run_example.sh
index 0602b85f..09d8daa8 100755
--- a/examples/quick_start/run_example.sh
+++ b/examples/quick_start/run_example.sh
@@ -1,3 +1,4 @@
+#!/bin/bash
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
@@ -14,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-#!/bin/bash
set -ex
cmake -B build -S .
diff --git a/include/tvm/ffi/c_api.h b/include/tvm/ffi/c_api.h
index 3dcdf4f8..0ab4d08d 100644
--- a/include/tvm/ffi/c_api.h
+++ b/include/tvm/ffi/c_api.h
@@ -272,16 +272,16 @@ typedef struct {
*/
uint32_t small_str_len;
};
- union { // 8 bytes
- int64_t v_int64; // integers
- double v_float64; // floating-point numbers
- void* v_ptr; // typeless pointers
- const char* v_c_str; // raw C-string
- TVMFFIObject* v_obj; // ref counted objects
- DLDataType v_dtype; // data type
- DLDevice v_device; // device
- char v_bytes[8]; // small string
- uint64_t v_uint64; // uint64 repr mainly used for hashing
+ union { // 8 bytes
+ int64_t v_int64; // integers
+ double v_float64; // floating-point numbers
+ void* v_ptr; // typeless pointers
+ const char* v_c_str; // raw C-string
+ TVMFFIObject* v_obj; // ref counted objects
+ DLDataType v_dtype; // data type
+ DLDevice v_device; // device
+ char v_bytes[8]; // small string
+ uint64_t v_uint64; // uint64 repr mainly used for hashing
};
} TVMFFIAny;
diff --git a/include/tvm/ffi/reflection/registry.h b/include/tvm/ffi/reflection/registry.h
index 6a1a9b55..c0d984f1 100644
--- a/include/tvm/ffi/reflection/registry.h
+++ b/include/tvm/ffi/reflection/registry.h
@@ -113,7 +113,7 @@ class AttachFieldFlag : public FieldInfoTrait {
* \returns The byteoffset
*/
template
-TVM_FFI_INLINE int64_t GetFieldByteOffsetToObject(T Class::*field_ptr) {
+TVM_FFI_INLINE int64_t GetFieldByteOffsetToObject(T Class::* field_ptr) {
int64_t field_offset_to_class =
reinterpret_cast(&(static_cast(nullptr)->*field_ptr));
return field_offset_to_class - details::ObjectUnsafe::GetObjectOffsetToSubclass();
@@ -350,7 +350,7 @@ class ObjectDef : public ReflectionDefBase {
* \return The reflection definition.
*/
template
- TVM_FFI_INLINE ObjectDef& def_ro(const char* name, T BaseClass::*field_ptr, Extra&&... extra) {
+ TVM_FFI_INLINE ObjectDef& def_ro(const char* name, T BaseClass::* field_ptr, Extra&&... extra) {
RegisterField(name, field_ptr, false, std::forward(extra)...);
return *this;
}
@@ -369,7 +369,7 @@ class ObjectDef : public ReflectionDefBase {
* \return The reflection definition.
*/
template
- TVM_FFI_INLINE ObjectDef& def_rw(const char* name, T BaseClass::*field_ptr, Extra&&... extra) {
+ TVM_FFI_INLINE ObjectDef& def_rw(const char* name, T BaseClass::* field_ptr, Extra&&... extra) {
static_assert(Class::_type_mutable, "Only mutable classes are supported for writable fields");
RegisterField(name, field_ptr, true, std::forward(extra)...);
return *this;
@@ -430,7 +430,7 @@ class ObjectDef : public ReflectionDefBase {
}
template
- void RegisterField(const char* name, T BaseClass::*field_ptr, bool writable,
+ void RegisterField(const char* name, T BaseClass::* field_ptr, bool writable,
ExtraArgs&&... extra_args) {
static_assert(std::is_base_of_v, "BaseClass must be a base class of Class");
TVMFFIFieldInfo info;
diff --git a/python/tvm_ffi/__init__.py b/python/tvm_ffi/__init__.py
index 9bafe2b7..b3b070fb 100644
--- a/python/tvm_ffi/__init__.py
+++ b/python/tvm_ffi/__init__.py
@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""TVM FFI Python package."""
+
# order matters here so we need to skip isort here
# isort: skip_file
# base always go first to load the libtvm_ffi
diff --git a/python/tvm_ffi/_convert.py b/python/tvm_ffi/_convert.py
index 7c7b515e..cf311b20 100644
--- a/python/tvm_ffi/_convert.py
+++ b/python/tvm_ffi/_convert.py
@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""Conversion utilities to bring python objects into ffi values."""
+
from numbers import Number
from typing import Any
diff --git a/python/tvm_ffi/_dtype.py b/python/tvm_ffi/_dtype.py
index 30409e41..1664d981 100644
--- a/python/tvm_ffi/_dtype.py
+++ b/python/tvm_ffi/_dtype.py
@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""dtype class."""
+
# pylint: disable=invalid-name
from enum import IntEnum
@@ -83,7 +84,10 @@ def with_lanes(self, lanes):
The new dtype with the given number of lanes.
"""
cdtype = core._create_dtype_from_tuple(
- core.DataType, self.__tvm_ffi_dtype__.type_code, self.__tvm_ffi_dtype__.bits, lanes
+ core.DataType,
+ self.__tvm_ffi_dtype__.type_code,
+ self.__tvm_ffi_dtype__.bits,
+ lanes,
)
val = str.__new__(dtype, str(cdtype))
val.__tvm_ffi_dtype__ = cdtype
diff --git a/python/tvm_ffi/_ffi_api.py b/python/tvm_ffi/_ffi_api.py
index 1c2326c0..f9314dd7 100644
--- a/python/tvm_ffi/_ffi_api.py
+++ b/python/tvm_ffi/_ffi_api.py
@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""FFI API."""
+
from . import registry
registry.init_ffi_api("ffi", __name__)
diff --git a/python/tvm_ffi/_optional_torch_c_dlpack.py b/python/tvm_ffi/_optional_torch_c_dlpack.py
index 6ff77cc4..b96e9d06 100644
--- a/python/tvm_ffi/_optional_torch_c_dlpack.py
+++ b/python/tvm_ffi/_optional_torch_c_dlpack.py
@@ -29,6 +29,7 @@
This module will load slowly at first time due to JITing,
subsequent calls will be much faster.
"""
+
import warnings
from . import libinfo
diff --git a/python/tvm_ffi/base.py b/python/tvm_ffi/base.py
index fbcc01c3..8099955f 100644
--- a/python/tvm_ffi/base.py
+++ b/python/tvm_ffi/base.py
@@ -16,6 +16,7 @@
# under the License.
# coding: utf-8
"""Base library for TVM FFI."""
+
import ctypes
import logging
import os
diff --git a/python/tvm_ffi/config.py b/python/tvm_ffi/config.py
index 4e87caaa..7e036806 100644
--- a/python/tvm_ffi/config.py
+++ b/python/tvm_ffi/config.py
@@ -37,16 +37,28 @@ def __main__():
description="Get various configuration information needed to compile with tvm-ffi"
)
- parser.add_argument("--includedir", action="store_true", help="Print include directory")
parser.add_argument(
- "--dlpack-includedir", action="store_true", help="Print dlpack include directory"
+ "--includedir", action="store_true", help="Print include directory"
+ )
+ parser.add_argument(
+ "--dlpack-includedir",
+ action="store_true",
+ help="Print dlpack include directory",
+ )
+ parser.add_argument(
+ "--cmakedir", action="store_true", help="Print library directory"
+ )
+ parser.add_argument(
+ "--sourcedir", action="store_true", help="Print source directory"
+ )
+ parser.add_argument(
+ "--libfiles", action="store_true", help="Fully qualified library filenames"
)
- parser.add_argument("--cmakedir", action="store_true", help="Print library directory")
- parser.add_argument("--sourcedir", action="store_true", help="Print source directory")
- parser.add_argument("--libfiles", action="store_true", help="Fully qualified library filenames")
parser.add_argument("--libdir", action="store_true", help="Print library directory")
parser.add_argument("--libs", action="store_true", help="Libraries to be linked")
- parser.add_argument("--cython-lib-path", action="store_true", help="Print cython path")
+ parser.add_argument(
+ "--cython-lib-path", action="store_true", help="Print cython path"
+ )
parser.add_argument("--cxxflags", action="store_true", help="Print cxx flags")
parser.add_argument("--cflags", action="store_true", help="Print c flags")
parser.add_argument("--ldflags", action="store_true", help="Print ld flags")
diff --git a/python/tvm_ffi/container.py b/python/tvm_ffi/container.py
index f64028f5..8368cd41 100644
--- a/python/tvm_ffi/container.py
+++ b/python/tvm_ffi/container.py
@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""Container classes."""
+
import collections.abc
from typing import Any, Mapping, Sequence
@@ -248,4 +249,8 @@ def __repr__(self):
# exception safety handling for chandle=None
if self.__chandle__() == 0:
return type(self).__name__ + "(chandle=None)"
- return "{" + ", ".join([f"{k.__repr__()}: {v.__repr__()}" for k, v in self.items()]) + "}"
+ return (
+ "{"
+ + ", ".join([f"{k.__repr__()}: {v.__repr__()}" for k, v in self.items()])
+ + "}"
+ )
diff --git a/python/tvm_ffi/cpp/load_inline.py b/python/tvm_ffi/cpp/load_inline.py
index ced97058..6ce3d11d 100644
--- a/python/tvm_ffi/cpp/load_inline.py
+++ b/python/tvm_ffi/cpp/load_inline.py
@@ -86,7 +86,9 @@ def _find_cuda_home() -> Optional[str]:
else:
# Guess #3
if IS_WINDOWS:
- cuda_homes = glob.glob("C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v*.*")
+ cuda_homes = glob.glob(
+ "C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v*.*"
+ )
if len(cuda_homes) == 0:
cuda_home = ""
else:
@@ -162,7 +164,9 @@ def _run_command_in_dev_prompt(args, cwd, capture_output):
raise FileNotFoundError("No Visual Studio installation found.")
# Construct the path to the VsDevCmd.bat file
- vsdevcmd_path = os.path.join(vs_install_path, "Common7", "Tools", "VsDevCmd.bat")
+ vsdevcmd_path = os.path.join(
+ vs_install_path, "Common7", "Tools", "VsDevCmd.bat"
+ )
if not os.path.exists(vsdevcmd_path):
raise FileNotFoundError(f"VsDevCmd.bat not found at: {vsdevcmd_path}")
@@ -175,7 +179,9 @@ def _run_command_in_dev_prompt(args, cwd, capture_output):
)
# Execute the command in a new shell
- return subprocess.run(cmd_command, cwd=cwd, capture_output=capture_output, shell=True)
+ return subprocess.run(
+ cmd_command, cwd=cwd, capture_output=capture_output, shell=True
+ )
except (FileNotFoundError, subprocess.CalledProcessError) as e:
raise RuntimeError(
@@ -217,7 +223,11 @@ def _generate_ninja_build(
"/EHsc",
]
default_cuda_cflags = ["-Xcompiler", "/std:c++17", "/O2"]
- default_ldflags = ["/DLL", f"/LIBPATH:{tvm_ffi_lib_path}", f"{tvm_ffi_lib_name}.lib"]
+ default_ldflags = [
+ "/DLL",
+ f"/LIBPATH:{tvm_ffi_lib_path}",
+ f"{tvm_ffi_lib_name}.lib",
+ ]
else:
default_cflags = ["-std=c++17", "-fPIC", "-O2"]
default_cuda_cflags = ["-Xcompiler", "-fPIC", "-std=c++17", "-O2"]
@@ -226,12 +236,17 @@ def _generate_ninja_build(
if with_cuda:
# determine the compute capability of the current GPU
default_cuda_cflags += [_get_cuda_target()]
- default_ldflags += ["-L{}".format(os.path.join(_find_cuda_home(), "lib64")), "-lcudart"]
+ default_ldflags += [
+ "-L{}".format(os.path.join(_find_cuda_home(), "lib64")),
+ "-lcudart",
+ ]
cflags = default_cflags + [flag.strip() for flag in extra_cflags]
cuda_cflags = default_cuda_cflags + [flag.strip() for flag in extra_cuda_cflags]
ldflags = default_ldflags + [flag.strip() for flag in extra_ldflags]
- include_paths = default_include_paths + [os.path.abspath(path) for path in extra_include_paths]
+ include_paths = default_include_paths + [
+ os.path.abspath(path) for path in extra_include_paths
+ ]
# append include paths
for path in include_paths:
@@ -241,7 +256,9 @@ def _generate_ninja_build(
# flags
ninja = []
ninja.append("ninja_required_version = 1.3")
- ninja.append("cxx = {}".format(os.environ.get("CXX", "cl" if IS_WINDOWS else "c++")))
+ ninja.append(
+ "cxx = {}".format(os.environ.get("CXX", "cl" if IS_WINDOWS else "c++"))
+ )
ninja.append("cflags = {}".format(" ".join(cflags)))
if with_cuda:
ninja.append("nvcc = {}".format(os.path.join(_find_cuda_home(), "bin", "nvcc")))
@@ -290,7 +307,9 @@ def _generate_ninja_build(
)
# Use appropriate extension based on platform
ext = ".dll" if IS_WINDOWS else ".so"
- ninja.append("build {}{}: link main.o{}".format(name, ext, " cuda.o" if with_cuda else ""))
+ ninja.append(
+ "build {}{}: link main.o{}".format(name, ext, " cuda.o" if with_cuda else "")
+ )
ninja.append("")
# default target
@@ -306,7 +325,9 @@ def _build_ninja(build_dir: str) -> None:
if num_workers is not None:
command += ["-j", num_workers]
if IS_WINDOWS:
- status = _run_command_in_dev_prompt(args=command, cwd=build_dir, capture_output=True)
+ status = _run_command_in_dev_prompt(
+ args=command, cwd=build_dir, capture_output=True
+ )
else:
status = subprocess.run(args=command, cwd=build_dir, capture_output=True)
if status.returncode != 0:
@@ -508,7 +529,9 @@ def load_inline(
extra_ldflags,
extra_include_paths,
)
- build_dir: str = os.path.join(build_directory, "{}_{}".format(name, source_hash))
+ build_dir: str = os.path.join(
+ build_directory, "{}_{}".format(name, source_hash)
+ )
else:
build_dir = os.path.abspath(build_directory)
os.makedirs(build_dir, exist_ok=True)
@@ -536,4 +559,6 @@ def load_inline(
# Use appropriate extension based on platform
ext = ".dll" if IS_WINDOWS else ".so"
- return load_module(os.path.abspath(os.path.join(build_dir, "{}{}".format(name, ext))))
+ return load_module(
+ os.path.abspath(os.path.join(build_dir, "{}{}".format(name, ext)))
+ )
diff --git a/python/tvm_ffi/cython/base.pxi b/python/tvm_ffi/cython/base.pxi
index 77c9c7e8..6fe10fd6 100644
--- a/python/tvm_ffi/cython/base.pxi
+++ b/python/tvm_ffi/cython/base.pxi
@@ -108,7 +108,6 @@ cdef extern from "tvm/ffi/c_api.h":
kTVMFFIModule = 73
kTVMFFIOpaquePyObject = 74
-
ctypedef void* TVMFFIObjectHandle
ctypedef struct TVMFFIObject:
@@ -153,9 +152,9 @@ cdef extern from "tvm/ffi/c_api.h":
kTVMFFIFieldFlagBitMaskHasDefault = 1 << 1
kTVMFFIFieldFlagBitMaskIsStaticMethod = 1 << 2
- ctypedef int (*TVMFFIFieldGetter)(void* field, TVMFFIAny* result) noexcept;
- ctypedef int (*TVMFFIFieldSetter)(void* field, const TVMFFIAny* value) noexcept;
- ctypedef int (*TVMFFIObjectCreator)(TVMFFIObjectHandle* result) noexcept;
+ ctypedef int (*TVMFFIFieldGetter)(void* field, TVMFFIAny* result) noexcept
+ ctypedef int (*TVMFFIFieldSetter)(void* field, const TVMFFIAny* value) noexcept
+ ctypedef int (*TVMFFIObjectCreator)(TVMFFIObjectHandle* result) noexcept
ctypedef struct TVMFFIFieldInfo:
TVMFFIByteArray name
@@ -202,7 +201,7 @@ cdef extern from "tvm/ffi/c_api.h":
int TVMFFIFunctionCall(TVMFFIObjectHandle func, TVMFFIAny* args, int32_t num_args,
TVMFFIAny* result) nogil
int TVMFFIFunctionCreate(void* self, TVMFFISafeCallType safe_call,
- void (*deleter)(void*), TVMFFIObjectHandle* out) nogil
+ void (*deleter)(void*), TVMFFIObjectHandle* out) nogil
int TVMFFIAnyViewToOwnedAny(const TVMFFIAny* any_view, TVMFFIAny* out) nogil
int TVMFFIFunctionSetGlobal(TVMFFIByteArray* name, TVMFFIObjectHandle f, int override) nogil
int TVMFFIFunctionGetGlobal(TVMFFIByteArray* name, TVMFFIObjectHandle* out) nogil
@@ -216,17 +215,17 @@ cdef extern from "tvm/ffi/c_api.h":
int TVMFFIBytesFromByteArray(TVMFFIByteArray* input_, TVMFFIAny* out) nogil
int TVMFFIDataTypeFromString(TVMFFIByteArray* str, DLDataType* out) nogil
int TVMFFIDataTypeToString(const DLDataType* dtype, TVMFFIAny* out) nogil
- const TVMFFIByteArray* TVMFFITraceback(
- const char* filename, int lineno, const char* func, int cross_ffi_boundary) nogil;
+ const TVMFFIByteArray* TVMFFITraceback(const char* filename, int lineno,
+ const char* func, int cross_ffi_boundary) nogil
int TVMFFITensorFromDLPack(DLManagedTensor* src, int32_t require_alignment,
- int32_t require_contiguous, TVMFFIObjectHandle* out) nogil
+ int32_t require_contiguous, TVMFFIObjectHandle* out) nogil
int TVMFFITensorFromDLPackVersioned(DLManagedTensorVersioned* src,
int32_t require_alignment,
int32_t require_contiguous,
TVMFFIObjectHandle* out) nogil
int TVMFFITensorToDLPack(TVMFFIObjectHandle src, DLManagedTensor** out) nogil
int TVMFFITensorToDLPackVersioned(TVMFFIObjectHandle src,
- DLManagedTensorVersioned** out) nogil
+ DLManagedTensorVersioned** out) nogil
const TVMFFITypeInfo* TVMFFIGetTypeInfo(int32_t type_index) nogil
TVMFFIByteArray TVMFFISmallBytesGetContentByteArray(const TVMFFIAny* value) nogil
TVMFFIByteArray* TVMFFIBytesGetByteArrayPtr(TVMFFIObjectHandle obj) nogil
@@ -241,9 +240,9 @@ cdef extern from "tvm/ffi/extra/c_env_api.h":
int TVMFFIEnvRegisterCAPI(const char* name, void* ptr) nogil
void* TVMFFIEnvGetStream(int32_t device_type, int32_t device_id) nogil
- int TVMFFIEnvSetStream(int32_t device_type, int32_t device_id,
- TVMFFIStreamHandle stream,
- TVMFFIStreamHandle* opt_out_original_stream) nogil
+ int TVMFFIEnvSetStream(int32_t device_type, int32_t device_id, TVMFFIStreamHandle stream,
+ TVMFFIStreamHandle* opt_out_original_stream) nogil
+
def _env_set_current_stream(int device_type, int device_id, uint64_t stream):
cdef TVMFFIStreamHandle prev_stream = NULL
@@ -256,8 +255,7 @@ def _env_set_current_stream(int device_type, int device_id, uint64_t stream):
cdef extern from "tvm_ffi_python_helpers.h":
- # no need to expose fields of the call context
- # setter data structure
+ # no need to expose fields of the call context setter data structure
ctypedef int (*DLPackFromPyObject)(
void* py_obj, DLManagedTensorVersioned** out, TVMFFIStreamHandle* env_stream
) except -1
diff --git a/python/tvm_ffi/cython/device.pxi b/python/tvm_ffi/cython/device.pxi
index 85740a06..84c047f4 100644
--- a/python/tvm_ffi/cython/device.pxi
+++ b/python/tvm_ffi/cython/device.pxi
@@ -20,6 +20,7 @@ from enum import IntEnum
_CLASS_DEVICE = None
+
def _set_class_device(cls):
global _CLASS_DEVICE
_CLASS_DEVICE = cls
@@ -162,7 +163,6 @@ cdef class Device:
def __hash__(self):
return hash((self.cdevice.device_type, self.cdevice.device_id))
-
def __device_type_name__(self):
return self._DEVICE_TYPE_TO_NAME[self.cdevice.device_type]
diff --git a/python/tvm_ffi/cython/dtype.pxi b/python/tvm_ffi/cython/dtype.pxi
index 4656b04e..0036c80b 100644
--- a/python/tvm_ffi/cython/dtype.pxi
+++ b/python/tvm_ffi/cython/dtype.pxi
@@ -18,6 +18,7 @@
_CLASS_DTYPE = None
+
def _set_class_dtype(cls):
global _CLASS_DTYPE
_CLASS_DTYPE = cls
diff --git a/python/tvm_ffi/cython/function.pxi b/python/tvm_ffi/cython/function.pxi
index c80e238d..c4662ca1 100644
--- a/python/tvm_ffi/cython/function.pxi
+++ b/python/tvm_ffi/cython/function.pxi
@@ -88,9 +88,9 @@ cdef inline object make_ret(TVMFFIAny result, DLPackToPyObject c_dlpack_to_pyobj
raise ValueError("Unhandled type index %d" % type_index)
-##----------------------------------------------------------------------------
-## Helper to simplify calling constructor
-##----------------------------------------------------------------------------
+# ----------------------------------------------------------------------------
+# Helper to simplify calling constructor
+# ----------------------------------------------------------------------------
cdef inline int ConstructorCall(void* constructor_handle,
PyObject* py_arg_tuple,
void** handle,
@@ -109,9 +109,9 @@ cdef inline int ConstructorCall(void* constructor_handle,
handle[0] = result.v_ptr
return 0
-##----------------------------------------------------------------------------
-## Implementation of setters using same naming style as TVMFFIPyArgSetterXXX_
-##----------------------------------------------------------------------------
+# ----------------------------------------------------------------------------
+# Implementation of setters using same naming style as TVMFFIPyArgSetterXXX_
+# ----------------------------------------------------------------------------
cdef int TVMFFIPyArgSetterTensor_(
TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx,
PyObject* arg, TVMFFIAny* out
@@ -219,8 +219,7 @@ cdef int TVMFFIPyArgSetterDLPack_(
out.v_ptr = temp_chandle
# record the stream from the source framework context when possible
temp_dltensor = TVMFFITensorGetDLTensorPtr(temp_chandle)
- if (temp_dltensor.device.device_type != kDLCPU and
- ctx.device_type != -1):
+ if (temp_dltensor.device.device_type != kDLCPU and ctx.device_type != -1):
# __tvm_ffi_env_stream__ returns the expected stream that should be set
# through TVMFFIEnvSetStream when calling a TVM FFI function
if hasattr(arg, "__tvm_ffi_env_stream__"):
@@ -571,9 +570,9 @@ cdef int TVMFFIPyArgSetterFactory_(PyObject* value, TVMFFIPyArgSetter* out) exce
out.func = TVMFFIPyArgSetterFallback_
return 0
-#---------------------------------------------------------------------------------------------
-## Implementation of function calling
-#---------------------------------------------------------------------------------------------
+# ---------------------------------------------------------------------------------------------
+# Implementation of function calling
+# ---------------------------------------------------------------------------------------------
cdef class Function(Object):
"""Python class that wraps a function with tvm-ffi ABI.
@@ -591,6 +590,7 @@ cdef class Function(Object):
property release_gil:
def __get__(self):
return self.c_release_gil != 0
+
def __set__(self, value):
self.c_release_gil = value
@@ -747,7 +747,7 @@ def _get_global_func(name, allow_missing):
return ret
if allow_missing:
- return None
+ return None
raise ValueError("Cannot find global function %s" % name)
@@ -835,7 +835,7 @@ def _convert_to_opaque_object(object pyobject):
def _print_debug_info():
"""Get the size of the dispatch map"""
- cdef size_t size = TVMFFIPyGetDispatchMapSize()
+ cdef size_t size = TVMFFIPyGetDispatchMapSize()
print(f"TVMFFIPyGetDispatchMapSize: {size}")
diff --git a/python/tvm_ffi/cython/object.pxi b/python/tvm_ffi/cython/object.pxi
index 90821e34..f47018ca 100644
--- a/python/tvm_ffi/cython/object.pxi
+++ b/python/tvm_ffi/cython/object.pxi
@@ -250,6 +250,7 @@ def _object_type_key_to_index(str type_key):
return tidx
return None
+
cdef inline str _type_index_to_key(int32_t tindex):
"""get the type key of object class"""
cdef const TVMFFITypeInfo* info = TVMFFIGetTypeInfo(tindex)
diff --git a/python/tvm_ffi/cython/string.pxi b/python/tvm_ffi/cython/string.pxi
index 4119e7b8..0f9d11ba 100644
--- a/python/tvm_ffi/cython/string.pxi
+++ b/python/tvm_ffi/cython/string.pxi
@@ -28,7 +28,6 @@ cdef inline bytes _bytes_obj_get_py_bytes(obj):
return bytearray_to_bytes(bytes)
-
class String(str, PyNativeObject):
__slots__ = ["__tvm_ffi_object__"]
"""String object that is possibly returned by FFI call.
diff --git a/python/tvm_ffi/cython/tensor.pxi b/python/tvm_ffi/cython/tensor.pxi
index 1255f0b0..dc8b75ee 100644
--- a/python/tvm_ffi/cython/tensor.pxi
+++ b/python/tvm_ffi/cython/tensor.pxi
@@ -106,7 +106,7 @@ cdef inline int _from_dlpack_universal(
# move to false as most frameworks get upgraded.
cdef int favor_legacy_dlpack = True
- if hasattr(ext_tensor, '__dlpack__'):
+ if hasattr(ext_tensor, "__dlpack__"):
if favor_legacy_dlpack:
_from_dlpack(
ext_tensor.__dlpack__(),
@@ -305,6 +305,7 @@ cdef class DLTensorTestWrapper:
cdef Tensor tensor
cdef dict __dict__
+
def __init__(self, tensor):
self.tensor = tensor
diff --git a/python/tvm_ffi/error.py b/python/tvm_ffi/error.py
index b09200d6..cec6956e 100644
--- a/python/tvm_ffi/error.py
+++ b/python/tvm_ffi/error.py
@@ -16,6 +16,7 @@
# under the License.
# pylint: disable=invalid-name
"""Error handling."""
+
import ast
import re
import sys
diff --git a/python/tvm_ffi/libinfo.py b/python/tvm_ffi/libinfo.py
index 55c51406..8325c355 100644
--- a/python/tvm_ffi/libinfo.py
+++ b/python/tvm_ffi/libinfo.py
@@ -75,7 +75,9 @@ def find_libtvm_ffi():
lib_found = [p for p in lib_dll_path if os.path.exists(p) and os.path.isfile(p)]
if not lib_found:
- raise RuntimeError(f"Cannot find library: {name}\nList of candidates:\n{lib_dll_path}")
+ raise RuntimeError(
+ f"Cannot find library: {name}\nList of candidates:\n{lib_dll_path}"
+ )
return lib_found[0]
@@ -108,7 +110,9 @@ def find_include_path():
"""Find header files for C compilation."""
candidates = [
os.path.join(os.path.dirname(os.path.realpath(__file__)), "include"),
- os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..", "include"),
+ os.path.join(
+ os.path.dirname(os.path.realpath(__file__)), "..", "..", "include"
+ ),
]
for candidate in candidates:
if os.path.isdir(candidate):
@@ -130,12 +134,19 @@ def find_python_helper_include_path():
def find_dlpack_include_path():
"""Find dlpack header files for C compilation."""
- install_include_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "include")
+ install_include_path = os.path.join(
+ os.path.dirname(os.path.realpath(__file__)), "include"
+ )
if os.path.isdir(os.path.join(install_include_path, "dlpack")):
return install_include_path
source_include_path = os.path.join(
- os.path.dirname(os.path.realpath(__file__)), "..", "..", "3rdparty", "dlpack", "include"
+ os.path.dirname(os.path.realpath(__file__)),
+ "..",
+ "..",
+ "3rdparty",
+ "dlpack",
+ "include",
)
if os.path.isdir(source_include_path):
return source_include_path
diff --git a/python/tvm_ffi/module.py b/python/tvm_ffi/module.py
index 103956e6..fbfb35de 100644
--- a/python/tvm_ffi/module.py
+++ b/python/tvm_ffi/module.py
@@ -200,7 +200,9 @@ def is_compilation_exportable(self):
b : Bool
True if the module is compilation exportable.
"""
- return (self.get_property_mask() & ModulePropertyMask.COMPILATION_EXPORTABLE) != 0
+ return (
+ self.get_property_mask() & ModulePropertyMask.COMPILATION_EXPORTABLE
+ ) != 0
def clear_imports(self):
"""Remove all imports of the module."""
diff --git a/python/tvm_ffi/registry.py b/python/tvm_ffi/registry.py
index 2cd1ba14..f31dea34 100644
--- a/python/tvm_ffi/registry.py
+++ b/python/tvm_ffi/registry.py
@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""FFI registry to register function and objects."""
+
import sys
from . import core
diff --git a/python/tvm_ffi/stream.py b/python/tvm_ffi/stream.py
index 598afcac..084dca83 100644
--- a/python/tvm_ffi/stream.py
+++ b/python/tvm_ffi/stream.py
@@ -16,6 +16,7 @@
# under the License.
# pylint: disable=invalid-name
"""Stream context."""
+
from ctypes import c_void_p
from typing import Any, Optional, Union
diff --git a/python/tvm_ffi/utils/lockfile.py b/python/tvm_ffi/utils/lockfile.py
index 3b3197e2..55ab41f3 100644
--- a/python/tvm_ffi/utils/lockfile.py
+++ b/python/tvm_ffi/utils/lockfile.py
@@ -64,7 +64,9 @@ def acquire(self):
)
msvcrt.locking(self._file_descriptor, msvcrt.LK_NBLCK, 1)
else: # Unix-like systems
- self._file_descriptor = os.open(self.lock_file_path, os.O_WRONLY | os.O_CREAT)
+ self._file_descriptor = os.open(
+ self.lock_file_path, os.O_WRONLY | os.O_CREAT
+ )
fcntl.flock(self._file_descriptor, fcntl.LOCK_EX | fcntl.LOCK_NB)
return True
except (IOError, BlockingIOError):
diff --git a/tests/lint/check_asf_header.py b/tests/lint/check_asf_header.py
index d5d9bf5b..48df9541 100644
--- a/tests/lint/check_asf_header.py
+++ b/tests/lint/check_asf_header.py
@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""Helper tool to add ASF header to files that cannot be handled by Rat."""
+
import argparse
import fnmatch
import os
@@ -187,7 +188,9 @@ def get_git_files():
if result.returncode == 0:
return [line.strip() for line in result.stdout.split("\n") if line.strip()]
else:
- print("Error: Could not get git files. Make sure you're in a git repository.")
+ print(
+ "Error: Could not get git files. Make sure you're in a git repository."
+ )
print("Git command failed:", result.stderr.strip())
return None
except FileNotFoundError:
@@ -343,7 +346,9 @@ def main():
)
parser.add_argument(
- "--check", action="store_true", help="Check mode: report errors without modifying files"
+ "--check",
+ action="store_true",
+ help="Check mode: report errors without modifying files",
)
parser.add_argument(
diff --git a/tests/lint/check_file_type.py b/tests/lint/check_file_type.py
index b44d5f1a..d6664703 100644
--- a/tests/lint/check_file_type.py
+++ b/tests/lint/check_file_type.py
@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""Helper tool to check file types that are allowed to checkin."""
+
import os
import subprocess
import sys
@@ -180,7 +181,7 @@ def main():
cmd = ["git", "ls-files"]
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
(out, _) = proc.communicate()
- assert proc.returncode == 0, f'{" ".join(cmd)} errored: {out}'
+ assert proc.returncode == 0, f"{' '.join(cmd)} errored: {out}"
res = out.decode("utf-8")
flist = res.split()
error_list = []
@@ -211,11 +212,14 @@ def main():
if asf_copyright_list:
report = "------File type check report----\n"
report += "\n".join(asf_copyright_list) + "\n"
- report += "------Found %d files that has ASF header with copyright message----\n" % len(
- asf_copyright_list
+ report += (
+ "------Found %d files that has ASF header with copyright message----\n"
+ % len(asf_copyright_list)
)
report += "--- Files with ASF header do not need Copyright lines.\n"
- report += "--- Contributors retain copyright to their contribution by default.\n"
+ report += (
+ "--- Contributors retain copyright to their contribution by default.\n"
+ )
report += "--- If a file comes with a different license, consider put it under the 3rdparty folder instead.\n"
report += "---\n"
report += "--- You can use the following steps to remove the copyright lines\n"
diff --git a/tests/lint/git-clang-format.sh b/tests/lint/git-clang-format.sh
index 70b3c5b4..fee48039 100755
--- a/tests/lint/git-clang-format.sh
+++ b/tests/lint/git-clang-format.sh
@@ -19,77 +19,74 @@ set -e
set -u
set -o pipefail
-
INPLACE_FORMAT=${INPLACE_FORMAT:=false}
LINT_ALL_FILES=true
REVISION=$(git rev-list --max-parents=0 HEAD)
-while (( $# )); do
- case "$1" in
- -i)
- INPLACE_FORMAT=true
- shift 1
- ;;
- --rev)
- LINT_ALL_FILES=false
- REVISION=$2
- shift 2
- ;;
- *)
- echo "Usage: tests/lint/git-clang-format.sh [-i] [--rev ]"
- echo ""
- echo "Run clang-format on files that changed since or on all files in the repo"
- echo "Examples:"
- echo "- Compare last one commit: tests/lint/git-clang-format.sh --rev HEAD~1"
- echo "- Compare against upstream/main: tests/lint/git-clang-format.sh --rev upstream/main"
- echo "The -i will use black to format files in-place instead of checking them."
- exit 1
- ;;
- esac
+while (($#)); do
+ case "$1" in
+ -i)
+ INPLACE_FORMAT=true
+ shift 1
+ ;;
+ --rev)
+ LINT_ALL_FILES=false
+ REVISION=$2
+ shift 2
+ ;;
+ *)
+ echo "Usage: tests/lint/git-clang-format.sh [-i] [--rev ]"
+ echo ""
+ echo "Run clang-format on files that changed since or on all files in the repo"
+ echo "Examples:"
+ echo "- Compare last one commit: tests/lint/git-clang-format.sh --rev HEAD~1"
+ echo "- Compare against upstream/main: tests/lint/git-clang-format.sh --rev upstream/main"
+ echo "The -i will use black to format files in-place instead of checking them."
+ exit 1
+ ;;
+ esac
done
-
-cleanup()
-{
- if [ -f /tmp/$$.clang-format.txt ]; then
- echo ""
- echo "---------clang-format log----------"
- cat /tmp/$$.clang-format.txt
- fi
- rm -rf /tmp/$$.clang-format.txt
+cleanup() {
+ if [ -f /tmp/$$.clang-format.txt ]; then
+ echo ""
+ echo "---------clang-format log----------"
+ cat /tmp/$$.clang-format.txt
+ fi
+ rm -rf /tmp/$$.clang-format.txt
}
trap cleanup 0
CLANG_FORMAT=clang-format-15
if [ -x "$(command -v clang-format-15)" ]; then
- CLANG_FORMAT=clang-format-15
+ CLANG_FORMAT=clang-format-15
elif [ -x "$(command -v clang-format)" ]; then
- echo "clang-format might be different from clang-format-15, expect potential difference."
- CLANG_FORMAT=clang-format
+ echo "clang-format might be different from clang-format-15, expect potential difference."
+ CLANG_FORMAT=clang-format
else
- echo "Cannot find clang-format-15"
- exit 1
+ echo "Cannot find clang-format-15"
+ exit 1
fi
# Print out specific version
${CLANG_FORMAT} --version
if [[ "$INPLACE_FORMAT" == "true" ]]; then
- echo "Running inplace git-clang-format against $REVISION"
- git-${CLANG_FORMAT} --extensions h,mm,c,cc,cu --binary=${CLANG_FORMAT} "$REVISION"
- exit 0
+ echo "Running inplace git-clang-format against $REVISION"
+ git-${CLANG_FORMAT} --extensions h,mm,c,cc,cu --binary=${CLANG_FORMAT} "$REVISION"
+ exit 0
fi
if [[ "$LINT_ALL_FILES" == "true" ]]; then
- echo "Running git-clang-format against all C++ files"
- git-${CLANG_FORMAT} --diff --extensions h,mm,c,cc,cu --binary=${CLANG_FORMAT} "$REVISION" 1> /tmp/$$.clang-format.txt
+ echo "Running git-clang-format against all C++ files"
+ git-${CLANG_FORMAT} --diff --extensions h,mm,c,cc,cu --binary=${CLANG_FORMAT} "$REVISION" 1>/tmp/$$.clang-format.txt
else
- echo "Running git-clang-format against $REVISION"
- git-${CLANG_FORMAT} --diff --extensions h,mm,c,cc,cu --binary=${CLANG_FORMAT} "$REVISION" 1> /tmp/$$.clang-format.txt
+ echo "Running git-clang-format against $REVISION"
+ git-${CLANG_FORMAT} --diff --extensions h,mm,c,cc,cu --binary=${CLANG_FORMAT} "$REVISION" 1>/tmp/$$.clang-format.txt
fi
-if grep --quiet -E "diff" < /tmp/$$.clang-format.txt; then
- echo "clang-format lint error found. Consider running clang-format-15 on these files to fix them."
- exit 1
+if grep --quiet -E "diff"