From 73b20a1df517aa1b6304d3387f9a4e4c713d1944 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Wed, 17 Sep 2025 11:26:52 -0700 Subject: [PATCH 1/3] chore: more modernization --- examples/inline_module/main.py | 1 - .../python/my_ffi_extension/__init__.py | 8 +- .../python/my_ffi_extension/_ffi_api.py | 1 + pyproject.toml | 77 ++++++++-------- python/tvm_ffi/__init__.py | 36 ++++---- python/tvm_ffi/_convert.py | 3 +- python/tvm_ffi/_dtype.py | 10 +- python/tvm_ffi/_optional_torch_c_dlpack.py | 7 +- python/tvm_ffi/_tensor.py | 8 +- python/tvm_ffi/access_path.py | 34 ++++--- python/tvm_ffi/config.py | 26 ++---- python/tvm_ffi/container.py | 21 +++-- python/tvm_ffi/cpp/load_inline.py | 59 +++++------- python/tvm_ffi/error.py | 19 ++-- python/tvm_ffi/libinfo.py | 15 +-- python/tvm_ffi/module.py | 27 ++++-- python/tvm_ffi/registry.py | 29 +++--- python/tvm_ffi/serialization.py | 10 +- python/tvm_ffi/stream.py | 9 +- python/tvm_ffi/testing.py | 12 +-- python/tvm_ffi/utils/lockfile.py | 27 ++---- tests/lint/check_asf_header.py | 29 +++--- tests/lint/check_file_type.py | 12 +-- tests/lint/git-clang-format.sh | 92 ------------------- tests/python/test_access_path.py | 27 ++---- tests/python/test_container.py | 3 +- tests/python/test_device.py | 5 +- tests/python/test_dtype.py | 1 - tests/python/test_error.py | 1 - tests/python/test_function.py | 3 +- tests/python/test_object.py | 5 +- tests/python/test_stream.py | 1 - tests/python/test_tensor.py | 1 - tests/scripts/benchmark_dlpack.py | 73 +++++---------- tests/scripts/task_lint.sh | 46 ---------- 35 files changed, 266 insertions(+), 472 deletions(-) delete mode 100755 tests/lint/git-clang-format.sh delete mode 100755 tests/scripts/task_lint.sh diff --git a/examples/inline_module/main.py b/examples/inline_module/main.py index 98b939e4..2477cffc 100644 --- a/examples/inline_module/main.py +++ b/examples/inline_module/main.py @@ -16,7 +16,6 @@ # under the License. import torch - import tvm_ffi.cpp from tvm_ffi.module import Module diff --git a/examples/packaging/python/my_ffi_extension/__init__.py b/examples/packaging/python/my_ffi_extension/__init__.py index 766a0990..d629d635 100644 --- a/examples/packaging/python/my_ffi_extension/__init__.py +++ b/examples/packaging/python/my_ffi_extension/__init__.py @@ -21,8 +21,7 @@ def add_one(x, y): - """ - Adds one to the input tensor. + """Adds one to the input tensor. Parameters ---------- @@ -30,13 +29,13 @@ def add_one(x, y): The input tensor. y : Tensor The output tensor. + """ return _LIB.add_one(x, y) def raise_error(msg): - """ - Raises an error with the given message. + """Raises an error with the given message. Parameters ---------- @@ -47,5 +46,6 @@ def raise_error(msg): ------ RuntimeError The error raised by the function. + """ return _ffi_api.raise_error(msg) diff --git a/examples/packaging/python/my_ffi_extension/_ffi_api.py b/examples/packaging/python/my_ffi_extension/_ffi_api.py index 5e034899..edc76775 100644 --- a/examples/packaging/python/my_ffi_extension/_ffi_api.py +++ b/examples/packaging/python/my_ffi_extension/_ffi_api.py @@ -17,6 +17,7 @@ import tvm_ffi # make sure lib is loaded first +from .base import _LIB # noqa: F401 # this is a short cut to register all the global functions # prefixed by `my_ffi_extension.` to this module diff --git a/pyproject.toml b/pyproject.toml index e047529a..94c63bbc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -67,7 +67,7 @@ cmake.build-type = "Release" cmake.args = [ "-DTVM_FFI_ATTACH_DEBUG_SYMBOLS=ON", "-DTVM_FFI_BUILD_TESTS=OFF", - "-DTVM_FFI_BUILD_PYTHON_MODULE=ON" + "-DTVM_FFI_BUILD_PYTHON_MODULE=ON", ] # Logging @@ -106,36 +106,42 @@ sdist.include = [ "/tests/**/*", ] -sdist.exclude = ["**/.git", "**/.github", "**/__pycache__", "**/*.pyc", "build", "dist"] +sdist.exclude = [ + "**/.git", + "**/.github", + "**/__pycache__", + "**/*.pyc", + "build", + "dist", +] [tool.pytest.ini_options] testpaths = ["tests"] -[tool.black] -line-length = 100 -skip-magic-trailing-comma = true - -exclude = ''' -/( - \.venv - | build - | docs - | dist - | 3rdparty/* -)/ -''' - -[tool.isort] -profile = "black" -src_paths = ["python", "tests"] -extend_skip = ["3rdparty"] -line_length = 100 -skip_gitignore = true - [tool.ruff] include = ["python/**/*.py", "tests/**/*.py"] +line-length = 100 +indent-width = 4 +target-version = "py39" [tool.ruff.lint] +select = [ + "UP", # pyupgrade, https://docs.astral.sh/ruff/rules/#pyupgrade-up + "PL", # pylint, https://docs.astral.sh/ruff/rules/#pylint-pl + "I", # isort, https://docs.astral.sh/ruff/rules/#isort-i + "RUF", # ruff, https://docs.astral.sh/ruff/rules/#ruff-specific-rules-ruf + "NPY", # numpy, https://docs.astral.sh/ruff/rules/#numpy-specific-rules-npy + "F", # pyflakes, https://docs.astral.sh/ruff/rules/#pyflakes-f + # "ANN", # flake8-annotations, https://docs.astral.sh/ruff/rules/#flake8-annotations-ann + # "PTH", # flake8-use-pathlib, https://docs.astral.sh/ruff/rules/#flake8-use-pathlib-pth + # "D", # pydocstyle, https://docs.astral.sh/ruff/rules/#pydocstyle-d +] +ignore = [ + "PLR2004", # pylint: magic-value-comparison + "ANN401", # flake8-annotations: any-type +] +fixable = ["ALL"] +unfixable = [] [tool.ruff.lint.per-file-ignores] "__init__.py" = ["F401"] @@ -144,26 +150,23 @@ include = ["python/**/*.py", "tests/**/*.py"] [tool.ruff.lint.pylint] max-args = 10 +[tool.ruff.format] +quote-style = "double" +indent-style = "space" +skip-magic-trailing-comma = false +line-ending = "auto" +docstring-code-format = false +docstring-code-line-length = "dynamic" + [tool.cibuildwheel] build-verbosity = 1 # only build up to cp312, cp312 # will be abi3 and can be used in future versions -build = [ - "cp39-*", - "cp310-*", - "cp311-*", - "cp312-*", -] -skip = [ - "*musllinux*" -] +build = ["cp39-*", "cp310-*", "cp311-*", "cp312-*"] +skip = ["*musllinux*"] # we only need to test on cp312 -test-skip = [ - "cp39-*", - "cp310-*", - "cp311-*", -] +test-skip = ["cp39-*", "cp310-*", "cp311-*"] # focus on testing abi3 wheel build-frontend = "build[uv]" test-command = "pytest {package}/tests/python -vvs" diff --git a/python/tvm_ffi/__init__.py b/python/tvm_ffi/__init__.py index b3b070fb..23813631 100644 --- a/python/tvm_ffi/__init__.py +++ b/python/tvm_ffi/__init__.py @@ -47,31 +47,31 @@ from . import _optional_torch_c_dlpack __all__ = [ - "dtype", + "Array", + "DLDeviceType", "Device", + "Device", + "Function", + "Map", + "Module", "Object", - "register_object", - "register_global_func", - "get_global_func", - "remove_global_func", - "init_ffi_api", "Object", "ObjectConvertible", - "Function", + "Shape", + "Tensor", + "access_path", "convert", - "register_error", - "Device", "device", - "DLDeviceType", + "dtype", "from_dlpack", - "Tensor", - "Shape", - "Array", - "Map", - "testing", - "access_path", + "get_global_func", + "init_ffi_api", + "load_module", + "register_error", + "register_global_func", + "register_object", + "remove_global_func", "serialization", - "Module", "system_lib", - "load_module", + "testing", ] diff --git a/python/tvm_ffi/_convert.py b/python/tvm_ffi/_convert.py index cf311b20..cdeccbc2 100644 --- a/python/tvm_ffi/_convert.py +++ b/python/tvm_ffi/_convert.py @@ -22,7 +22,7 @@ from . import container, core -def convert(value: Any) -> Any: +def convert(value: Any) -> Any: # noqa: PLR0911 """Convert a python object to ffi values. Parameters @@ -40,6 +40,7 @@ def convert(value: Any) -> Any: Function arguments to ffi function calls are automatically converted. So this function is mainly only used in internal or testing scenarios. + """ if isinstance(value, (core.Object, core.PyNativeObject, bool, Number)): return value diff --git a/python/tvm_ffi/_dtype.py b/python/tvm_ffi/_dtype.py index 1664d981..a76d1115 100644 --- a/python/tvm_ffi/_dtype.py +++ b/python/tvm_ffi/_dtype.py @@ -18,6 +18,7 @@ # pylint: disable=invalid-name from enum import IntEnum +from typing import Any, ClassVar from . import core @@ -54,11 +55,12 @@ class dtype(str): ---- This class subclasses str so it can be directly passed into other array api's dtype arguments. + """ __slots__ = ["__tvm_ffi_dtype__"] - _NUMPY_DTYPE_TO_STR = {} + _NUMPY_DTYPE_TO_STR: ClassVar[dict[Any, str]] = {} def __new__(cls, content): content = str(content) @@ -70,8 +72,7 @@ def __repr__(self): return f"dtype('{self}')" def with_lanes(self, lanes): - """ - Create a new dtype with the given number of lanes. + """Create a new dtype with the given number of lanes. Parameters ---------- @@ -82,6 +83,7 @@ def with_lanes(self, lanes): ------- dtype The new dtype with the given number of lanes. + """ cdtype = core._create_dtype_from_tuple( core.DataType, @@ -128,7 +130,7 @@ def lanes(self): dtype._NUMPY_DTYPE_TO_STR[np.dtype(np.float32)] = "float32" dtype._NUMPY_DTYPE_TO_STR[np.dtype(np.float64)] = "float64" if hasattr(np, "float_"): - dtype._NUMPY_DTYPE_TO_STR[np.dtype(np.float_)] = "float64" + dtype._NUMPY_DTYPE_TO_STR[np.dtype(np.float64)] = "float64" except ImportError: pass diff --git a/python/tvm_ffi/_optional_torch_c_dlpack.py b/python/tvm_ffi/_optional_torch_c_dlpack.py index b96e9d06..b8b1f8fa 100644 --- a/python/tvm_ffi/_optional_torch_c_dlpack.py +++ b/python/tvm_ffi/_optional_torch_c_dlpack.py @@ -378,8 +378,8 @@ def load_torch_c_dlpack_extension(): """ try: # optionally import torch - import torch - from torch.utils import cpp_extension + import torch # noqa: PLC0415 + from torch.utils import cpp_extension # noqa: PLC0415 include_paths = libinfo.include_paths() extra_cflags = ["-O3"] @@ -408,8 +408,7 @@ def load_torch_c_dlpack_extension(): pass except Exception as e: warnings.warn( - f"Failed to load torch c dlpack extension: {e}," - "EnvTensorAllocator will not be enabled." + f"Failed to load torch c dlpack extension: {e},EnvTensorAllocator will not be enabled." ) return None diff --git a/python/tvm_ffi/_tensor.py b/python/tvm_ffi/_tensor.py index a9212b44..bea20a92 100644 --- a/python/tvm_ffi/_tensor.py +++ b/python/tvm_ffi/_tensor.py @@ -28,10 +28,11 @@ class Shape(tuple, core.PyNativeObject): """Shape tuple that represents `ffi::Shape` returned by a ffi call. - Note + Note: ---- This class subclasses `tuple` so it can be used in most places where tuple is used in python array apis. + """ def __new__(cls, content): @@ -51,7 +52,7 @@ def __from_tvm_ffi_object__(cls, obj): def device(device_type, index=None): - """Construct a TVM FFI device with given device type and index + """Construct a TVM FFI device with given device type and index. Parameters ---------- @@ -74,8 +75,9 @@ def device(device_type, index=None): assert tvm_ffi.device("cuda:0") == tvm_ffi.device("cuda", 0) assert tvm_ffi.device("cpu:0") == tvm_ffi.device("cpu", 0) + """ return core._CLASS_DEVICE(device_type, index) -__all__ = ["from_dlpack", "Tensor", "device", "Device", "DLDeviceType"] +__all__ = ["DLDeviceType", "Device", "Tensor", "device", "from_dlpack"] diff --git a/python/tvm_ffi/access_path.py b/python/tvm_ffi/access_path.py index 91a426b2..8a453317 100644 --- a/python/tvm_ffi/access_path.py +++ b/python/tvm_ffi/access_path.py @@ -18,7 +18,7 @@ """Access path classes.""" from enum import IntEnum -from typing import Any, List +from typing import Any from . import core from .registry import register_object @@ -35,12 +35,12 @@ class AccessKind(IntEnum): @register_object("ffi.reflection.AccessStep") class AccessStep(core.Object): - """Access step container""" + """Access step container.""" @register_object("ffi.reflection.AccessPath") class AccessPath(core.Object): - """Access path container""" + """Access path container.""" def __init__(self) -> None: super().__init__() @@ -51,7 +51,7 @@ def __init__(self) -> None: @staticmethod def root() -> "AccessPath": - """Create a root access path""" + """Create a root access path.""" return AccessPath._root() def __eq__(self, other: Any) -> bool: @@ -65,7 +65,7 @@ def __ne__(self, other: Any) -> bool: return not self._path_equal(other) def is_prefix_of(self, other: "AccessPath") -> bool: - """Check if this access path is a prefix of another access path + """Check if this access path is a prefix of another access path. Parameters ---------- @@ -76,11 +76,12 @@ def is_prefix_of(self, other: "AccessPath") -> bool: ------- bool True if this access path is a prefix of the other access path, False otherwise + """ return self._is_prefix_of(other) def attr(self, attr_key: str) -> "AccessPath": - """Create an access path to the attribute of the current object + """Create an access path to the attribute of the current object. Parameters ---------- @@ -91,11 +92,12 @@ def attr(self, attr_key: str) -> "AccessPath": ------- AccessPath The extended access path + """ return self._attr(attr_key) def attr_missing(self, attr_key: str) -> "AccessPath": - """Create an access path that indicate an attribute is missing + """Create an access path that indicate an attribute is missing. Parameters ---------- @@ -106,11 +108,12 @@ def attr_missing(self, attr_key: str) -> "AccessPath": ------- AccessPath The extended access path + """ return self._attr_missing(attr_key) def array_item(self, index: int) -> "AccessPath": - """Create an access path to the item of the current array + """Create an access path to the item of the current array. Parameters ---------- @@ -121,11 +124,12 @@ def array_item(self, index: int) -> "AccessPath": ------- AccessPath The extended access path + """ return self._array_item(index) def array_item_missing(self, index: int) -> "AccessPath": - """Create an access path that indicate an array item is missing + """Create an access path that indicate an array item is missing. Parameters ---------- @@ -136,11 +140,12 @@ def array_item_missing(self, index: int) -> "AccessPath": ------- AccessPath The extended access path + """ return self._array_item_missing(index) def map_item(self, key: Any) -> "AccessPath": - """Create an access path to the item of the current map + """Create an access path to the item of the current map. Parameters ---------- @@ -151,11 +156,12 @@ def map_item(self, key: Any) -> "AccessPath": ------- AccessPath The extended access path + """ return self._map_item(key) def map_item_missing(self, key: Any) -> "AccessPath": - """Create an access path that indicate a map item is missing + """Create an access path that indicate a map item is missing. Parameters ---------- @@ -166,16 +172,18 @@ def map_item_missing(self, key: Any) -> "AccessPath": ------- AccessPath The extended access path + """ return self._map_item_missing(key) - def to_steps(self) -> List["AccessStep"]: - """Convert the access path to a list of access steps + def to_steps(self) -> list["AccessStep"]: + """Convert the access path to a list of access steps. Returns ------- List[AccessStep] The list of access steps + """ return self._to_steps() diff --git a/python/tvm_ffi/config.py b/python/tvm_ffi/config.py index 7e036806..64a536b3 100644 --- a/python/tvm_ffi/config.py +++ b/python/tvm_ffi/config.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Config utilities for finding paths to lib and headers""" +"""Config utilities for finding paths to lib and headers.""" import argparse import os @@ -31,34 +31,24 @@ def find_windows_implib(): return implib -def __main__(): - """Main function""" +def __main__(): # noqa: PLR0912 + """Main function.""" parser = argparse.ArgumentParser( description="Get various configuration information needed to compile with tvm-ffi" ) - parser.add_argument( - "--includedir", action="store_true", help="Print include directory" - ) + parser.add_argument("--includedir", action="store_true", help="Print include directory") parser.add_argument( "--dlpack-includedir", action="store_true", help="Print dlpack include directory", ) - parser.add_argument( - "--cmakedir", action="store_true", help="Print library directory" - ) - parser.add_argument( - "--sourcedir", action="store_true", help="Print source directory" - ) - parser.add_argument( - "--libfiles", action="store_true", help="Fully qualified library filenames" - ) + parser.add_argument("--cmakedir", action="store_true", help="Print library directory") + parser.add_argument("--sourcedir", action="store_true", help="Print source directory") + parser.add_argument("--libfiles", action="store_true", help="Fully qualified library filenames") parser.add_argument("--libdir", action="store_true", help="Print library directory") parser.add_argument("--libs", action="store_true", help="Libraries to be linked") - parser.add_argument( - "--cython-lib-path", action="store_true", help="Print cython path" - ) + parser.add_argument("--cython-lib-path", action="store_true", help="Print cython path") parser.add_argument("--cxxflags", action="store_true", help="Print cxx flags") parser.add_argument("--cflags", action="store_true", help="Print c flags") parser.add_argument("--ldflags", action="store_true", help="Print ld flags") diff --git a/python/tvm_ffi/container.py b/python/tvm_ffi/container.py index 8368cd41..c77af7f8 100644 --- a/python/tvm_ffi/container.py +++ b/python/tvm_ffi/container.py @@ -17,7 +17,8 @@ """Container classes.""" import collections.abc -from typing import Any, Mapping, Sequence +from collections.abc import Mapping, Sequence +from typing import Any from . import _ffi_api, core from .registry import register_object @@ -46,6 +47,7 @@ def getitem_helper(obj, elem_getter, length, idx): ------- result : object The result of getitem + """ if isinstance(idx, slice): start = idx.start if idx.start is not None else 0 @@ -88,6 +90,7 @@ class Array(core.Object, collections.abc.Sequence): a = tvm_ffi.convert([1, 2, 3]) assert isinstance(a, tvm_ffi.Array) assert len(a) == 3 + """ def __init__(self, input_list: Sequence[Any]): @@ -107,7 +110,7 @@ def __repr__(self): class KeysView(collections.abc.KeysView): - """Helper class to return keys view""" + """Helper class to return keys view.""" def __init__(self, backend_map): self._backend_map = backend_map @@ -130,7 +133,7 @@ def __contains__(self, k): class ValuesView(collections.abc.ValuesView): - """Helper class to return values view""" + """Helper class to return values view.""" def __init__(self, backend_map): self._backend_map = backend_map @@ -150,7 +153,7 @@ def __iter__(self): class ItemsView(collections.abc.ItemsView): - """Helper class to return items view""" + """Helper class to return items view.""" def __init__(self, backend_map): self.backend_map = backend_map @@ -196,6 +199,7 @@ class Map(core.Object, collections.abc.Mapping): assert len(amap) == 2 assert amap["a"] == 1 assert amap["b"] == 2 + """ def __init__(self, input_dict: Mapping[Any, Any]): @@ -218,7 +222,7 @@ def values(self): return ValuesView(self) def items(self): - """Get the items from the map""" + """Get the items from the map.""" return ItemsView(self) def __len__(self): @@ -242,6 +246,7 @@ def get(self, key, default=None): ------- value: object The result value. + """ return self[key] if key in self else default @@ -249,8 +254,4 @@ def __repr__(self): # exception safety handling for chandle=None if self.__chandle__() == 0: return type(self).__name__ + "(chandle=None)" - return ( - "{" - + ", ".join([f"{k.__repr__()}: {v.__repr__()}" for k, v in self.items()]) - + "}" - ) + return "{" + ", ".join([f"{k.__repr__()}: {v.__repr__()}" for k, v in self.items()]) + "}" diff --git a/python/tvm_ffi/cpp/load_inline.py b/python/tvm_ffi/cpp/load_inline.py index 6ce3d11d..5f4128e1 100644 --- a/python/tvm_ffi/cpp/load_inline.py +++ b/python/tvm_ffi/cpp/load_inline.py @@ -22,7 +22,8 @@ import shutil import subprocess import sys -from typing import Mapping, Optional, Sequence +from collections.abc import Mapping, Sequence +from typing import Optional from tvm_ffi.libinfo import find_dlpack_include_path, find_include_path, find_libtvm_ffi from tvm_ffi.module import Module, load_module @@ -65,7 +66,7 @@ def _hash_sources( def _maybe_write(path: str, content: str) -> None: """Write content to path if it does not already exist with the same content.""" if os.path.exists(path): - with open(path, "r") as f: + with open(path) as f: existing_content = f.read() if existing_content == content: return @@ -86,9 +87,7 @@ def _find_cuda_home() -> Optional[str]: else: # Guess #3 if IS_WINDOWS: - cuda_homes = glob.glob( - "C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v*.*" - ) + cuda_homes = glob.glob("C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v*.*") if len(cuda_homes) == 0: cuda_home = "" else: @@ -97,8 +96,7 @@ def _find_cuda_home() -> Optional[str]: cuda_home = "/usr/local/cuda" if not os.path.exists(cuda_home): raise RuntimeError( - "Could not find CUDA installation. " - "Please set CUDA_HOME environment variable." + "Could not find CUDA installation. Please set CUDA_HOME environment variable." ) return cuda_home @@ -115,7 +113,6 @@ def _get_cuda_target() -> str: flags.append(f"-gencode=arch=compute_{major}{minor},code=sm_{major}{minor}") return " ".join(flags) else: - # try: status = subprocess.run( args=["nvidia-smi", "--query-gpu=compute_cap", "--format=csv,noheader"], @@ -164,9 +161,7 @@ def _run_command_in_dev_prompt(args, cwd, capture_output): raise FileNotFoundError("No Visual Studio installation found.") # Construct the path to the VsDevCmd.bat file - vsdevcmd_path = os.path.join( - vs_install_path, "Common7", "Tools", "VsDevCmd.bat" - ) + vsdevcmd_path = os.path.join(vs_install_path, "Common7", "Tools", "VsDevCmd.bat") if not os.path.exists(vsdevcmd_path): raise FileNotFoundError(f"VsDevCmd.bat not found at: {vsdevcmd_path}") @@ -180,7 +175,7 @@ def _run_command_in_dev_prompt(args, cwd, capture_output): # Execute the command in a new shell return subprocess.run( - cmd_command, cwd=cwd, capture_output=capture_output, shell=True + cmd_command, check=False, cwd=cwd, capture_output=capture_output, shell=True ) except (FileNotFoundError, subprocess.CalledProcessError) as e: @@ -191,7 +186,7 @@ def _run_command_in_dev_prompt(args, cwd, capture_output): ) from e -def _generate_ninja_build( +def _generate_ninja_build( # noqa: PLR0915 name: str, build_dir: str, with_cuda: bool, @@ -231,7 +226,7 @@ def _generate_ninja_build( else: default_cflags = ["-std=c++17", "-fPIC", "-O2"] default_cuda_cflags = ["-Xcompiler", "-fPIC", "-std=c++17", "-O2"] - default_ldflags = ["-shared", "-L{}".format(tvm_ffi_lib_path), "-ltvm_ffi"] + default_ldflags = ["-shared", f"-L{tvm_ffi_lib_path}", "-ltvm_ffi"] if with_cuda: # determine the compute capability of the current GPU @@ -244,9 +239,7 @@ def _generate_ninja_build( cflags = default_cflags + [flag.strip() for flag in extra_cflags] cuda_cflags = default_cuda_cflags + [flag.strip() for flag in extra_cuda_cflags] ldflags = default_ldflags + [flag.strip() for flag in extra_ldflags] - include_paths = default_include_paths + [ - os.path.abspath(path) for path in extra_include_paths - ] + include_paths = default_include_paths + [os.path.abspath(path) for path in extra_include_paths] # append include paths for path in include_paths: @@ -256,9 +249,7 @@ def _generate_ninja_build( # flags ninja = [] ninja.append("ninja_required_version = 1.3") - ninja.append( - "cxx = {}".format(os.environ.get("CXX", "cl" if IS_WINDOWS else "c++")) - ) + ninja.append("cxx = {}".format(os.environ.get("CXX", "cl" if IS_WINDOWS else "c++"))) ninja.append("cflags = {}".format(" ".join(cflags))) if with_cuda: ninja.append("nvcc = {}".format(os.path.join(_find_cuda_home(), "bin", "nvcc"))) @@ -307,13 +298,11 @@ def _generate_ninja_build( ) # Use appropriate extension based on platform ext = ".dll" if IS_WINDOWS else ".so" - ninja.append( - "build {}{}: link main.o{}".format(name, ext, " cuda.o" if with_cuda else "") - ) + ninja.append("build {}{}: link main.o{}".format(name, ext, " cuda.o" if with_cuda else "")) ninja.append("") # default target - ninja.append("default {}{}".format(name, ext)) + ninja.append(f"default {name}{ext}") ninja.append("") return "\n".join(ninja) @@ -325,18 +314,16 @@ def _build_ninja(build_dir: str) -> None: if num_workers is not None: command += ["-j", num_workers] if IS_WINDOWS: - status = _run_command_in_dev_prompt( - args=command, cwd=build_dir, capture_output=True - ) + status = _run_command_in_dev_prompt(args=command, cwd=build_dir, capture_output=True) else: - status = subprocess.run(args=command, cwd=build_dir, capture_output=True) + status = subprocess.run(check=False, args=command, cwd=build_dir, capture_output=True) if status.returncode != 0: - msg = ["ninja exited with status {}".format(status.returncode)] + msg = [f"ninja exited with status {status.returncode}"] encoding = "oem" if IS_WINDOWS else "utf-8" if status.stdout: - msg.append("stdout:\n{}".format(status.stdout.decode(encoding))) + msg.append(f"stdout:\n{status.stdout.decode(encoding)}") if status.stderr: - msg.append("stderr:\n{}".format(status.stderr.decode(encoding))) + msg.append(f"stderr:\n{status.stderr.decode(encoding)}") raise RuntimeError("\n".join(msg)) @@ -401,7 +388,6 @@ def load_inline( Parameters ---------- - name: str The name of the tvm ffi module. cpp_sources: Sequence[str] | str, optional @@ -483,6 +469,7 @@ def load_inline( y = torch.empty_like(x) mod.add_one_cpu(x, y) torch.testing.assert_close(x + 1, y) + """ if cpp_sources is None: cpp_sources = [] @@ -529,9 +516,7 @@ def load_inline( extra_ldflags, extra_include_paths, ) - build_dir: str = os.path.join( - build_directory, "{}_{}".format(name, source_hash) - ) + build_dir: str = os.path.join(build_directory, f"{name}_{source_hash}") else: build_dir = os.path.abspath(build_directory) os.makedirs(build_dir, exist_ok=True) @@ -559,6 +544,4 @@ def load_inline( # Use appropriate extension based on platform ext = ".dll" if IS_WINDOWS else ".so" - return load_module( - os.path.abspath(os.path.join(build_dir, "{}{}".format(name, ext))) - ) + return load_module(os.path.abspath(os.path.join(build_dir, f"{name}{ext}"))) diff --git a/python/tvm_ffi/error.py b/python/tvm_ffi/error.py index cec6956e..28a1eadc 100644 --- a/python/tvm_ffi/error.py +++ b/python/tvm_ffi/error.py @@ -26,7 +26,7 @@ def _parse_traceback(traceback): - """Parse the traceback string into a list of (filename, lineno, func) + """Parse the traceback string into a list of (filename, lineno, func). Parameters ---------- @@ -37,6 +37,7 @@ def _parse_traceback(traceback): ------- result : List[Tuple[str, int, str]] The list of (filename, lineno, func) + """ pattern = r'File "(.+?)", line (\d+), in (.+)' result = [] @@ -54,9 +55,7 @@ def _parse_traceback(traceback): class TracebackManager: - """ - Helper to manage traceback generation - """ + """Helper to manage traceback generation.""" def __init__(self): self._code_cache = {} @@ -84,7 +83,7 @@ def _get_cached_code_object(self, filename, lineno, func): return code_object def _create_frame(self, filename, lineno, func): - """Create a frame object from the filename, lineno, and func""" + """Create a frame object from the filename, lineno, and func.""" code_object = self._get_cached_code_object(filename, lineno, func) # call into get frame, but changes the context so the code # points to the correct frame @@ -93,7 +92,7 @@ def _create_frame(self, filename, lineno, func): return eval(code_object, context, context) def append_traceback(self, tb, filename, lineno, func): - """Append a traceback to the given traceback + """Append a traceback to the given traceback. Parameters ---------- @@ -110,6 +109,7 @@ def append_traceback(self, tb, filename, lineno, func): ------- new_tb : types.TracebackType The new traceback with the appended frame. + """ frame = self._create_frame(filename, lineno, func) return types.TracebackType(tb, frame, frame.f_lasti, lineno) @@ -119,7 +119,7 @@ def append_traceback(self, tb, filename, lineno, func): def _with_append_traceback(py_error, traceback): - """Append the traceback to the py_error and return it""" + """Append the traceback to the py_error and return it.""" tb = py_error.__traceback__ for filename, lineno, func in reversed(_parse_traceback(traceback)): tb = _TRACEBACK_MANAGER.append_traceback(tb, filename, lineno, func) @@ -127,7 +127,7 @@ def _with_append_traceback(py_error, traceback): def _traceback_to_str(tb): - """Convert the traceback to a string""" + """Convert the traceback to a string.""" lines = [] while tb is not None: frame = tb.tb_frame @@ -169,13 +169,14 @@ class MyError(RuntimeError): err_inst = tvm.error.create_ffi_error("MyError: xyz") assert isinstance(err_inst, MyError) + """ if callable(name_or_cls): cls = name_or_cls name_or_cls = cls.__name__ def register(mycls): - """internal register function""" + """Internal register function.""" err_name = name_or_cls if isinstance(name_or_cls, str) else mycls.__name__ core.ERROR_NAME_TO_TYPE[err_name] = mycls core.ERROR_TYPE_TO_NAME[mycls] = err_name diff --git a/python/tvm_ffi/libinfo.py b/python/tvm_ffi/libinfo.py index 8325c355..1e09ac6e 100644 --- a/python/tvm_ffi/libinfo.py +++ b/python/tvm_ffi/libinfo.py @@ -35,6 +35,7 @@ def split_env_var(env_var, split): ------- splits : list(string) If env_var exists, split env_var. Otherwise, empty list. + """ if os.environ.get(env_var, None): return [p.strip() for p in os.environ[env_var].split(split)] @@ -42,7 +43,7 @@ def split_env_var(env_var, split): def get_dll_directories(): - """Get the possible dll directories""" + """Get the possible dll directories.""" ffi_dir = os.path.dirname(os.path.realpath(os.path.expanduser(__file__))) dll_path = [os.path.join(ffi_dir, "lib")] dll_path += [os.path.join(ffi_dir, "..", "..", "build", "lib")] @@ -75,9 +76,7 @@ def find_libtvm_ffi(): lib_found = [p for p in lib_dll_path if os.path.exists(p) and os.path.isfile(p)] if not lib_found: - raise RuntimeError( - f"Cannot find library: {name}\nList of candidates:\n{lib_dll_path}" - ) + raise RuntimeError(f"Cannot find library: {name}\nList of candidates:\n{lib_dll_path}") return lib_found[0] @@ -110,9 +109,7 @@ def find_include_path(): """Find header files for C compilation.""" candidates = [ os.path.join(os.path.dirname(os.path.realpath(__file__)), "include"), - os.path.join( - os.path.dirname(os.path.realpath(__file__)), "..", "..", "include" - ), + os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..", "include"), ] for candidate in candidates: if os.path.isdir(candidate): @@ -134,9 +131,7 @@ def find_python_helper_include_path(): def find_dlpack_include_path(): """Find dlpack header files for C compilation.""" - install_include_path = os.path.join( - os.path.dirname(os.path.realpath(__file__)), "include" - ) + install_include_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "include") if os.path.isdir(os.path.join(install_include_path, "dlpack")): return install_include_path diff --git a/python/tvm_ffi/module.py b/python/tvm_ffi/module.py index fbfb35de..a93cc8af 100644 --- a/python/tvm_ffi/module.py +++ b/python/tvm_ffi/module.py @@ -22,7 +22,7 @@ from . import _ffi_api, core from .registry import register_object -__all__ = ["Module", "ModulePropertyMask", "system_lib", "load_module"] +__all__ = ["Module", "ModulePropertyMask", "load_module", "system_lib"] class ModulePropertyMask(IntEnum): @@ -37,7 +37,7 @@ class ModulePropertyMask(IntEnum): class Module(core.Object): """Module container for dynamically loaded Module. - Example + Example: ------- .. code-block:: python @@ -48,9 +48,10 @@ class Module(core.Object): # you can use mod.func_name to call the exported function mod.func_name(*args) - See Also + See Also: -------- :py:func:`tvm_ffi.load_module` + """ # constant for entry function name @@ -63,12 +64,13 @@ def kind(self): @property def imports(self): - """Get imported modules + """Get imported modules. Returns - ---------- + ------- modules : list of Module The module + """ return self.imports_ @@ -91,6 +93,7 @@ def implements_function(self, name, query_imports=False): ------- b : Bool True if module (or one of its imports) has a definition for name. + """ return _ffi_api.ModuleImplementsFunction(self, name, query_imports) @@ -118,6 +121,7 @@ def get_function(self, name, query_imports=False): ------- f : tvm_ffi.Function The result function. + """ func = _ffi_api.ModuleGetFunction(self, name, query_imports) if func is None: @@ -131,6 +135,7 @@ def import_module(self, module): ---------- module : tvm.runtime.Module The other module. + """ _ffi_api.ModuleImportModule(self, module) @@ -155,6 +160,7 @@ def inspect_source(self, fmt=""): ------- source : str The result source code. + """ return _ffi_api.ModuleInspectSource(self, fmt) @@ -169,6 +175,7 @@ def get_property_mask(self): ------- mask : int Bitmask of runtime module property + """ return _ffi_api.ModuleGetPropertyMask(self) @@ -179,6 +186,7 @@ def is_binary_serializable(self): ------- b : Bool True if the module is binary serializable. + """ return (self.get_property_mask() & ModulePropertyMask.BINARY_SERIALIZABLE) != 0 @@ -189,6 +197,7 @@ def is_runnable(self): ------- b : Bool True if the module is runnable. + """ return (self.get_property_mask() & ModulePropertyMask.RUNNABLE) != 0 @@ -199,10 +208,9 @@ def is_compilation_exportable(self): ------- b : Bool True if the module is compilation exportable. + """ - return ( - self.get_property_mask() & ModulePropertyMask.COMPILATION_EXPORTABLE - ) != 0 + return (self.get_property_mask() & ModulePropertyMask.COMPILATION_EXPORTABLE) != 0 def clear_imports(self): """Remove all imports of the module.""" @@ -221,6 +229,7 @@ def write_to_file(self, file_name, fmt=""): See Also -------- runtime.Module.export_library : export the module to shared library. + """ _ffi_api.ModuleWriteToFile(self, file_name, fmt) @@ -245,6 +254,7 @@ def system_lib(symbol_prefix=""): ------- module : runtime.Module The system-wide library module. + """ return _ffi_api.SystemLib(symbol_prefix) @@ -272,5 +282,6 @@ def load_module(path): See Also -------- :py:class:`tvm_ffi.Module` + """ return _ffi_api.ModuleLoadFromFile(path) diff --git a/python/tvm_ffi/registry.py b/python/tvm_ffi/registry.py index f31dea34..81960473 100644 --- a/python/tvm_ffi/registry.py +++ b/python/tvm_ffi/registry.py @@ -25,7 +25,7 @@ def register_object(type_key=None): - """register object type. + """Register object type. Parameters ---------- @@ -42,16 +42,17 @@ def register_object(type_key=None): @tvm_ffi.register_object("test.MyObject") class MyObject(Object): pass + """ object_name = type_key if isinstance(type_key, str) else type_key.__name__ def register(cls): - """internal register function""" + """Internal register function.""" type_index = core._object_type_key_to_index(object_name) if type_index is None: if _SKIP_UNKNOWN_OBJECTS: return cls - raise ValueError("Cannot find object type index for %s" % object_name) + raise ValueError(f"Cannot find object type index for {object_name}") core._add_class_attrs_by_reflection(type_index, cls) core._register_object_by_index(type_index, cls) return cls @@ -63,7 +64,7 @@ def register(cls): def register_global_func(func_name, f=None, override=False): - """Register global function + """Register global function. Parameters ---------- @@ -104,6 +105,7 @@ def echo(x): -------- :py:func:`tvm_ffi.get_global_func` :py:func:`tvm_ffi.remove_global_func` + """ if callable(func_name): f = func_name @@ -113,7 +115,7 @@ def echo(x): raise ValueError("expect string function name") def register(myf): - """internal register function""" + """Internal register function.""" return core._register_global_func(func_name, myf, override) if f: @@ -122,7 +124,7 @@ def register(myf): def get_global_func(name, allow_missing=False): - """Get a global function by name + """Get a global function by name. Parameters ---------- @@ -140,6 +142,7 @@ def get_global_func(name, allow_missing=False): See Also -------- :py:func:`tvm_ffi.register_global_func` + """ return core._get_global_func(name, allow_missing) @@ -151,6 +154,7 @@ def list_global_func_names(): ------- names : list List of global functions names. + """ name_functor = get_global_func("ffi.FunctionListGlobalNamesFunctor")() num_names = name_functor(-1) @@ -158,18 +162,19 @@ def list_global_func_names(): def remove_global_func(name): - """Remove a global function by name + """Remove a global function by name. Parameters ---------- name : str The name of the global function + """ get_global_func("ffi.FunctionRemoveGlobal")(name) def init_ffi_api(namespace, target_module_name=None): - """Initialize register ffi api functions into a given module + """Initialize register ffi api functions into a given module. Parameters ---------- @@ -181,7 +186,6 @@ def init_ffi_api(namespace, target_module_name=None): Examples -------- - A typical usage pattern is to create a _ffi_api.py file to register the functions under a given module. The following code populates all registered global functions @@ -195,6 +199,7 @@ def init_ffi_api(namespace, target_module_name=None): import tvm_ffi tvm_ffi.init_ffi_api("mypackage", __name__) + """ target_module_name = target_module_name if target_module_name else namespace @@ -219,10 +224,10 @@ def init_ffi_api(namespace, target_module_name=None): __all__ = [ - "register_object", - "register_global_func", "get_global_func", + "init_ffi_api", "list_global_func_names", + "register_global_func", + "register_object", "remove_global_func", - "init_ffi_api", ] diff --git a/python/tvm_ffi/serialization.py b/python/tvm_ffi/serialization.py index e5367d9f..803d533b 100644 --- a/python/tvm_ffi/serialization.py +++ b/python/tvm_ffi/serialization.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Serialization related utilities to enable some object can be pickled""" +"""Serialization related utilities to enable some object can be pickled.""" from typing import Any, Optional @@ -22,8 +22,7 @@ def to_json_graph_str(obj: Any, metadata: Optional[dict] = None): - """ - Dump an object to a JSON graph string. + """Dump an object to a JSON graph string. The JSON graph string is a string representation of of the object graph includes the reference information of same objects, which can @@ -41,13 +40,13 @@ def to_json_graph_str(obj: Any, metadata: Optional[dict] = None): ------- json_str : str The JSON graph string. + """ return _ffi_api.ToJSONGraphString(obj, metadata) def from_json_graph_str(json_str: str): - """ - Load an object from a JSON graph string. + """Load an object from a JSON graph string. The JSON graph string is a string representation of of the object graph that also includes the reference information. @@ -61,6 +60,7 @@ def from_json_graph_str(json_str: str): ------- obj : Any The loaded object. + """ return _ffi_api.FromJSONGraphString(json_str) diff --git a/python/tvm_ffi/stream.py b/python/tvm_ffi/stream.py index 084dca83..e4ae19bd 100644 --- a/python/tvm_ffi/stream.py +++ b/python/tvm_ffi/stream.py @@ -42,6 +42,7 @@ class StreamContext: See Also -------- :py:func:`tvm_ffi.use_raw_stream`, :py:func:`tvm_ffi.use_torch_stream` + """ def __init__(self, device: core.Device, stream: Union[int, c_void_p]): @@ -82,8 +83,7 @@ def __exit__(self, *args): self.ffi_context.__exit__(*args) def use_torch_stream(context: Optional[Any] = None): - """ - Create a ffi stream context with given torch stream, + """Create a ffi stream context with given torch stream, cuda graph or current stream if `None` provided. Parameters @@ -111,6 +111,7 @@ def use_torch_stream(context: Optional[Any] = None): Note ---- When working with raw cudaStream_t handle, using :py:func:`tvm_ffi.use_raw_stream` instead. + """ return TorchStreamContext(context) @@ -121,8 +122,7 @@ def use_torch_stream(context: Optional[Any] = None): def use_raw_stream(device: core.Device, stream: Union[int, c_void_p]): - """ - Create a ffi stream context with given device and stream handle. + """Create a ffi stream context with given device and stream handle. Parameters ---------- @@ -140,6 +140,7 @@ def use_raw_stream(device: core.Device, stream: Union[int, c_void_p]): Note ---- When working with torch stram or cuda graph, using :py:func:`tvm_ffi.use_torch_stream` instead. + """ if not isinstance(stream, (int, c_void_p)): raise ValueError( diff --git a/python/tvm_ffi/testing.py b/python/tvm_ffi/testing.py index 843a10c8..3c173dc9 100644 --- a/python/tvm_ffi/testing.py +++ b/python/tvm_ffi/testing.py @@ -23,21 +23,16 @@ @register_object("testing.TestObjectBase") class TestObjectBase(Object): - """ - Test object base class. - """ + """Test object base class.""" @register_object("testing.TestObjectDerived") class TestObjectDerived(TestObjectBase): - """ - Test object derived class. - """ + """Test object derived class.""" def create_object(type_key: str, **kwargs) -> Object: - """ - Make an object by reflection. + """Make an object by reflection. Parameters ---------- @@ -55,6 +50,7 @@ def create_object(type_key: str, **kwargs) -> Object: ---- This function is only used for testing purposes and should not be used in other cases. + """ args = [type_key] for k, v in kwargs.items(): diff --git a/python/tvm_ffi/utils/lockfile.py b/python/tvm_ffi/utils/lockfile.py index 55ab41f3..581ea829 100644 --- a/python/tvm_ffi/utils/lockfile.py +++ b/python/tvm_ffi/utils/lockfile.py @@ -27,8 +27,7 @@ class FileLock: - """ - A cross-platform file locking mechanism using Python's standard library. + """A cross-platform file locking mechanism using Python's standard library. This class implements an advisory lock, which must be respected by all cooperating processes. """ @@ -38,23 +37,19 @@ def __init__(self, lock_file_path): self._file_descriptor = None def __enter__(self): - """ - Context manager protocol: acquire the lock upon entering the 'with' block. + """Context manager protocol: acquire the lock upon entering the 'with' block. This method will block indefinitely until the lock is acquired. """ self.blocking_acquire() return self def __exit__(self, exc_type, exc_val, exc_tb): - """ - Context manager protocol: release the lock upon exiting the 'with' block. - """ + """Context manager protocol: release the lock upon exiting the 'with' block.""" self.release() return False # Propagate exceptions, if any def acquire(self): - """ - Acquires an exclusive, non-blocking lock on the file. + """Acquires an exclusive, non-blocking lock on the file. Returns True if the lock was acquired, False otherwise. """ try: @@ -64,12 +59,10 @@ def acquire(self): ) msvcrt.locking(self._file_descriptor, msvcrt.LK_NBLCK, 1) else: # Unix-like systems - self._file_descriptor = os.open( - self.lock_file_path, os.O_WRONLY | os.O_CREAT - ) + self._file_descriptor = os.open(self.lock_file_path, os.O_WRONLY | os.O_CREAT) fcntl.flock(self._file_descriptor, fcntl.LOCK_EX | fcntl.LOCK_NB) return True - except (IOError, BlockingIOError): + except (OSError, BlockingIOError): if self._file_descriptor is not None: os.close(self._file_descriptor) self._file_descriptor = None @@ -81,13 +74,13 @@ def acquire(self): raise RuntimeError(f"An unexpected error occurred: {e}") def blocking_acquire(self, timeout=None, poll_interval=0.1): - """ - Waits until an exclusive lock can be acquired, with an optional timeout. + """Waits until an exclusive lock can be acquired, with an optional timeout. Args: timeout (float): The maximum time to wait for the lock in seconds. A value of None means wait indefinitely. poll_interval (float): The time to wait between lock attempts in seconds. + """ start_time = time.time() while True: @@ -103,9 +96,7 @@ def blocking_acquire(self, timeout=None, poll_interval=0.1): time.sleep(poll_interval) def release(self): - """ - Releases the lock and closes the file descriptor. - """ + """Releases the lock and closes the file descriptor.""" if self._file_descriptor is not None: if sys.platform == "win32": msvcrt.locking(self._file_descriptor, msvcrt.LK_UNLCK, 1) diff --git a/tests/lint/check_asf_header.py b/tests/lint/check_asf_header.py index 48df9541..5021212d 100644 --- a/tests/lint/check_asf_header.py +++ b/tests/lint/check_asf_header.py @@ -172,7 +172,7 @@ def should_skip_file(filepath): - """Check if file should be skipped based on SKIP_LIST""" + """Check if file should be skipped based on SKIP_LIST.""" for pattern in SKIP_LIST: if fnmatch.fnmatch(filepath, pattern): return True @@ -180,17 +180,15 @@ def should_skip_file(filepath): def get_git_files(): - """Get list of files tracked by git""" + """Get list of files tracked by git.""" try: result = subprocess.run( - ["git", "ls-files"], capture_output=True, text=True, cwd=os.getcwd() + ["git", "ls-files"], check=False, capture_output=True, text=True, cwd=os.getcwd() ) if result.returncode == 0: return [line.strip() for line in result.stdout.split("\n") if line.strip()] else: - print( - "Error: Could not get git files. Make sure you're in a git repository." - ) + print("Error: Could not get git files. Make sure you're in a git repository.") print("Git command failed:", result.stderr.strip()) return None except FileNotFoundError: @@ -211,7 +209,7 @@ def copyright_line(line): def check_header(fname, header): - """Check header status of file without modifying it""" + """Check header status of file without modifying it.""" if not os.path.exists(fname): print(f"ERROR: Cannot find {fname}") return False @@ -243,7 +241,7 @@ def check_header(fname, header): def collect_files(): - """Collect all files that need header checking from git""" + """Collect all files that need header checking from git.""" files = [] # Get files from git (required) @@ -266,18 +264,17 @@ def collect_files(): if ( suffix in FMT_MAP or basename == "gradle.properties" - or suffix == "" - and basename in ["CMakeLists", "Makefile"] + or (suffix == "" and basename in ["CMakeLists", "Makefile"]) ): files.append(git_file) return files -def add_header(fname, header): - """Add header to file""" +def add_header(fname, header): # noqa: PLR0912 + """Add header to file.""" if not os.path.exists(fname): - print("Cannot find %s ..." % fname) + print(f"Cannot find {fname} ...") return lines = open(fname).readlines() @@ -318,12 +315,12 @@ def add_header(fname, header): outfile.write(header + "\n\n") outfile.write("".join(lines)) if not has_asf_header: - print("Add header to %s" % fname) + print(f"Add header to {fname}") if has_copyright: - print("Removed copyright line from %s" % fname) + print(f"Removed copyright line from {fname}") -def main(): +def main(): # noqa: PLR0911, PLR0912 parser = argparse.ArgumentParser( description="Check and fix ASF headers in source files tracked by git", formatter_class=argparse.RawDescriptionHelpFormatter, diff --git a/tests/lint/check_file_type.py b/tests/lint/check_file_type.py index d6664703..7db2bc1e 100644 --- a/tests/lint/check_file_type.py +++ b/tests/lint/check_file_type.py @@ -128,6 +128,7 @@ def filename_allowed(name): ------- allowed : bool Whether the filename is allowed. + """ arr = name.rsplit(".", 1) if arr[-1] in ALLOW_EXTENSION: @@ -193,7 +194,7 @@ def main(): if error_list: report = "------File type check report----\n" report += "\n".join(error_list) - report += "\nFound %d files that are not allowed\n" % len(error_list) + report += f"\nFound {len(error_list)} files that are not allowed\n" report += ( "We do not check in binary files into the repo.\n" "If necessary, please discuss with committers and" @@ -212,14 +213,9 @@ def main(): if asf_copyright_list: report = "------File type check report----\n" report += "\n".join(asf_copyright_list) + "\n" - report += ( - "------Found %d files that has ASF header with copyright message----\n" - % len(asf_copyright_list) - ) + report += f"------Found {len(asf_copyright_list)} files that has ASF header with copyright message----\n" report += "--- Files with ASF header do not need Copyright lines.\n" - report += ( - "--- Contributors retain copyright to their contribution by default.\n" - ) + report += "--- Contributors retain copyright to their contribution by default.\n" report += "--- If a file comes with a different license, consider put it under the 3rdparty folder instead.\n" report += "---\n" report += "--- You can use the following steps to remove the copyright lines\n" diff --git a/tests/lint/git-clang-format.sh b/tests/lint/git-clang-format.sh deleted file mode 100755 index fee48039..00000000 --- a/tests/lint/git-clang-format.sh +++ /dev/null @@ -1,92 +0,0 @@ -#!/usr/bin/env bash -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -set -e -set -u -set -o pipefail - -INPLACE_FORMAT=${INPLACE_FORMAT:=false} -LINT_ALL_FILES=true -REVISION=$(git rev-list --max-parents=0 HEAD) - -while (($#)); do - case "$1" in - -i) - INPLACE_FORMAT=true - shift 1 - ;; - --rev) - LINT_ALL_FILES=false - REVISION=$2 - shift 2 - ;; - *) - echo "Usage: tests/lint/git-clang-format.sh [-i] [--rev ]" - echo "" - echo "Run clang-format on files that changed since or on all files in the repo" - echo "Examples:" - echo "- Compare last one commit: tests/lint/git-clang-format.sh --rev HEAD~1" - echo "- Compare against upstream/main: tests/lint/git-clang-format.sh --rev upstream/main" - echo "The -i will use black to format files in-place instead of checking them." - exit 1 - ;; - esac -done - -cleanup() { - if [ -f /tmp/$$.clang-format.txt ]; then - echo "" - echo "---------clang-format log----------" - cat /tmp/$$.clang-format.txt - fi - rm -rf /tmp/$$.clang-format.txt -} -trap cleanup 0 - -CLANG_FORMAT=clang-format-15 - -if [ -x "$(command -v clang-format-15)" ]; then - CLANG_FORMAT=clang-format-15 -elif [ -x "$(command -v clang-format)" ]; then - echo "clang-format might be different from clang-format-15, expect potential difference." - CLANG_FORMAT=clang-format -else - echo "Cannot find clang-format-15" - exit 1 -fi - -# Print out specific version -${CLANG_FORMAT} --version - -if [[ "$INPLACE_FORMAT" == "true" ]]; then - echo "Running inplace git-clang-format against $REVISION" - git-${CLANG_FORMAT} --extensions h,mm,c,cc,cu --binary=${CLANG_FORMAT} "$REVISION" - exit 0 -fi - -if [[ "$LINT_ALL_FILES" == "true" ]]; then - echo "Running git-clang-format against all C++ files" - git-${CLANG_FORMAT} --diff --extensions h,mm,c,cc,cu --binary=${CLANG_FORMAT} "$REVISION" 1>/tmp/$$.clang-format.txt -else - echo "Running git-clang-format against $REVISION" - git-${CLANG_FORMAT} --diff --extensions h,mm,c,cc,cu --binary=${CLANG_FORMAT} "$REVISION" 1>/tmp/$$.clang-format.txt -fi - -if grep --quiet -E "diff" @@ -345,9 +332,7 @@ def load_torch_get_current_cuda_stream(): def bench_torch_get_current_stream(repeat, name, func): - """ - Measures overhead of running torch.cuda.current_stream - """ + """Measures overhead of running torch.cuda.current_stream.""" x = torch.arange(1, device="cuda") # noqa: F841 func(0) start = time.time() @@ -360,14 +345,12 @@ def bench_torch_get_current_stream(repeat, name, func): def populate_object_table(num_classes): nop = tvm_ffi.get_global_func("testing.nop") - dummy_instances = [ - type(f"DummyClass{i}", (object,), {})() for i in range(num_classes) - ] + dummy_instances = [type(f"DummyClass{i}", (object,), {})() for i in range(num_classes)] for instance in dummy_instances: nop(instance) -def main(): +def main(): # noqa: PLR0915 repeat = 10000 # measures impact of object dispatch table size # takeaway so far is that there is no impact on the performance @@ -401,12 +384,8 @@ def main(): print("---------------------------------------------------") print("Benchmark x.__dlpack__(max_version=(1,1)) overhead") print("---------------------------------------------------") - bench_to_dlpack_versioned( - torch.arange(1), "torch.__dlpack__(max_version=(1,1))", repeat - ) - bench_to_dlpack_versioned( - np.arange(1), "numpy.__dlpack__(max_version=(1,1))", repeat - ) + bench_to_dlpack_versioned(torch.arange(1), "torch.__dlpack__(max_version=(1,1))", repeat) + bench_to_dlpack_versioned(np.arange(1), "numpy.__dlpack__(max_version=(1,1))", repeat) bench_to_dlpack_versioned( tvm_ffi.from_dlpack(torch.arange(1)), "tvm.__dlpack__(max_version=(1,1))", @@ -415,9 +394,7 @@ def main(): print("---------------------------------------------------") print("Benchmark torch.get_cuda_stream[default stream]") print("---------------------------------------------------") - bench_torch_get_current_stream( - repeat, "cpp-extension", load_torch_get_current_cuda_stream() - ) + bench_torch_get_current_stream(repeat, "cpp-extension", load_torch_get_current_cuda_stream()) bench_torch_get_current_stream(repeat, "python", torch_get_cuda_stream_native) print("---------------------------------------------------") print("Benchmark torch.get_cuda_stream[non-default stream]") diff --git a/tests/scripts/task_lint.sh b/tests/scripts/task_lint.sh deleted file mode 100755 index 5b17cf81..00000000 --- a/tests/scripts/task_lint.sh +++ /dev/null @@ -1,46 +0,0 @@ -#!/usr/bin/env bash -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -set -euxo pipefail - -cleanup() { - rm -rf /tmp/$$.* -} -trap cleanup 0 - -function run_lint { - echo "Checking file types..." - python tests/lint/check_file_type.py - - echo "Checking ASF headers..." - python tests/lint/check_asf_header.py --check - - echo "isort check..." - isort --check --diff . - - echo "black check..." - black --check --diff . - - echo "ruff check..." - ruff check --diff . - - echo "clang-format check..." - tests/lint/git-clang-format.sh -} - -run_lint From 63cc7cf5bf3c0be880a9f6b7c770b6f2a26e0b3e Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Wed, 17 Sep 2025 15:01:15 -0700 Subject: [PATCH 2/3] enable pathlib checks --- docs/conf.py | 4 +- .../packaging/python/my_ffi_extension/base.py | 8 +-- pyproject.toml | 2 +- python/tvm_ffi/config.py | 14 ++-- python/tvm_ffi/cpp/load_inline.py | 68 +++++++++--------- python/tvm_ffi/libinfo.py | 69 +++++++++---------- tests/lint/check_asf_header.py | 18 ++--- tests/lint/check_file_type.py | 8 +-- 8 files changed, 95 insertions(+), 96 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index e6218781..fdb5ca17 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -16,6 +16,7 @@ # under the License. # -*- coding: utf-8 -*- import os +from pathlib import Path import tomli @@ -25,9 +26,8 @@ # -- General configuration ------------------------------------------------ - # Load version from pyproject.toml -with open("../pyproject.toml", "rb") as f: +with Path("../pyproject.toml").open("rb") as f: pyproject_data = tomli.load(f) __version__ = pyproject_data["project"]["version"] diff --git a/examples/packaging/python/my_ffi_extension/base.py b/examples/packaging/python/my_ffi_extension/base.py index fa172526..2e2d09de 100644 --- a/examples/packaging/python/my_ffi_extension/base.py +++ b/examples/packaging/python/my_ffi_extension/base.py @@ -14,15 +14,15 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations. # Base logic to load library for extension package -import os import sys +from pathlib import Path import tvm_ffi def _load_lib(): # first look at the directory of the current file - file_dir = os.path.dirname(os.path.realpath(__file__)) + file_dir = Path(__file__).resolve().parent if sys.platform.startswith("win32"): lib_dll_name = "my_ffi_extension.dll" @@ -31,8 +31,8 @@ def _load_lib(): else: lib_dll_name = "my_ffi_extension.so" - lib_path = os.path.join(file_dir, lib_dll_name) - return tvm_ffi.load_module(lib_path) + lib_path = file_dir / lib_dll_name + return tvm_ffi.load_module(str(lib_path)) _LIB = _load_lib() diff --git a/pyproject.toml b/pyproject.toml index 94c63bbc..dff8c028 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -133,7 +133,7 @@ select = [ "NPY", # numpy, https://docs.astral.sh/ruff/rules/#numpy-specific-rules-npy "F", # pyflakes, https://docs.astral.sh/ruff/rules/#pyflakes-f # "ANN", # flake8-annotations, https://docs.astral.sh/ruff/rules/#flake8-annotations-ann - # "PTH", # flake8-use-pathlib, https://docs.astral.sh/ruff/rules/#flake8-use-pathlib-pth + "PTH", # flake8-use-pathlib, https://docs.astral.sh/ruff/rules/#flake8-use-pathlib-pth # "D", # pydocstyle, https://docs.astral.sh/ruff/rules/#pydocstyle-d ] ignore = [ diff --git a/python/tvm_ffi/config.py b/python/tvm_ffi/config.py index 64a536b3..32ad9a51 100644 --- a/python/tvm_ffi/config.py +++ b/python/tvm_ffi/config.py @@ -17,18 +17,18 @@ """Config utilities for finding paths to lib and headers.""" import argparse -import os import sys +from pathlib import Path from . import libinfo def find_windows_implib(): - libdir = os.path.dirname(libinfo.find_libtvm_ffi()) - implib = os.path.join(libdir, "tvm_ffi.lib") - if not os.path.isfile(implib): + libdir = Path(libinfo.find_libtvm_ffi()).parent + implib = libdir / "tvm_ffi.lib" + if not implib.is_file(): raise RuntimeError(f"Cannot find imp lib {implib}") - return implib + return str(implib) def __main__(): # noqa: PLR0912 @@ -67,7 +67,7 @@ def __main__(): # noqa: PLR0912 if args.cmakedir: print(libinfo.find_cmake_path()) if args.libdir: - print(os.path.dirname(libinfo.find_libtvm_ffi())) + print(Path(libinfo.find_libtvm_ffi()).parent) if args.libfiles: if sys.platform.startswith("win32"): print(find_windows_implib()) @@ -92,7 +92,7 @@ def __main__(): # noqa: PLR0912 print("-ltvm_ffi") if args.ldflags: if not sys.platform.startswith("win32"): - print(f"-L{os.path.dirname(libinfo.find_libtvm_ffi())}") + print(f"-L{Path(libinfo.find_libtvm_ffi()).parent}") if __name__ == "__main__": diff --git a/python/tvm_ffi/cpp/load_inline.py b/python/tvm_ffi/cpp/load_inline.py index 5f4128e1..0dd8b97a 100644 --- a/python/tvm_ffi/cpp/load_inline.py +++ b/python/tvm_ffi/cpp/load_inline.py @@ -16,13 +16,13 @@ # under the License. import functools -import glob import hashlib import os import shutil import subprocess import sys from collections.abc import Mapping, Sequence +from pathlib import Path from typing import Optional from tvm_ffi.libinfo import find_dlpack_include_path, find_include_path, find_libtvm_ffi @@ -65,12 +65,13 @@ def _hash_sources( def _maybe_write(path: str, content: str) -> None: """Write content to path if it does not already exist with the same content.""" - if os.path.exists(path): - with open(path) as f: + p = Path(path) + if p.exists(): + with p.open() as f: existing_content = f.read() if existing_content == content: return - with open(path, "w") as f: + with p.open("w") as f: f.write(content) @@ -83,18 +84,19 @@ def _find_cuda_home() -> Optional[str]: # Guess #2 nvcc_path = shutil.which("nvcc") if nvcc_path is not None: - cuda_home = os.path.dirname(os.path.dirname(nvcc_path)) + cuda_home = str(Path(nvcc_path).parent.parent) else: # Guess #3 if IS_WINDOWS: - cuda_homes = glob.glob("C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v*.*") + cuda_root = Path("C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA") + cuda_homes = list(cuda_root.glob("v*.*")) if len(cuda_homes) == 0: cuda_home = "" else: - cuda_home = cuda_homes[0] + cuda_home = str(cuda_homes[0]) else: cuda_home = "/usr/local/cuda" - if not os.path.exists(cuda_home): + if not Path(cuda_home).exists(): raise RuntimeError( "Could not find CUDA installation. Please set CUDA_HOME environment variable." ) @@ -131,14 +133,14 @@ def _run_command_in_dev_prompt(args, cwd, capture_output): """Locates the Developer Command Prompt and runs a command within its environment.""" try: # Path to vswhere.exe - vswhere_path = os.path.join( - os.environ.get("ProgramFiles(x86)", "C:\\Program Files (x86)"), - "Microsoft Visual Studio", - "Installer", - "vswhere.exe", + vswhere_path = str( + Path(os.environ.get("ProgramFiles(x86)", "C:\\Program Files (x86)")) + / "Microsoft Visual Studio" + / "Installer" + / "vswhere.exe" ) - if not os.path.exists(vswhere_path): + if not Path(vswhere_path).exists(): raise FileNotFoundError("vswhere.exe not found.") # Find the Visual Studio installation path @@ -161,9 +163,9 @@ def _run_command_in_dev_prompt(args, cwd, capture_output): raise FileNotFoundError("No Visual Studio installation found.") # Construct the path to the VsDevCmd.bat file - vsdevcmd_path = os.path.join(vs_install_path, "Common7", "Tools", "VsDevCmd.bat") + vsdevcmd_path = str(Path(vs_install_path) / "Common7" / "Tools" / "VsDevCmd.bat") - if not os.path.exists(vsdevcmd_path): + if not Path(vsdevcmd_path).exists(): raise FileNotFoundError(f"VsDevCmd.bat not found at: {vsdevcmd_path}") # Use cmd.exe to run the batch file and then your command. @@ -199,8 +201,8 @@ def _generate_ninja_build( # noqa: PLR0915 default_include_paths = [find_include_path(), find_dlpack_include_path()] tvm_ffi_lib = find_libtvm_ffi() - tvm_ffi_lib_path = os.path.dirname(tvm_ffi_lib) - tvm_ffi_lib_name = os.path.splitext(os.path.basename(tvm_ffi_lib))[0] + tvm_ffi_lib_path = str(Path(tvm_ffi_lib).parent) + tvm_ffi_lib_name = Path(tvm_ffi_lib).stem if IS_WINDOWS: default_cflags = [ "/std:c++17", @@ -232,14 +234,16 @@ def _generate_ninja_build( # noqa: PLR0915 # determine the compute capability of the current GPU default_cuda_cflags += [_get_cuda_target()] default_ldflags += [ - "-L{}".format(os.path.join(_find_cuda_home(), "lib64")), + "-L{}".format(str(Path(_find_cuda_home()) / "lib64")), "-lcudart", ] cflags = default_cflags + [flag.strip() for flag in extra_cflags] cuda_cflags = default_cuda_cflags + [flag.strip() for flag in extra_cuda_cflags] ldflags = default_ldflags + [flag.strip() for flag in extra_ldflags] - include_paths = default_include_paths + [os.path.abspath(path) for path in extra_include_paths] + include_paths = default_include_paths + [ + str(Path(path).resolve()) for path in extra_include_paths + ] # append include paths for path in include_paths: @@ -252,7 +256,7 @@ def _generate_ninja_build( # noqa: PLR0915 ninja.append("cxx = {}".format(os.environ.get("CXX", "cl" if IS_WINDOWS else "c++"))) ninja.append("cflags = {}".format(" ".join(cflags))) if with_cuda: - ninja.append("nvcc = {}".format(os.path.join(_find_cuda_home(), "bin", "nvcc"))) + ninja.append("nvcc = {}".format(str(Path(_find_cuda_home()) / "bin" / "nvcc"))) ninja.append("cuda_cflags = {}".format(" ".join(cuda_cflags))) ninja.append("ldflags = {}".format(" ".join(ldflags))) @@ -287,13 +291,13 @@ def _generate_ninja_build( # noqa: PLR0915 # build targets ninja.append( "build main.o: compile {}".format( - os.path.abspath(os.path.join(build_dir, "main.cpp")).replace(":", "$:") + str((Path(build_dir) / "main.cpp").resolve()).replace(":", "$:") ) ) if with_cuda: ninja.append( "build cuda.o: compile_cuda {}".format( - os.path.abspath(os.path.join(build_dir, "cuda.cu")).replace(":", "$:") + str((Path(build_dir) / "cuda.cu").resolve()).replace(":", "$:") ) ) # Use appropriate extension based on platform @@ -505,7 +509,7 @@ def load_inline( # determine the cache dir for the built module if build_directory is None: build_directory = os.environ.get( - "TVM_FFI_CACHE_DIR", os.path.expanduser("~/.cache/tvm-ffi") + "TVM_FFI_CACHE_DIR", str(Path("~/.cache/tvm-ffi").expanduser()) ) source_hash: str = _hash_sources( cpp_source, @@ -516,10 +520,10 @@ def load_inline( extra_ldflags, extra_include_paths, ) - build_dir: str = os.path.join(build_directory, f"{name}_{source_hash}") + build_dir: str = str(Path(build_directory) / f"{name}_{source_hash}") else: - build_dir = os.path.abspath(build_directory) - os.makedirs(build_dir, exist_ok=True) + build_dir = str(Path(build_directory).resolve()) + Path(build_dir).mkdir(parents=True, exist_ok=True) # generate build.ninja ninja_source = _generate_ninja_build( @@ -532,16 +536,16 @@ def load_inline( extra_include_paths=extra_include_paths, ) - with FileLock(os.path.join(build_dir, "lock")): + with FileLock(str(Path(build_dir) / "lock")): # write source files and build.ninja if they do not already exist - _maybe_write(os.path.join(build_dir, "main.cpp"), cpp_source) + _maybe_write(str(Path(build_dir) / "main.cpp"), cpp_source) if with_cuda: - _maybe_write(os.path.join(build_dir, "cuda.cu"), cuda_source) - _maybe_write(os.path.join(build_dir, "build.ninja"), ninja_source) + _maybe_write(str(Path(build_dir) / "cuda.cu"), cuda_source) + _maybe_write(str(Path(build_dir) / "build.ninja"), ninja_source) # build the module _build_ninja(build_dir) # Use appropriate extension based on platform ext = ".dll" if IS_WINDOWS else ".so" - return load_module(os.path.abspath(os.path.join(build_dir, f"{name}{ext}"))) + return load_module(str((Path(build_dir) / f"{name}{ext}").resolve())) diff --git a/python/tvm_ffi/libinfo.py b/python/tvm_ffi/libinfo.py index 1e09ac6e..3e419474 100644 --- a/python/tvm_ffi/libinfo.py +++ b/python/tvm_ffi/libinfo.py @@ -15,9 +15,9 @@ # specific language governing permissions and limitations # under the License. -import glob import os import sys +from pathlib import Path def split_env_var(env_var, split): @@ -44,11 +44,11 @@ def split_env_var(env_var, split): def get_dll_directories(): """Get the possible dll directories.""" - ffi_dir = os.path.dirname(os.path.realpath(os.path.expanduser(__file__))) - dll_path = [os.path.join(ffi_dir, "lib")] - dll_path += [os.path.join(ffi_dir, "..", "..", "build", "lib")] + ffi_dir = Path(__file__).expanduser().resolve().parent + dll_path = [ffi_dir / "lib"] + dll_path += [ffi_dir / ".." / ".." / "build" / "lib"] # in source build from parent if needed - dll_path += [os.path.join(ffi_dir, "..", "..", "..", "build", "lib")] + dll_path += [ffi_dir / ".." / ".." / ".." / "build" / "lib"] if sys.platform.startswith("linux") or sys.platform.startswith("freebsd"): dll_path.extend(split_env_var("LD_LIBRARY_PATH", ":")) @@ -58,7 +58,7 @@ def get_dll_directories(): dll_path.extend(split_env_var("PATH", ":")) elif sys.platform.startswith("win32"): dll_path.extend(split_env_var("PATH", ";")) - return [os.path.abspath(x) for x in dll_path if os.path.isdir(x)] + return [str(Path(x).resolve()) for x in dll_path if Path(x).is_dir()] def find_libtvm_ffi(): @@ -72,8 +72,8 @@ def find_libtvm_ffi(): lib_dll_names = ["libtvm_ffi.so"] name = lib_dll_names - lib_dll_path = [os.path.join(p, name) for name in lib_dll_names for p in dll_path] - lib_found = [p for p in lib_dll_path if os.path.exists(p) and os.path.isfile(p)] + lib_dll_path = [str(Path(p) / name) for name in lib_dll_names for p in dll_path] + lib_found = [p for p in lib_dll_path if Path(p).exists() and Path(p).is_file()] if not lib_found: raise RuntimeError(f"Cannot find library: {name}\nList of candidates:\n{lib_dll_path}") @@ -84,11 +84,11 @@ def find_libtvm_ffi(): def find_source_path(): """Find packaged source home path.""" candidates = [ - os.path.join(os.path.dirname(os.path.realpath(__file__))), - os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", ".."), + str(Path(__file__).resolve().parent), + str(Path(__file__).resolve().parent / ".." / ".."), ] for candidate in candidates: - if os.path.isdir(os.path.join(candidate, "cmake")): + if Path(candidate, "cmake").is_dir(): return candidate raise RuntimeError("Cannot find home path.") @@ -96,11 +96,11 @@ def find_source_path(): def find_cmake_path(): """Find the preferred cmake path.""" candidates = [ - os.path.join(os.path.dirname(os.path.realpath(__file__)), "cmake"), - os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..", "cmake"), + str(Path(__file__).resolve().parent / "cmake"), + str(Path(__file__).resolve().parent / ".." / ".." / "cmake"), ] for candidate in candidates: - if os.path.isdir(candidate): + if Path(candidate).is_dir(): return candidate raise RuntimeError("Cannot find cmake path.") @@ -108,11 +108,11 @@ def find_cmake_path(): def find_include_path(): """Find header files for C compilation.""" candidates = [ - os.path.join(os.path.dirname(os.path.realpath(__file__)), "include"), - os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..", "include"), + str(Path(__file__).resolve().parent / "include"), + str(Path(__file__).resolve().parent / ".." / ".." / "include"), ] for candidate in candidates: - if os.path.isdir(candidate): + if Path(candidate).is_dir(): return candidate raise RuntimeError("Cannot find include path.") @@ -120,31 +120,26 @@ def find_include_path(): def find_python_helper_include_path(): """Find header files for C compilation.""" candidates = [ - os.path.join(os.path.dirname(os.path.realpath(__file__)), "include"), - os.path.join(os.path.dirname(os.path.realpath(__file__)), "cython"), + str(Path(__file__).resolve().parent / "include"), + str(Path(__file__).resolve().parent / "cython"), ] for candidate in candidates: - if os.path.isfile(os.path.join(candidate, "tvm_ffi_python_helpers.h")): + if Path(candidate, "tvm_ffi_python_helpers.h").is_file(): return candidate raise RuntimeError("Cannot find python helper include path.") def find_dlpack_include_path(): """Find dlpack header files for C compilation.""" - install_include_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "include") - if os.path.isdir(os.path.join(install_include_path, "dlpack")): - return install_include_path - - source_include_path = os.path.join( - os.path.dirname(os.path.realpath(__file__)), - "..", - "..", - "3rdparty", - "dlpack", - "include", + install_include_path = Path(__file__).resolve().parent / "include" + if (install_include_path / "dlpack").is_dir(): + return str(install_include_path) + + source_include_path = ( + Path(__file__).resolve().parent / ".." / ".." / "3rdparty" / "dlpack" / "include" ) - if os.path.isdir(source_include_path): - return source_include_path + if source_include_path.is_dir(): + return str(source_include_path) raise RuntimeError("Cannot find include path.") @@ -152,13 +147,13 @@ def find_dlpack_include_path(): def find_cython_lib(): """Find the path to tvm cython.""" path_candidates = [ - os.path.dirname(os.path.realpath(__file__)), - os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..", "build"), + Path(__file__).resolve().parent, + Path(__file__).resolve().parent / ".." / ".." / "build", ] suffixes = "pyd" if sys.platform.startswith("win32") else "so" for candidate in path_candidates: - for path in glob.glob(os.path.join(candidate, f"core*.{suffixes}")): - return os.path.abspath(path) + for path in Path(candidate).glob(f"core*.{suffixes}"): + return str(Path(path).resolve()) raise RuntimeError("Cannot find tvm cython path.") diff --git a/tests/lint/check_asf_header.py b/tests/lint/check_asf_header.py index 5021212d..9fcce8b8 100644 --- a/tests/lint/check_asf_header.py +++ b/tests/lint/check_asf_header.py @@ -18,9 +18,9 @@ import argparse import fnmatch -import os import subprocess import sys +from pathlib import Path header_cstyle = """ /* @@ -183,7 +183,7 @@ def get_git_files(): """Get list of files tracked by git.""" try: result = subprocess.run( - ["git", "ls-files"], check=False, capture_output=True, text=True, cwd=os.getcwd() + ["git", "ls-files"], check=False, capture_output=True, text=True, cwd=Path.cwd() ) if result.returncode == 0: return [line.strip() for line in result.stdout.split("\n") if line.strip()] @@ -210,11 +210,11 @@ def copyright_line(line): def check_header(fname, header): """Check header status of file without modifying it.""" - if not os.path.exists(fname): + if not Path(fname).exists(): print(f"ERROR: Cannot find {fname}") return False - lines = open(fname).readlines() + lines = Path(fname).open().readlines() has_asf_header = False has_copyright = False @@ -259,7 +259,7 @@ def collect_files(): # Check if this file type is supported suffix = git_file.split(".")[-1] if "." in git_file else "" - basename = os.path.basename(git_file) + basename = Path(git_file).name if ( suffix in FMT_MAP @@ -273,11 +273,11 @@ def collect_files(): def add_header(fname, header): # noqa: PLR0912 """Add header to file.""" - if not os.path.exists(fname): + if not Path(fname).exists(): print(f"Cannot find {fname} ...") return - lines = open(fname).readlines() + lines = Path(fname).open().readlines() has_asf_header = False has_copyright = False @@ -292,7 +292,7 @@ def add_header(fname, header): # noqa: PLR0912 if has_asf_header and not has_copyright: return - with open(fname, "w") as outfile: + with Path(fname).open("w") as outfile: skipline = False if not lines: skipline = False # File is enpty @@ -397,7 +397,7 @@ def main(): # noqa: PLR0911, PLR0912 for fname in files: processed_count += 1 suffix = fname.split(".")[-1] if "." in fname else "" - basename = os.path.basename(fname) + basename = Path(fname).name # Determine header type if suffix in FMT_MAP: diff --git a/tests/lint/check_file_type.py b/tests/lint/check_file_type.py index 7db2bc1e..9517168c 100644 --- a/tests/lint/check_file_type.py +++ b/tests/lint/check_file_type.py @@ -16,9 +16,9 @@ # under the License. """Helper tool to check file types that are allowed to checkin.""" -import os import subprocess import sys +from pathlib import Path # List of file types we allow ALLOW_EXTENSION = { @@ -134,7 +134,7 @@ def filename_allowed(name): if arr[-1] in ALLOW_EXTENSION: return True - if os.path.basename(name) in ALLOW_FILE_NAME: + if Path(name).name in ALLOW_FILE_NAME: return True if name.startswith("3rdparty"): @@ -161,12 +161,12 @@ def copyright_line(line): def check_asf_copyright(fname): if fname.endswith(".png"): return True - if not os.path.isfile(fname): + if not Path(fname).is_file(): return True has_asf_header = False has_copyright = False try: - for line in open(fname): + for line in Path(fname).open(): if line.find("Licensed to the Apache Software Foundation") != -1: has_asf_header = True if copyright_line(line): From 906ded1d3a6e0dd8c539f0e9b2272ddbca5a0a84 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Wed, 17 Sep 2025 15:14:26 -0700 Subject: [PATCH 3/3] enable pydocstring --- docs/conf.py | 3 +++ examples/inline_module/main.py | 2 ++ .../python/my_ffi_extension/__init__.py | 5 +++-- .../packaging/python/my_ffi_extension/base.py | 2 ++ examples/packaging/run_example.py | 4 ++++ examples/quick_start/run_example.py | 4 +++- pyproject.toml | 15 ++++++++++++--- python/tvm_ffi/access_path.py | 5 +++++ python/tvm_ffi/config.py | 3 ++- python/tvm_ffi/container.py | 14 +++++++++++++- python/tvm_ffi/cpp/__init__.py | 1 + python/tvm_ffi/cpp/load_inline.py | 1 + python/tvm_ffi/error.py | 3 ++- python/tvm_ffi/libinfo.py | 3 ++- python/tvm_ffi/module.py | 16 ++++++++++++---- python/tvm_ffi/registry.py | 4 ++-- python/tvm_ffi/stream.py | 15 +++++++++++++-- python/tvm_ffi/utils/__init__.py | 1 + python/tvm_ffi/utils/lockfile.py | 15 ++++++++++----- tests/scripts/benchmark_dlpack.py | 16 +++++----------- 20 files changed, 98 insertions(+), 34 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index fdb5ca17..28307114 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -14,6 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +"""Sphinx configuration for the tvm-ffi documentation site.""" + # -*- coding: utf-8 -*- import os from pathlib import Path @@ -181,6 +183,7 @@ def footer_html(): + """Generate HTML for the documentation footer.""" # Create footer HTML with two-line layout # Generate dropdown menu items dropdown_items = "" diff --git a/examples/inline_module/main.py b/examples/inline_module/main.py index 2477cffc..8afa9b51 100644 --- a/examples/inline_module/main.py +++ b/examples/inline_module/main.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +"""Example: Build and run an inline C++/CUDA tvm-ffi module.""" import torch import tvm_ffi.cpp @@ -21,6 +22,7 @@ def main(): + """Build, load, and run inline CPU/CUDA functions.""" mod: Module = tvm_ffi.cpp.load_inline( name="hello", cpp_sources=r""" diff --git a/examples/packaging/python/my_ffi_extension/__init__.py b/examples/packaging/python/my_ffi_extension/__init__.py index d629d635..583945b3 100644 --- a/examples/packaging/python/my_ffi_extension/__init__.py +++ b/examples/packaging/python/my_ffi_extension/__init__.py @@ -15,13 +15,14 @@ # specific language governing permissions and limitations. # order matters here so we need to skip isort here # isort: skip_file +"""Public Python API for the example tvm-ffi extension package.""" from .base import _LIB from . import _ffi_api def add_one(x, y): - """Adds one to the input tensor. + """Add one to the input tensor. Parameters ---------- @@ -35,7 +36,7 @@ def add_one(x, y): def raise_error(msg): - """Raises an error with the given message. + """Raise an error with the given message. Parameters ---------- diff --git a/examples/packaging/python/my_ffi_extension/base.py b/examples/packaging/python/my_ffi_extension/base.py index 2e2d09de..5b1546fc 100644 --- a/examples/packaging/python/my_ffi_extension/base.py +++ b/examples/packaging/python/my_ffi_extension/base.py @@ -14,6 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations. # Base logic to load library for extension package +"""Utilities to locate and load the example extension shared library.""" + import sys from pathlib import Path diff --git a/examples/packaging/run_example.py b/examples/packaging/run_example.py index 5304409a..04650ec4 100644 --- a/examples/packaging/run_example.py +++ b/examples/packaging/run_example.py @@ -14,6 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations. # Base logic to load library for extension package +"""Run functions from the example packaged tvm-ffi extension.""" + import sys import my_ffi_extension @@ -21,6 +23,7 @@ def run_add_one(): + """Invoke add_one from the extension and print the result.""" x = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float32) y = torch.empty_like(x) my_ffi_extension.add_one(x, y) @@ -28,6 +31,7 @@ def run_add_one(): def run_raise_error(): + """Invoke raise_error from the extension to demonstrate error handling.""" my_ffi_extension.raise_error("This is an error") diff --git a/examples/quick_start/run_example.py b/examples/quick_start/run_example.py index 698bc2af..830c3c71 100644 --- a/examples/quick_start/run_example.py +++ b/examples/quick_start/run_example.py @@ -14,6 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +"""Quick start script to run tvm-ffi examples from prebuilt libraries.""" + import tvm_ffi try: @@ -93,7 +95,7 @@ def run_add_one_cuda(): def main(): - """Main function to run the example.""" + """Run the quick start example.""" run_add_one_cpu() run_add_one_c() run_add_one_cuda() diff --git a/pyproject.toml b/pyproject.toml index dff8c028..7783c75e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -134,18 +134,27 @@ select = [ "F", # pyflakes, https://docs.astral.sh/ruff/rules/#pyflakes-f # "ANN", # flake8-annotations, https://docs.astral.sh/ruff/rules/#flake8-annotations-ann "PTH", # flake8-use-pathlib, https://docs.astral.sh/ruff/rules/#flake8-use-pathlib-pth - # "D", # pydocstyle, https://docs.astral.sh/ruff/rules/#pydocstyle-d + "D", # pydocstyle, https://docs.astral.sh/ruff/rules/#pydocstyle-d ] ignore = [ "PLR2004", # pylint: magic-value-comparison "ANN401", # flake8-annotations: any-type + "D203", # pydocstyle: incorrect-blank-line-before-class + "D213", # pydocstyle: multi-line-summary-second-line ] fixable = ["ALL"] unfixable = [] [tool.ruff.lint.per-file-ignores] -"__init__.py" = ["F401"] -"tests/*" = ["E741"] +"__init__.py" = ["F401"] # pyflakes: unused-import +"tests/*" = [ + "E741", # pycodestyle: ambiguous-variable-name + "D100", # pydocstyle: undocumented-public-module + "D101", # pydocstyle: undocumented-public-class + "D103", # pydocstyle: undocumented-public-function + "D107", # pydocstyle: undocumented-public-init + "D205", # pydocstyle: missing-blank-line-after-summary +] [tool.ruff.lint.pylint] max-args = 10 diff --git a/python/tvm_ffi/access_path.py b/python/tvm_ffi/access_path.py index 8a453317..e8aec104 100644 --- a/python/tvm_ffi/access_path.py +++ b/python/tvm_ffi/access_path.py @@ -25,6 +25,8 @@ class AccessKind(IntEnum): + """Kinds of access steps in an access path.""" + ATTR = 0 ARRAY_ITEM = 1 MAP_ITEM = 2 @@ -43,6 +45,7 @@ class AccessPath(core.Object): """Access path container.""" def __init__(self) -> None: + """Disallow direct construction; use `AccessPath.root()` instead.""" super().__init__() raise ValueError( "AccessPath can't be initialized directly. " @@ -55,11 +58,13 @@ def root() -> "AccessPath": return AccessPath._root() def __eq__(self, other: Any) -> bool: + """Return whether two access paths are equal.""" if not isinstance(other, AccessPath): return False return self._path_equal(other) def __ne__(self, other: Any) -> bool: + """Return whether two access paths are not equal.""" if not isinstance(other, AccessPath): return True return not self._path_equal(other) diff --git a/python/tvm_ffi/config.py b/python/tvm_ffi/config.py index 32ad9a51..bc31de43 100644 --- a/python/tvm_ffi/config.py +++ b/python/tvm_ffi/config.py @@ -24,6 +24,7 @@ def find_windows_implib(): + """Find and return the Windows import library path for tvm_ffi.lib.""" libdir = Path(libinfo.find_libtvm_ffi()).parent implib = libdir / "tvm_ffi.lib" if not implib.is_file(): @@ -32,7 +33,7 @@ def find_windows_implib(): def __main__(): # noqa: PLR0912 - """Main function.""" + """Parse CLI args and print build and include configuration paths.""" parser = argparse.ArgumentParser( description="Get various configuration information needed to compile with tvm-ffi" ) diff --git a/python/tvm_ffi/container.py b/python/tvm_ffi/container.py index c77af7f8..9bb9f97a 100644 --- a/python/tvm_ffi/container.py +++ b/python/tvm_ffi/container.py @@ -27,7 +27,7 @@ def getitem_helper(obj, elem_getter, length, idx): - """Helper function to implement a pythonic getitem function. + """Implement a pythonic __getitem__ helper. Parameters ---------- @@ -94,15 +94,19 @@ class Array(core.Object, collections.abc.Sequence): """ def __init__(self, input_list: Sequence[Any]): + """Construct an Array from a Python sequence.""" self.__init_handle_by_constructor__(_ffi_api.Array, *input_list) def __getitem__(self, idx): + """Return one element or a Python list for a slice.""" return getitem_helper(self, _ffi_api.ArrayGetItem, len(self), idx) def __len__(self): + """Return the number of elements in the array.""" return _ffi_api.ArraySize(self) def __repr__(self): + """Return a string representation of the array.""" # exception safety handling for chandle=None if self.__chandle__() == 0: return type(self).__name__ + "(chandle=None)" @@ -203,6 +207,7 @@ class Map(core.Object, collections.abc.Mapping): """ def __init__(self, input_dict: Mapping[Any, Any]): + """Construct a Map from a Python mapping.""" list_kvs = [] for k, v in input_dict.items(): list_kvs.append(k) @@ -210,15 +215,19 @@ def __init__(self, input_dict: Mapping[Any, Any]): self.__init_handle_by_constructor__(_ffi_api.Map, *list_kvs) def __getitem__(self, k): + """Return the value for key `k` or raise KeyError.""" return _ffi_api.MapGetItem(self, k) def __contains__(self, k): + """Return True if the map contains key `k`.""" return _ffi_api.MapCount(self, k) != 0 def keys(self): + """Return a dynamic view of the map's keys.""" return KeysView(self) def values(self): + """Return a dynamic view of the map's values.""" return ValuesView(self) def items(self): @@ -226,9 +235,11 @@ def items(self): return ItemsView(self) def __len__(self): + """Return the number of items in the map.""" return _ffi_api.MapSize(self) def __iter__(self): + """Iterate over the map's keys.""" return iter(self.keys()) def get(self, key, default=None): @@ -251,6 +262,7 @@ def get(self, key, default=None): return self[key] if key in self else default def __repr__(self): + """Return a string representation of the map.""" # exception safety handling for chandle=None if self.__chandle__() == 0: return type(self).__name__ + "(chandle=None)" diff --git a/python/tvm_ffi/cpp/__init__.py b/python/tvm_ffi/cpp/__init__.py index 632698f4..ede2b544 100644 --- a/python/tvm_ffi/cpp/__init__.py +++ b/python/tvm_ffi/cpp/__init__.py @@ -14,5 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +"""C++ integration helpers for building and loading inline modules.""" from .load_inline import load_inline diff --git a/python/tvm_ffi/cpp/load_inline.py b/python/tvm_ffi/cpp/load_inline.py index 0dd8b97a..264a7bb7 100644 --- a/python/tvm_ffi/cpp/load_inline.py +++ b/python/tvm_ffi/cpp/load_inline.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +"""Build and load inline C++/CUDA sources into a tvm_ffi Module using Ninja.""" import functools import hashlib diff --git a/python/tvm_ffi/error.py b/python/tvm_ffi/error.py index 28a1eadc..28788eff 100644 --- a/python/tvm_ffi/error.py +++ b/python/tvm_ffi/error.py @@ -58,6 +58,7 @@ class TracebackManager: """Helper to manage traceback generation.""" def __init__(self): + """Initialize the traceback manager and its cache.""" self._code_cache = {} def _get_cached_code_object(self, filename, lineno, func): @@ -176,7 +177,7 @@ class MyError(RuntimeError): name_or_cls = cls.__name__ def register(mycls): - """Internal register function.""" + """Register the error class name with the FFI core.""" err_name = name_or_cls if isinstance(name_or_cls, str) else mycls.__name__ core.ERROR_NAME_TO_TYPE[err_name] = mycls core.ERROR_TYPE_TO_NAME[mycls] = err_name diff --git a/python/tvm_ffi/libinfo.py b/python/tvm_ffi/libinfo.py index 3e419474..b707f2bf 100644 --- a/python/tvm_ffi/libinfo.py +++ b/python/tvm_ffi/libinfo.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +"""Utilities to locate tvm_ffi libraries, headers, and helper include paths.""" import os import sys @@ -21,7 +22,7 @@ def split_env_var(env_var, split): - """Splits environment variable string. + """Split an environment variable string. Parameters ---------- diff --git a/python/tvm_ffi/module.py b/python/tvm_ffi/module.py index a93cc8af..335e262d 100644 --- a/python/tvm_ffi/module.py +++ b/python/tvm_ffi/module.py @@ -75,7 +75,11 @@ def imports(self): return self.imports_ def implements_function(self, name, query_imports=False): - """Returns True if the module has a definition for the global function with name. Note + """Return True if the module defines a global function. + + Note + ---- + that has_function(name) does not imply get_function(name) is non-null since the module that has_function(name) does not imply get_function(name) is non-null since the module may be, eg, a CSourceModule which cannot supply a packed-func implementation of the function without further compilation. However, get_function(name) non null should always imply @@ -140,11 +144,13 @@ def import_module(self, module): _ffi_api.ModuleImportModule(self, module) def __getitem__(self, name): + """Return function by name using item access (module["func"]).""" if not isinstance(name, str): raise ValueError("Can only take string as function name") return self.get_function(name) def __call__(self, *args): + """Call the module's entry function (`main`).""" # pylint: disable=not-callable return self.main(*args) @@ -180,7 +186,7 @@ def get_property_mask(self): return _ffi_api.ModuleGetPropertyMask(self) def is_binary_serializable(self): - """Module 'binary serializable', save_to_bytes is supported. + """Return whether the module is binary serializable (supports save_to_bytes). Returns ------- @@ -191,7 +197,7 @@ def is_binary_serializable(self): return (self.get_property_mask() & ModulePropertyMask.BINARY_SERIALIZABLE) != 0 def is_runnable(self): - """Module 'runnable', get_function is supported. + """Return whether the module is runnable (supports get_function). Returns ------- @@ -202,7 +208,9 @@ def is_runnable(self): return (self.get_property_mask() & ModulePropertyMask.RUNNABLE) != 0 def is_compilation_exportable(self): - """Module 'compilation exportable', write_to_file is supported for object or source. + """Return whether the module is compilation exportable. + + write_to_file is supported for object or source. Returns ------- diff --git a/python/tvm_ffi/registry.py b/python/tvm_ffi/registry.py index 81960473..60c8dedd 100644 --- a/python/tvm_ffi/registry.py +++ b/python/tvm_ffi/registry.py @@ -47,7 +47,7 @@ class MyObject(Object): object_name = type_key if isinstance(type_key, str) else type_key.__name__ def register(cls): - """Internal register function.""" + """Register the object type with the FFI core.""" type_index = core._object_type_key_to_index(object_name) if type_index is None: if _SKIP_UNKNOWN_OBJECTS: @@ -115,7 +115,7 @@ def echo(x): raise ValueError("expect string function name") def register(myf): - """Internal register function.""" + """Register the global function with the FFI core.""" return core._register_global_func(func_name, myf, override) if f: diff --git a/python/tvm_ffi/stream.py b/python/tvm_ffi/stream.py index e4ae19bd..81cbabed 100644 --- a/python/tvm_ffi/stream.py +++ b/python/tvm_ffi/stream.py @@ -25,7 +25,8 @@ class StreamContext: - """StreamContext represents a stream context in the ffi system. + """Represent a stream context in the FFI system. + StreamContext helps setup ffi environment stream by python `with` statement. When entering `with` scope, it caches the current environment stream and setup the given new stream. @@ -46,16 +47,19 @@ class StreamContext: """ def __init__(self, device: core.Device, stream: Union[int, c_void_p]): + """Initialize a stream context with a device and stream handle.""" self.device_type = device.dlpack_device_type() self.device_id = device.index self.stream = stream def __enter__(self): + """Enter the context and set the current stream.""" self.prev_stream = core._env_set_current_stream( self.device_type, self.device_id, self.stream ) def __exit__(self, *args): + """Exit the context and restore the previous stream.""" self.prev_stream = core._env_set_current_stream( self.device_type, self.device_id, self.prev_stream ) @@ -65,10 +69,14 @@ def __exit__(self, *args): import torch class TorchStreamContext: + """Context manager that syncs Torch and FFI stream contexts.""" + def __init__(self, context: Optional[Any]): + """Initialize with an optional Torch stream/graph context wrapper.""" self.torch_context = context def __enter__(self): + """Enter both Torch and FFI stream contexts.""" if self.torch_context: self.torch_context.__enter__() current_stream = torch.cuda.current_stream() @@ -78,12 +86,14 @@ def __enter__(self): self.ffi_context.__enter__() def __exit__(self, *args): + """Exit both Torch and FFI stream contexts.""" if self.torch_context: self.torch_context.__exit__(*args) self.ffi_context.__exit__(*args) def use_torch_stream(context: Optional[Any] = None): - """Create a ffi stream context with given torch stream, + """Create an FFI stream context with a Torch stream or graph. + cuda graph or current stream if `None` provided. Parameters @@ -118,6 +128,7 @@ def use_torch_stream(context: Optional[Any] = None): except ImportError: def use_torch_stream(context: Optional[Any] = None): + """Raise an informative error when Torch is unavailable.""" raise ImportError("Cannot import torch") diff --git a/python/tvm_ffi/utils/__init__.py b/python/tvm_ffi/utils/__init__.py index 543bd0f8..896001ec 100644 --- a/python/tvm_ffi/utils/__init__.py +++ b/python/tvm_ffi/utils/__init__.py @@ -14,5 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +"""Utilities used by the tvm_ffi Python package.""" from .lockfile import FileLock diff --git a/python/tvm_ffi/utils/lockfile.py b/python/tvm_ffi/utils/lockfile.py index 581ea829..b317f049 100644 --- a/python/tvm_ffi/utils/lockfile.py +++ b/python/tvm_ffi/utils/lockfile.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +"""Simple cross-platform advisory file lock utilities.""" import os import sys @@ -27,18 +28,21 @@ class FileLock: - """A cross-platform file locking mechanism using Python's standard library. + """Provide a cross-platform file locking mechanism using Python's stdlib. + This class implements an advisory lock, which must be respected by all cooperating processes. """ def __init__(self, lock_file_path): + """Initialize a file lock using the given lock file path.""" self.lock_file_path = lock_file_path self._file_descriptor = None def __enter__(self): - """Context manager protocol: acquire the lock upon entering the 'with' block. - This method will block indefinitely until the lock is acquired. + """Acquire the lock upon entering the context. + + This method blocks until the lock is acquired. """ self.blocking_acquire() return self @@ -49,7 +53,8 @@ def __exit__(self, exc_type, exc_val, exc_tb): return False # Propagate exceptions, if any def acquire(self): - """Acquires an exclusive, non-blocking lock on the file. + """Acquire an exclusive, non-blocking lock on the file. + Returns True if the lock was acquired, False otherwise. """ try: @@ -74,7 +79,7 @@ def acquire(self): raise RuntimeError(f"An unexpected error occurred: {e}") def blocking_acquire(self, timeout=None, poll_interval=0.1): - """Waits until an exclusive lock can be acquired, with an optional timeout. + """Wait until an exclusive lock can be acquired, with an optional timeout. Args: timeout (float): The maximum time to wait for the lock in seconds. diff --git a/tests/scripts/benchmark_dlpack.py b/tests/scripts/benchmark_dlpack.py index cfd9986f..bef52843 100644 --- a/tests/scripts/benchmark_dlpack.py +++ b/tests/scripts/benchmark_dlpack.py @@ -14,12 +14,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""This script is used to benchmark the API overhead of different -python FFI API calling overhead, through DLPack API. +"""Benchmark API overhead of different python FFI API calling overhead through DLPack API. -Specifically, we would like to understand the overall overhead -python/C++ API calls. The general goal is to understand the overall -space and get a sense of what are the possible operations. +Specifically, we would like to understand the overall overhead python/C++ API calls. +The general goal is to understand the overall space and get a sense of what are the possible operations. We pick function f(x, y, z) where x, y, z are length 1 tensors. The benchmark is running in eager mode so we can see what is possible. @@ -28,12 +26,8 @@ of what is possible under eager mode. Summary of some takeaways: -- numpy.add roughly takes 0.36 us per call, which gives roughly what can - be done in python env. -- torch.add on gpu takes about 3.7us per call, giving us an idea of what - roughly we need to get to in eager mode. -- - +- numpy.add roughly takes 0.36 us per call, which gives roughly what can be done in python env. +- torch.add on gpu takes about 3.7us per call, giving us an idea of what roughly we need to get to in eager mode. """ import time