diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8f762dcf..329c00c4 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -83,3 +83,30 @@ repos: rev: v0.10.0.1 hooks: - id: shellcheck + - repo: https://github.com/pre-commit/mirrors-mypy + rev: "v1.17.0" + hooks: + - id: mypy + name: mypy for `python/` and `tests/` + additional_dependencies: ['numpy >= 1.22', "ml-dtypes >= 0.1", "pytest", "typing-extensions>=4.5"] + args: [--show-error-codes, --python-version=3.9] + exclude: ^.*/_ffi_api\.py$ + files: ^(python/|tests/).*\.py$ + - id: mypy + name: mypy for `examples/inline_module` + additional_dependencies: ['numpy >= 1.22', "ml-dtypes >= 0.1", "pytest", "typing-extensions>=4.5"] + args: [--show-error-codes, --python-version=3.9] + exclude: ^.*/_ffi_api\.py$ + files: ^examples/inline_module/.*\.py$ + - id: mypy + name: mypy for `examples/packaging` + additional_dependencies: ['numpy >= 1.22', "ml-dtypes >= 0.1", "pytest", "typing-extensions>=4.5"] + args: [--show-error-codes, --python-version=3.9] + exclude: ^.*/_ffi_api\.py$ + files: ^examples/packaging/.*\.py$ + - id: mypy + name: mypy for `examples/quick_start` + additional_dependencies: ['numpy >= 1.22', "ml-dtypes >= 0.1", "pytest", "typing-extensions>=4.5"] + args: [--show-error-codes, --python-version=3.9] + exclude: ^.*/_ffi_api\.py$ + files: ^examples/quick_start/.*\.py$ diff --git a/examples/packaging/python/my_ffi_extension/__init__.py b/examples/packaging/python/my_ffi_extension/__init__.py index ae4abfda..0c2b0fd0 100644 --- a/examples/packaging/python/my_ffi_extension/__init__.py +++ b/examples/packaging/python/my_ffi_extension/__init__.py @@ -51,4 +51,4 @@ def raise_error(msg: str) -> None: The error raised by the function. """ - return _ffi_api.raise_error(msg) + return _ffi_api.raise_error(msg) # type: ignore[attr-defined] diff --git a/examples/quick_start/run_example.py b/examples/quick_start/run_example.py index 87c9507c..2e2b7f3c 100644 --- a/examples/quick_start/run_example.py +++ b/examples/quick_start/run_example.py @@ -16,15 +16,9 @@ # under the License. """Quick start script to run tvm-ffi examples from prebuilt libraries.""" -import tvm_ffi - -try: - import torch -except ImportError: - torch = None - - import numpy +import torch +import tvm_ffi def run_add_one_cpu() -> None: @@ -40,9 +34,6 @@ def run_add_one_cpu() -> None: print("numpy.result after add_one(x, y)") print(x) - if torch is None: - return - x = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float32) y = torch.empty_like(x) # tvm-ffi automatically handles DLPack compatible tensors @@ -63,9 +54,6 @@ def run_add_one_c() -> None: print("numpy.result after add_one_c(x, y)") print(x) - if torch is None: - return - x = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float32) y = torch.empty_like(x) mod.add_one_c(x, y) @@ -75,7 +63,7 @@ def run_add_one_c() -> None: def run_add_one_cuda() -> None: """Load the add_one_cuda module and call the add_one_cuda function.""" - if torch is None or not torch.cuda.is_available(): + if not torch.cuda.is_available(): return mod = tvm_ffi.load_module("build/add_one_cuda.so") diff --git a/pyproject.toml b/pyproject.toml index d2fd8978..58d394ac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -213,3 +213,7 @@ environment = { MACOSX_DEPLOYMENT_TARGET = "10.14" } [tool.cibuildwheel.windows] archs = ["AMD64"] + +[tool.mypy] +allow_redefinition = true +ignore_missing_imports = true diff --git a/python/tvm_ffi/_convert.py b/python/tvm_ffi/_convert.py index 6313c8e0..05e99da8 100644 --- a/python/tvm_ffi/_convert.py +++ b/python/tvm_ffi/_convert.py @@ -16,20 +16,25 @@ # under the License. """Conversion utilities to bring python objects into ffi values.""" +from __future__ import annotations + from numbers import Number +from types import ModuleType from typing import Any from . import container, core +torch: ModuleType | None = None try: - import torch + import torch # type: ignore[no-redef] except ImportError: - torch = None + pass +numpy: ModuleType | None = None try: import numpy except ImportError: - numpy = None + pass def convert(value: Any) -> Any: # noqa: PLR0911,PLR0912 diff --git a/python/tvm_ffi/_dtype.py b/python/tvm_ffi/_dtype.py index ba1735ff..83065935 100644 --- a/python/tvm_ffi/_dtype.py +++ b/python/tvm_ffi/_dtype.py @@ -59,6 +59,7 @@ class dtype(str): """ __slots__ = ["__tvm_ffi_dtype__"] + __tvm_ffi_dtype__: core.DataType _NUMPY_DTYPE_TO_STR: ClassVar[dict[Any, str]] = {} diff --git a/python/tvm_ffi/_ffi_api.pyi b/python/tvm_ffi/_ffi_api.pyi new file mode 100644 index 00000000..95059e5b --- /dev/null +++ b/python/tvm_ffi/_ffi_api.pyi @@ -0,0 +1,43 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""FFI API.""" + +from typing import Any + +def ModuleGetKind(*args: Any) -> Any: ... +def ModuleImplementsFunction(*args: Any) -> Any: ... +def ModuleGetFunction(*args: Any) -> Any: ... +def ModuleImportModule(*args: Any) -> Any: ... +def ModuleInspectSource(*args: Any) -> Any: ... +def ModuleGetWriteFormats(*args: Any) -> Any: ... +def ModuleGetPropertyMask(*args: Any) -> Any: ... +def ModuleClearImports(*args: Any) -> Any: ... +def ModuleWriteToFile(*args: Any) -> Any: ... +def ModuleLoadFromFile(*args: Any) -> Any: ... +def SystemLib(*args: Any) -> Any: ... +def Array(*args: Any) -> Any: ... +def ArrayGetItem(*args: Any) -> Any: ... +def ArraySize(*args: Any) -> Any: ... +def MapForwardIterFunctor(*args: Any) -> Any: ... +def Map(*args: Any) -> Any: ... +def MapGetItem(*args: Any) -> Any: ... +def MapCount(*args: Any) -> Any: ... +def MapSize(*args: Any) -> Any: ... +def MakeObjectFromPackedArgs(*args: Any) -> Any: ... +def ToJSONGraphString(*args: Any) -> Any: ... +def FromJSONGraphString(*args: Any) -> Any: ... +def Shape(*args: Any) -> Any: ... diff --git a/python/tvm_ffi/_optional_torch_c_dlpack.py b/python/tvm_ffi/_optional_torch_c_dlpack.py index dd820a99..5be7211e 100644 --- a/python/tvm_ffi/_optional_torch_c_dlpack.py +++ b/python/tvm_ffi/_optional_torch_c_dlpack.py @@ -31,12 +31,12 @@ """ import warnings -from typing import Any, Optional +from typing import Any from . import libinfo -def load_torch_c_dlpack_extension() -> Optional[Any]: +def load_torch_c_dlpack_extension() -> Any: """Load the torch c dlpack extension.""" cpp_source = """ #include @@ -556,9 +556,9 @@ def load_torch_c_dlpack_extension() -> Optional[Any]: extra_include_paths=include_paths, ) # set the dlpack related flags - torch.Tensor.__c_dlpack_from_pyobject__ = mod.TorchDLPackFromPyObjectPtr() - torch.Tensor.__c_dlpack_to_pyobject__ = mod.TorchDLPackToPyObjectPtr() - torch.Tensor.__c_dlpack_tensor_allocator__ = mod.TorchDLPackTensorAllocatorPtr() + setattr(torch.Tensor, "__c_dlpack_from_pyobject__", mod.TorchDLPackFromPyObjectPtr()) + setattr(torch.Tensor, "__c_dlpack_to_pyobject__", mod.TorchDLPackToPyObjectPtr()) + setattr(torch.Tensor, "__c_dlpack_tensor_allocator__", mod.TorchDLPackTensorAllocatorPtr()) return mod except ImportError: pass @@ -566,7 +566,7 @@ def load_torch_c_dlpack_extension() -> Optional[Any]: warnings.warn( f"Failed to load torch c dlpack extension: {e},EnvTensorAllocator will not be enabled." ) - return None + return None # keep alive diff --git a/python/tvm_ffi/_tensor.py b/python/tvm_ffi/_tensor.py index 0cc09f13..0d44994f 100644 --- a/python/tvm_ffi/_tensor.py +++ b/python/tvm_ffi/_tensor.py @@ -15,11 +15,13 @@ # specific language governing permissions and limitations # under the License. """Tensor related objects and functions.""" + +from __future__ import annotations + # we name it as _tensor.py to avoid potential future case # if we also want to expose a tensor function in the root namespace - from numbers import Integral -from typing import Any, Optional, Union +from typing import Any from . import _ffi_api, core, registry from .core import ( @@ -43,23 +45,25 @@ class Shape(tuple, PyNativeObject): """ - def __new__(cls, content: tuple[int, ...]) -> "Shape": + __tvm_ffi_object__: Any + + def __new__(cls, content: tuple[int, ...]) -> Shape: if any(not isinstance(x, Integral) for x in content): raise ValueError("Shape must be a tuple of integers") - val = tuple.__new__(cls, content) + val: Shape = tuple.__new__(cls, content) val.__init_tvm_ffi_object_by_constructor__(_ffi_api.Shape, *content) return val # pylint: disable=no-self-argument - def __from_tvm_ffi_object__(cls, obj: Any) -> "Shape": + def __from_tvm_ffi_object__(cls, obj: Any) -> Shape: """Construct from a given tvm object.""" content = _shape_obj_get_py_tuple(obj) - val = tuple.__new__(cls, content) - val.__tvm_ffi_object__ = obj + val: Shape = tuple.__new__(cls, content) # type: ignore[arg-type] + val.__tvm_ffi_object__ = obj # type: ignore[attr-defined] return val -def device(device_type: Union[str, int, DLDeviceType], index: Optional[int] = None) -> Device: +def device(device_type: str | int | DLDeviceType, index: int | None = None) -> Device: """Construct a TVM FFI device with given device type and index. Parameters diff --git a/python/tvm_ffi/access_path.py b/python/tvm_ffi/access_path.py index e8aec104..aa52d58f 100644 --- a/python/tvm_ffi/access_path.py +++ b/python/tvm_ffi/access_path.py @@ -39,11 +39,16 @@ class AccessKind(IntEnum): class AccessStep(core.Object): """Access step container.""" + kind: AccessKind + key: Any + @register_object("ffi.reflection.AccessPath") class AccessPath(core.Object): """Access path container.""" + parent: "AccessPath" + def __init__(self) -> None: """Disallow direct construction; use `AccessPath.root()` instead.""" super().__init__() @@ -55,19 +60,19 @@ def __init__(self) -> None: @staticmethod def root() -> "AccessPath": """Create a root access path.""" - return AccessPath._root() + return AccessPath._root() # type: ignore[attr-defined] def __eq__(self, other: Any) -> bool: """Return whether two access paths are equal.""" if not isinstance(other, AccessPath): return False - return self._path_equal(other) + return self._path_equal(other) # type: ignore[attr-defined] def __ne__(self, other: Any) -> bool: """Return whether two access paths are not equal.""" if not isinstance(other, AccessPath): return True - return not self._path_equal(other) + return not self._path_equal(other) # type: ignore[attr-defined] def is_prefix_of(self, other: "AccessPath") -> bool: """Check if this access path is a prefix of another access path. @@ -83,7 +88,7 @@ def is_prefix_of(self, other: "AccessPath") -> bool: True if this access path is a prefix of the other access path, False otherwise """ - return self._is_prefix_of(other) + return self._is_prefix_of(other) # type: ignore[attr-defined] def attr(self, attr_key: str) -> "AccessPath": """Create an access path to the attribute of the current object. @@ -99,7 +104,7 @@ def attr(self, attr_key: str) -> "AccessPath": The extended access path """ - return self._attr(attr_key) + return self._attr(attr_key) # type: ignore[attr-defined] def attr_missing(self, attr_key: str) -> "AccessPath": """Create an access path that indicate an attribute is missing. @@ -115,7 +120,7 @@ def attr_missing(self, attr_key: str) -> "AccessPath": The extended access path """ - return self._attr_missing(attr_key) + return self._attr_missing(attr_key) # type: ignore[attr-defined] def array_item(self, index: int) -> "AccessPath": """Create an access path to the item of the current array. @@ -131,7 +136,7 @@ def array_item(self, index: int) -> "AccessPath": The extended access path """ - return self._array_item(index) + return self._array_item(index) # type: ignore[attr-defined] def array_item_missing(self, index: int) -> "AccessPath": """Create an access path that indicate an array item is missing. @@ -147,7 +152,7 @@ def array_item_missing(self, index: int) -> "AccessPath": The extended access path """ - return self._array_item_missing(index) + return self._array_item_missing(index) # type: ignore[attr-defined] def map_item(self, key: Any) -> "AccessPath": """Create an access path to the item of the current map. @@ -163,7 +168,7 @@ def map_item(self, key: Any) -> "AccessPath": The extended access path """ - return self._map_item(key) + return self._map_item(key) # type: ignore[attr-defined] def map_item_missing(self, key: Any) -> "AccessPath": """Create an access path that indicate a map item is missing. @@ -179,7 +184,7 @@ def map_item_missing(self, key: Any) -> "AccessPath": The extended access path """ - return self._map_item_missing(key) + return self._map_item_missing(key) # type: ignore[attr-defined] def to_steps(self) -> list["AccessStep"]: """Convert the access path to a list of access steps. @@ -190,6 +195,6 @@ def to_steps(self) -> list["AccessStep"]: The list of access steps """ - return self._to_steps() + return self._to_steps() # type: ignore[attr-defined] __hash__ = core.Object.__hash__ diff --git a/python/tvm_ffi/container.py b/python/tvm_ffi/container.py index 008dda93..6f29dfdf 100644 --- a/python/tvm_ffi/container.py +++ b/python/tvm_ffi/container.py @@ -16,6 +16,8 @@ # under the License. """Container classes.""" +from __future__ import annotations + import collections.abc from collections.abc import Iterator, Mapping, Sequence from typing import Any, Callable @@ -121,7 +123,7 @@ def __repr__(self) -> str: class KeysView(collections.abc.KeysView): """Helper class to return keys view.""" - def __init__(self, backend_map: "Map") -> None: + def __init__(self, backend_map: Map) -> None: self._backend_map = backend_map def __len__(self) -> int: @@ -144,7 +146,7 @@ def __contains__(self, k: Any) -> bool: class ValuesView(collections.abc.ValuesView): """Helper class to return values view.""" - def __init__(self, backend_map: "Map") -> None: + def __init__(self, backend_map: Map) -> None: self._backend_map = backend_map def __len__(self) -> int: @@ -164,7 +166,7 @@ def __iter__(self) -> Iterator[Any]: class ItemsView(collections.abc.ItemsView): """Helper class to return items view.""" - def __init__(self, backend_map: "Map") -> None: + def __init__(self, backend_map: Map) -> None: self.backend_map = backend_map def __len__(self) -> int: @@ -231,7 +233,7 @@ def keys(self) -> KeysView: """Return a dynamic view of the map's keys.""" return KeysView(self) - def values(self) -> "ValuesView": + def values(self) -> ValuesView: """Return a dynamic view of the map's values.""" return ValuesView(self) diff --git a/python/tvm_ffi/core.pyi b/python/tvm_ffi/core.pyi index cfb3ea9d..7c623afd 100644 --- a/python/tvm_ffi/core.pyi +++ b/python/tvm_ffi/core.pyi @@ -19,6 +19,7 @@ from __future__ import annotations import types +from ctypes import c_void_p from enum import IntEnum from typing import Any, Callable @@ -28,7 +29,6 @@ ERROR_TYPE_TO_NAME: dict[type, str] _WITH_APPEND_BACKTRACE: Callable[[BaseException, str], BaseException] | None _TRACEBACK_TO_BACKTRACE_STR: Callable[[types.TracebackType | None], str] | None - # DLPack protocol version (defined in tensor.pxi) __dlpack_version__: tuple[int, int] @@ -44,7 +44,7 @@ class Object: def __eq__(self, other: Any) -> bool: ... def __ne__(self, other: Any) -> bool: ... def __hash__(self) -> int: ... - def __init_handle_by_constructor__(self, fconstructor: Function, *args: Any) -> None: ... + def __init_handle_by_constructor__(self, fconstructor: Any, *args: Any) -> None: ... def __ffi_init__(self, *args: Any) -> None: """Initialize the instance using the ` __init__` method registered on C++ side. @@ -78,9 +78,7 @@ class PyNativeObject: """Base class of all TVM objects that also subclass python's builtin types.""" __slots__: list[str] - def __init_tvm_ffi_object_by_constructor__( - self, fconstructor: Function, *args: Any - ) -> None: ... + def __init_tvm_ffi_object_by_constructor__(self, fconstructor: Any, *args: Any) -> None: ... def _set_class_object(cls: type) -> None: ... def _register_object_by_index(type_index: int, type_cls: type) -> TypeInfo: ... @@ -101,7 +99,9 @@ class Error(Object): def backtrace(self) -> str: ... def _convert_to_ffi_error(error: BaseException) -> Error: ... -def _env_set_current_stream(device_type: int, device_id: int, stream: int) -> int: ... +def _env_set_current_stream( + device_type: int, device_id: int, stream: int | c_void_p +) -> int | c_void_p: ... class DataType: """DataType wrapper around DLDataType.""" @@ -121,6 +121,8 @@ class DataType: def __str__(self) -> str: ... def _set_class_dtype(cls: type) -> None: ... +def _convert_torch_dtype_to_ffi_dtype(torch_dtype: Any) -> DataType: ... +def _convert_numpy_dtype_to_ffi_dtype(numpy_dtype: Any) -> DataType: ... def _create_dtype_from_tuple(cls: type[DataType], code: int, bits: int, lanes: int) -> DataType: ... class DLDeviceType(IntEnum): @@ -185,6 +187,9 @@ class Tensor(Object): copy: bool | None = None, ) -> Any: ... +_CLASS_TENSOR: type[Tensor] = Tensor + +def _set_class_tensor(cls: type[Tensor]) -> None: ... def from_dlpack( ext_tensor: Any, *, require_alignment: int = ..., require_contiguous: bool = ... ) -> Tensor: ... diff --git a/python/tvm_ffi/cpp/load_inline.py b/python/tvm_ffi/cpp/load_inline.py index 2c1caf5a..d7a5c143 100644 --- a/python/tvm_ffi/cpp/load_inline.py +++ b/python/tvm_ffi/cpp/load_inline.py @@ -16,6 +16,8 @@ # under the License. """Build and load inline C++/CUDA sources into a tvm_ffi Module using Ninja.""" +from __future__ import annotations + import functools import hashlib import os @@ -24,7 +26,6 @@ import sys from collections.abc import Mapping, Sequence from pathlib import Path -from typing import Optional from tvm_ffi.libinfo import find_dlpack_include_path, find_include_path, find_libtvm_ffi from tvm_ffi.module import Module, load_module @@ -77,7 +78,7 @@ def _maybe_write(path: str, content: str) -> None: @functools.lru_cache -def _find_cuda_home() -> Optional[str]: +def _find_cuda_home() -> str: """Find the CUDA install path.""" # Guess #1 cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") @@ -92,9 +93,10 @@ def _find_cuda_home() -> Optional[str]: cuda_root = Path("C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA") cuda_homes = list(cuda_root.glob("v*.*")) if len(cuda_homes) == 0: - cuda_home = "" - else: - cuda_home = str(cuda_homes[0]) + raise RuntimeError( + "Could not find CUDA installation. Please set CUDA_HOME environment variable." + ) + cuda_home = str(cuda_homes[0]) else: cuda_home = "/usr/local/cuda" if not Path(cuda_home).exists(): @@ -358,17 +360,17 @@ def _decorate_with_tvm_ffi(source: str, functions: Mapping[str, str]) -> str: return "\n".join(sources) -def load_inline( +def load_inline( # noqa: PLR0912, PLR0915 name: str, *, - cpp_sources: str | None = None, - cuda_sources: str | None = None, - functions: Sequence[str] | None = None, + cpp_sources: Sequence[str] | str | None = None, + cuda_sources: Sequence[str] | str | None = None, + functions: Mapping[str, str] | Sequence[str] | str | None = None, extra_cflags: Sequence[str] | None = None, extra_cuda_cflags: Sequence[str] | None = None, extra_ldflags: Sequence[str] | None = None, extra_include_paths: Sequence[str] | None = None, - build_directory: Optional[str] = None, + build_directory: str | None = None, ) -> Module: """Compile and load a C++/CUDA module from inline source code. @@ -481,76 +483,83 @@ def load_inline( """ if cpp_sources is None: - cpp_sources = [] + cpp_source_list: list[str] = [] elif isinstance(cpp_sources, str): - cpp_sources = [cpp_sources] - cpp_source = "\n".join(cpp_sources) + cpp_source_list = [cpp_sources] + else: + cpp_source_list = list(cpp_sources) + cpp_source = "\n".join(cpp_source_list) + with_cpp = bool(cpp_source_list) + del cpp_source_list + if cuda_sources is None: - cuda_sources = [] + cuda_source_list: list[str] = [] elif isinstance(cuda_sources, str): - cuda_sources = [cuda_sources] - cuda_source = "\n".join(cuda_sources) - with_cpp = len(cpp_sources) > 0 - with_cuda = len(cuda_sources) > 0 + cuda_source_list = [cuda_sources] + else: + cuda_source_list = list(cuda_sources) + cuda_source = "\n".join(cuda_source_list) + with_cuda = bool(cuda_source_list) + del cuda_source_list - extra_ldflags = extra_ldflags or [] - extra_cflags = extra_cflags or [] - extra_cuda_cflags = extra_cuda_cflags or [] - extra_include_paths = extra_include_paths or [] + extra_ldflags_list = list(extra_ldflags) if extra_ldflags is not None else [] + extra_cflags_list = list(extra_cflags) if extra_cflags is not None else [] + extra_cuda_cflags_list = list(extra_cuda_cflags) if extra_cuda_cflags is not None else [] + extra_include_paths_list = list(extra_include_paths) if extra_include_paths is not None else [] # add function registration code to sources - if isinstance(functions, str): - functions = {functions: ""} - elif isinstance(functions, Sequence): - functions = {name: "" for name in functions} + if functions is None: + function_map: dict[str, str] = {} + elif isinstance(functions, str): + function_map = {functions: ""} + elif isinstance(functions, Mapping): + function_map = dict(functions) + else: + function_map = {name: "" for name in functions} if with_cpp: - cpp_source = _decorate_with_tvm_ffi(cpp_source, functions) + cpp_source = _decorate_with_tvm_ffi(cpp_source, function_map) cuda_source = _decorate_with_tvm_ffi(cuda_source, {}) else: cpp_source = _decorate_with_tvm_ffi(cpp_source, {}) - cuda_source = _decorate_with_tvm_ffi(cuda_source, functions) + cuda_source = _decorate_with_tvm_ffi(cuda_source, function_map) # determine the cache dir for the built module + build_dir: Path if build_directory is None: - build_directory = os.environ.get( - "TVM_FFI_CACHE_DIR", str(Path("~/.cache/tvm-ffi").expanduser()) - ) + cache_dir = os.environ.get("TVM_FFI_CACHE_DIR", str(Path("~/.cache/tvm-ffi").expanduser())) source_hash: str = _hash_sources( cpp_source, cuda_source, - functions, - extra_cflags, - extra_cuda_cflags, - extra_ldflags, - extra_include_paths, + function_map, + extra_cflags_list, + extra_cuda_cflags_list, + extra_ldflags_list, + extra_include_paths_list, ) - build_dir: str = str(Path(build_directory) / f"{name}_{source_hash}") + build_dir = Path(cache_dir).expanduser() / f"{name}_{source_hash}" else: - build_dir = str(Path(build_directory).resolve()) - Path(build_dir).mkdir(parents=True, exist_ok=True) + build_dir = Path(build_directory).resolve() + build_dir.mkdir(parents=True, exist_ok=True) # generate build.ninja ninja_source = _generate_ninja_build( name=name, - build_dir=build_dir, + build_dir=str(build_dir), with_cuda=with_cuda, - extra_cflags=extra_cflags, - extra_cuda_cflags=extra_cuda_cflags, - extra_ldflags=extra_ldflags, - extra_include_paths=extra_include_paths, + extra_cflags=extra_cflags_list, + extra_cuda_cflags=extra_cuda_cflags_list, + extra_ldflags=extra_ldflags_list, + extra_include_paths=extra_include_paths_list, ) - - with FileLock(str(Path(build_dir) / "lock")): + with FileLock(str(build_dir / "lock")): # write source files and build.ninja if they do not already exist - _maybe_write(str(Path(build_dir) / "main.cpp"), cpp_source) + _maybe_write(str(build_dir / "main.cpp"), cpp_source) if with_cuda: - _maybe_write(str(Path(build_dir) / "cuda.cu"), cuda_source) - _maybe_write(str(Path(build_dir) / "build.ninja"), ninja_source) - + _maybe_write(str(build_dir / "cuda.cu"), cuda_source) + _maybe_write(str(build_dir / "build.ninja"), ninja_source) # build the module - _build_ninja(build_dir) - + _build_ninja(str(build_dir)) # Use appropriate extension based on platform ext = ".dll" if IS_WINDOWS else ".so" - return load_module(str((Path(build_dir) / f"{name}{ext}").resolve())) + return load_module(str((build_dir / f"{name}{ext}").resolve())) diff --git a/python/tvm_ffi/cython/type_info.pxi b/python/tvm_ffi/cython/type_info.pxi index 4ab9f15a..fcd443b6 100644 --- a/python/tvm_ffi/cython/type_info.pxi +++ b/python/tvm_ffi/cython/type_info.pxi @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. import dataclasses +from typing import Optional, Any cdef class FieldGetter: @@ -62,13 +63,13 @@ class TypeField: """Description of a single reflected field on an FFI-backed type.""" name: str - doc: str | None + doc: Optional[str] size: int offset: int frozen: bool getter: FieldGetter setter: FieldSetter - dataclass_field: object | None = None + dataclass_field: Any = None def __post_init__(self): assert self.setter is not None @@ -96,7 +97,7 @@ class TypeMethod: """Description of a single reflected method on an FFI-backed type.""" name: str - doc: str | None + doc: Optional[str] func: object is_static: bool @@ -105,9 +106,9 @@ class TypeMethod: class TypeInfo: """Aggregated type information required to build a proxy class.""" - type_cls: type | None + type_cls: Optional[type] type_index: int type_key: str fields: list[TypeField] methods: list[TypeMethod] - parent_type_info: TypeInfo | None + parent_type_info: Optional[TypeInfo] diff --git a/python/tvm_ffi/dataclasses/_utils.py b/python/tvm_ffi/dataclasses/_utils.py index ef7c7e4c..7a28e01a 100644 --- a/python/tvm_ffi/dataclasses/_utils.py +++ b/python/tvm_ffi/dataclasses/_utils.py @@ -21,7 +21,7 @@ import functools import inspect from dataclasses import MISSING -from typing import Any, Callable, NamedTuple, TypeVar +from typing import Any, Callable, NamedTuple, TypeVar, cast from ..core import ( Object, @@ -68,7 +68,7 @@ def type_info_to_cls( attrs[field.name] = field.as_property(cls) # Step 3. Add methods - def _add_method(name: str, func: Callable) -> None: + def _add_method(name: str, func: Callable[..., Any]) -> None: if name == "__ffi_init__": name = "__c_ffi_init__" if name in attrs: # already defined @@ -80,9 +80,9 @@ def _add_method(name: str, func: Callable) -> None: attrs[name] = func setattr(cls, name, func) - for name, method in methods.items(): - if method is not None: - _add_method(name, method) + for name, method_impl in methods.items(): + if method_impl is not None: + _add_method(name, method_impl) for method in type_info.methods: _add_method(method.name, method.func) @@ -90,7 +90,7 @@ def _add_method(name: str, func: Callable) -> None: new_cls = type(cls.__name__, cls_bases, attrs) new_cls.__module__ = cls.__module__ new_cls = functools.wraps(cls, updated=())(new_cls) # type: ignore - return new_cls + return cast(type[_InputClsType], new_cls) def fill_dataclass_field(type_cls: type, type_field: TypeField) -> None: @@ -123,15 +123,12 @@ class DefaultFactory(NamedTuple): fn: Callable[[], Any] - fields: list[TypeInfo] = [] - cur_type_info = type_info - while True: + fields: list[TypeField] = [] + cur_type_info: TypeInfo | None = type_info + while cur_type_info is not None: fields.extend(reversed(cur_type_info.fields)) cur_type_info = cur_type_info.parent_type_info - if cur_type_info is None: - break fields.reverse() - del cur_type_info annotations: dict[str, Any] = {"return": None} # Step 1. Split the parameters into two groups to ensure that @@ -187,7 +184,7 @@ def bind_args(*args: Any, **kwargs: Any) -> tuple[Any, ...]: else: raise ValueError(f"Cannot find constructor method: `{type_info.type_key}.__ffi_init__`") - def __init__(self: type, *args: Any, **kwargs: Any) -> None: + def __init__(self: Any, *args: Any, **kwargs: Any) -> None: e = None try: args = bind_args(*args, **kwargs) diff --git a/python/tvm_ffi/dataclasses/c_class.py b/python/tvm_ffi/dataclasses/c_class.py index 7507b76c..700f2872 100644 --- a/python/tvm_ffi/dataclasses/c_class.py +++ b/python/tvm_ffi/dataclasses/c_class.py @@ -28,18 +28,21 @@ from typing import ClassVar, TypeVar, get_origin, get_type_hints from ..core import TypeField, TypeInfo -from . import _utils, field +from . import _utils +from .field import Field, field try: - from typing import dataclass_transform + from typing_extensions import dataclass_transform # type: ignore[attr-defined] except ImportError: - from typing_extensions import dataclass_transform + from typing import dataclass_transform # type: ignore[no-redef,attr-defined] +except ImportError: + pass _InputClsType = TypeVar("_InputClsType") -@dataclass_transform(field_specifiers=(field.field, field.Field)) +@dataclass_transform(field_specifiers=(field, Field)) def c_class( type_key: str, init: bool = True ) -> Callable[[type[_InputClsType]], type[_InputClsType]]: @@ -157,7 +160,7 @@ def _inspect_c_class_fields(type_cls: type, type_info: TypeInfo) -> list[TypeFie for field_name, _field_ty_py in type_hints_py.items(): if field_name.startswith("__tvm_ffi"): # TVM's private fields - skip continue - type_field: TypeField = type_fields_cxx.pop(field_name, None) + type_field = type_fields_cxx.pop(field_name, None) if type_field is None: raise ValueError( f"Extraneous field `{type_cls}.{field_name}`. Defined in Python but not in C++" diff --git a/python/tvm_ffi/dataclasses/field.py b/python/tvm_ffi/dataclasses/field.py index 00170e5e..f1a582ea 100644 --- a/python/tvm_ffi/dataclasses/field.py +++ b/python/tvm_ffi/dataclasses/field.py @@ -18,11 +18,10 @@ from __future__ import annotations -from dataclasses import MISSING, dataclass +from dataclasses import MISSING from typing import Any, Callable -@dataclass(kw_only=True) class Field: """(Experimental) Descriptor placeholder returned by :func:`tvm_ffi.dataclasses.field`. @@ -36,8 +35,12 @@ class Field: way the decorator understands. """ - name: str | None = None - default_factory: Callable[[], Any] + __slots__ = ("default_factory", "name") + + def __init__(self, *, name: str | None = None, default_factory: Callable[[], Any]) -> None: + """Do not call directly; use :func:`field` instead.""" + self.name = name + self.default_factory = default_factory def field(*, default: Any = MISSING, default_factory: Any = MISSING) -> Field: diff --git a/python/tvm_ffi/error.py b/python/tvm_ffi/error.py index f29fc909..fd0bf2bc 100644 --- a/python/tvm_ffi/error.py +++ b/python/tvm_ffi/error.py @@ -17,11 +17,13 @@ # pylint: disable=invalid-name """Error handling.""" +from __future__ import annotations + import ast import re import sys import types -from typing import Any, Optional +from typing import Any from . import core @@ -60,7 +62,7 @@ class TracebackManager: def __init__(self) -> None: """Initialize the traceback manager and its cache.""" - self._code_cache = {} + self._code_cache: dict[tuple[str, int, str], types.CodeType] = {} def _get_cached_code_object(self, filename: str, lineno: int, func: str) -> types.CodeType: # Hack to create a code object that points to the correct @@ -95,7 +97,7 @@ def _create_frame(self, filename: str, lineno: int, func: str) -> types.FrameTyp def append_traceback( self, - tb: Optional[types.TracebackType], + tb: types.TracebackType | None, filename: str, lineno: int, func: str, @@ -134,7 +136,7 @@ def _with_append_backtrace(py_error: BaseException, backtrace: str) -> BaseExcep return py_error.with_traceback(tb) -def _traceback_to_backtrace_str(tb: Optional[types.TracebackType]) -> str: +def _traceback_to_backtrace_str(tb: types.TracebackType | None) -> str: """Convert the traceback to a string.""" lines = [] while tb is not None: @@ -155,7 +157,7 @@ def _traceback_to_backtrace_str(tb: Optional[types.TracebackType]) -> str: def register_error( name_or_cls: str | type | None = None, - cls: Optional[type] = None, + cls: type | None = None, ) -> Any: """Register an error class so it can be recognized by the ffi error handler. diff --git a/python/tvm_ffi/libinfo.py b/python/tvm_ffi/libinfo.py index 382690b1..8d92df3d 100644 --- a/python/tvm_ffi/libinfo.py +++ b/python/tvm_ffi/libinfo.py @@ -46,25 +46,24 @@ def split_env_var(env_var: str, split: str) -> list[str]: def get_dll_directories() -> list[str]: """Get the possible dll directories.""" ffi_dir = Path(__file__).expanduser().resolve().parent - dll_path = [ffi_dir / "lib"] - dll_path += [ffi_dir / ".." / ".." / "build" / "lib"] + dll_path: list[Path] = [ffi_dir / "lib"] + dll_path.append(ffi_dir / ".." / ".." / "build" / "lib") # in source build from parent if needed - dll_path += [ffi_dir / ".." / ".." / ".." / "build" / "lib"] - + dll_path.append(ffi_dir / ".." / ".." / ".." / "build" / "lib") if sys.platform.startswith("linux") or sys.platform.startswith("freebsd"): - dll_path.extend(split_env_var("LD_LIBRARY_PATH", ":")) - dll_path.extend(split_env_var("PATH", ":")) + dll_path.extend(Path(p) for p in split_env_var("LD_LIBRARY_PATH", ":")) + dll_path.extend(Path(p) for p in split_env_var("PATH", ":")) elif sys.platform.startswith("darwin"): - dll_path.extend(split_env_var("DYLD_LIBRARY_PATH", ":")) - dll_path.extend(split_env_var("PATH", ":")) + dll_path.extend(Path(p) for p in split_env_var("DYLD_LIBRARY_PATH", ":")) + dll_path.extend(Path(p) for p in split_env_var("PATH", ":")) elif sys.platform.startswith("win32"): - dll_path.extend(split_env_var("PATH", ";")) - return [str(Path(x).resolve()) for x in dll_path if Path(x).is_dir()] + dll_path.extend(Path(p) for p in split_env_var("PATH", ";")) + return [str(path.resolve()) for path in dll_path if path.is_dir()] def find_libtvm_ffi() -> str: """Find libtvm_ffi.""" - dll_path = get_dll_directories() + dll_path = [Path(p) for p in get_dll_directories()] if sys.platform.startswith("win32"): lib_dll_names = ["tvm_ffi.dll"] elif sys.platform.startswith("darwin"): @@ -72,14 +71,18 @@ def find_libtvm_ffi() -> str: else: lib_dll_names = ["libtvm_ffi.so"] - name = lib_dll_names - lib_dll_path = [str(Path(p) / name) for name in lib_dll_names for p in dll_path] - lib_found = [p for p in lib_dll_path if Path(p).exists() and Path(p).is_file()] + lib_dll_path = [p / name for name in lib_dll_names for p in dll_path] + lib_found = [p for p in lib_dll_path if p.exists() and p.is_file()] if not lib_found: - raise RuntimeError(f"Cannot find library: {name}\nList of candidates:\n{lib_dll_path}") - - return lib_found[0] + candidate_list = "\n".join(str(p) for p in lib_dll_path) + raise RuntimeError( + "Cannot find library: {}\nList of candidates:\n{}".format( + ", ".join(lib_dll_names), candidate_list + ) + ) + + return str(lib_found[0]) def find_source_path() -> str: diff --git a/python/tvm_ffi/module.py b/python/tvm_ffi/module.py index acdc11ef..768463df 100644 --- a/python/tvm_ffi/module.py +++ b/python/tvm_ffi/module.py @@ -73,7 +73,7 @@ def imports(self) -> list["Module"]: The module """ - return self.imports_ + return self.imports_ # type: ignore[return-value] def implements_function(self, name: str, query_imports: bool = False) -> bool: """Return True if the module defines a global function. @@ -255,7 +255,7 @@ def system_lib(symbol_prefix: str = "") -> Module: Parameters ---------- - symbol_prefix: Optional[str] + symbol_prefix: str = "" Optional symbol prefix that can be used for search. When we lookup a symbol symbol_prefix + name will first be searched, then the name without symbol_prefix. diff --git a/python/tvm_ffi/registry.py b/python/tvm_ffi/registry.py index 6bf08f64..3ef40395 100644 --- a/python/tvm_ffi/registry.py +++ b/python/tvm_ffi/registry.py @@ -16,8 +16,10 @@ # under the License. """FFI registry to register function and objects.""" +from __future__ import annotations + import sys -from typing import Any, Callable, Optional +from typing import Any, Callable, Literal, overload from . import core from .core import TypeInfo @@ -26,7 +28,7 @@ _SKIP_UNKNOWN_OBJECTS = False -def register_object(type_key: str | type | None = None) -> Any: +def register_object(type_key: str | type | None = None) -> Callable[[type], type] | type: """Register object type. Parameters @@ -46,9 +48,8 @@ class MyObject(Object): pass """ - object_name = type_key if isinstance(type_key, str) else type_key.__name__ - def register(cls: type) -> type: + def _register(cls: type, object_name: str) -> type: """Register the object type with the FFI core.""" type_index = core._object_type_key_to_index(object_name) if type_index is None: @@ -60,14 +61,25 @@ def register(cls: type) -> type: return cls if isinstance(type_key, str): - return register - return register(type_key) + def _decorator_with_name(cls: type) -> type: + return _register(cls, type_key) + + return _decorator_with_name + + def _decorator_default(cls: type) -> type: + return _register(cls, cls.__name__) + + if type_key is None: + return _decorator_default + if isinstance(type_key, type): + return _decorator_default(type_key) + raise TypeError("type_key must be a string, type, or None") def register_global_func( func_name: str | Callable[..., Any], - f: Optional[Callable[..., Any]] = None, + f: Callable[..., Any] | None = None, override: bool = False, ) -> Any: """Register global function. @@ -124,12 +136,20 @@ def register(myf: Callable[..., Any]) -> Any: """Register the global function with the FFI core.""" return core._register_global_func(func_name, myf, override) - if f: + if f is not None: return register(f) return register -def get_global_func(name: str, allow_missing: bool = False) -> Optional[core.Function]: +@overload +def get_global_func(name: str, allow_missing: Literal[True]) -> core.Function | None: ... + + +@overload +def get_global_func(name: str, allow_missing: Literal[False] = False) -> core.Function: ... + + +def get_global_func(name: str, allow_missing: bool = False) -> core.Function | None: """Get a global function by name. Parameters @@ -179,7 +199,7 @@ def remove_global_func(name: str) -> None: get_global_func("ffi.FunctionRemoveGlobal")(name) -def init_ffi_api(namespace: str, target_module_name: Optional[str] = None) -> None: +def init_ffi_api(namespace: str, target_module_name: str | None = None) -> None: """Initialize register ffi api functions into a given module. Parameters @@ -225,8 +245,8 @@ def init_ffi_api(namespace: str, target_module_name: Optional[str] = None) -> No continue f = get_global_func(name) - f.__name__ = fname - setattr(target_module, f.__name__, f) + setattr(f, "__name__", fname) + setattr(target_module, fname, f) def _member_method_wrapper(method_func: Callable[..., Any]) -> Callable[..., Any]: @@ -253,16 +273,16 @@ def _add_class_attrs(type_cls: type, type_info: TypeInfo) -> type: doc = method.doc if method.doc else None method_func = method.func if method.is_static: - method_pyfunc = staticmethod(method_func) + if doc is not None: + method_func.__doc__ = doc + method_func.__name__ = name + method_pyfunc: Any = staticmethod(method_func) else: - # must call into another method instead of direct capture - # to avoid the same method_func variable being used - # across multiple loop iterations - method_pyfunc = _member_method_wrapper(method_func) - - if doc is not None: - method_pyfunc.__doc__ = doc - method_pyfunc.__name__ = name + wrapped_func = _member_method_wrapper(method_func) + if doc is not None: + wrapped_func.__doc__ = doc + wrapped_func.__name__ = name + method_pyfunc = wrapped_func if hasattr(type_cls, name): # skip already defined attributes diff --git a/python/tvm_ffi/serialization.py b/python/tvm_ffi/serialization.py index 2bc0d146..6960eda8 100644 --- a/python/tvm_ffi/serialization.py +++ b/python/tvm_ffi/serialization.py @@ -16,12 +16,14 @@ # under the License. """Serialization related utilities to enable some object can be pickled.""" -from typing import Any, Optional +from __future__ import annotations + +from typing import Any from . import _ffi_api -def to_json_graph_str(obj: Any, metadata: Optional[dict] = None) -> str: +def to_json_graph_str(obj: Any, metadata: dict[str, Any] | None = None) -> str: """Dump an object to a JSON graph string. The JSON graph string is a string representation of of the object diff --git a/python/tvm_ffi/stream.py b/python/tvm_ffi/stream.py index 3ce9d943..7f2dde52 100644 --- a/python/tvm_ffi/stream.py +++ b/python/tvm_ffi/stream.py @@ -18,7 +18,7 @@ """Stream context.""" from ctypes import c_void_p -from typing import Any, NoReturn, Optional, Union +from typing import Any, Union from . import core from ._tensor import device @@ -72,7 +72,7 @@ def __exit__(self, *args: Any) -> None: class TorchStreamContext: """Context manager that syncs Torch and FFI stream contexts.""" - def __init__(self, context: Optional[Any]) -> None: + def __init__(self, context: Any) -> None: """Initialize with an optional Torch stream/graph context wrapper.""" self.torch_context = context @@ -93,14 +93,14 @@ def __exit__(self, *args: Any) -> None: self.torch_context.__exit__(*args) self.ffi_context.__exit__(*args) - def use_torch_stream(context: Optional[Any] = None) -> "TorchStreamContext": + def use_torch_stream(context: Any = None) -> "TorchStreamContext": """Create an FFI stream context with a Torch stream or graph. cuda graph or current stream if `None` provided. Parameters ---------- - context : Optional[Any] + context : Any = None The wrapped torch stream or cuda graph. Returns @@ -129,7 +129,7 @@ def use_torch_stream(context: Optional[Any] = None) -> "TorchStreamContext": except ImportError: - def use_torch_stream(context: Optional[Any] = None) -> NoReturn: + def use_torch_stream(context: Any = None) -> "TorchStreamContext": """Raise an informative error when Torch is unavailable.""" raise ImportError("Cannot import torch") diff --git a/python/tvm_ffi/testing.py b/python/tvm_ffi/testing.py index 825f9cfc..97691584 100644 --- a/python/tvm_ffi/testing.py +++ b/python/tvm_ffi/testing.py @@ -16,9 +16,12 @@ # under the License. """Testing utilities.""" +from __future__ import annotations + from typing import Any, ClassVar from . import _ffi_api +from .container import Array, Map from .core import Object from .dataclasses import c_class, field from .registry import register_object @@ -28,11 +31,18 @@ class TestObjectBase(Object): """Test object base class.""" + v_i64: int + v_f64: float + v_str: str + @register_object("testing.TestIntPair") class TestIntPair(Object): """Test Int Pair.""" + a: int + b: int + def __init__(self, a: int, b: int) -> None: """Construct the object.""" self.__ffi_init__(a, b) @@ -42,6 +52,9 @@ def __init__(self, a: int, b: int) -> None: class TestObjectDerived(TestObjectBase): """Test object derived class.""" + v_map: Map + v_array: Array + def create_object(type_key: str, **kwargs: Any) -> Object: """Make an object by reflection. @@ -79,7 +92,7 @@ class _TestCxxClassBase: not_field_2: ClassVar[int] = 2 def __init__(self, v_i64: int, v_i32: int) -> None: - self.__ffi_init__(v_i64 + 1, v_i32 + 2) + self.__ffi_init__(v_i64 + 1, v_i32 + 2) # type: ignore[attr-defined] @c_class("testing.TestCxxClassDerived") @@ -90,5 +103,5 @@ class _TestCxxClassDerived(_TestCxxClassBase): @c_class("testing.TestCxxClassDerivedDerived") class _TestCxxClassDerivedDerived(_TestCxxClassDerived): - v_str: str = field(default_factory=lambda: "default") - v_bool: bool + v_str: str = field(default_factory=lambda: "default") # type: ignore[assignment] + v_bool: bool # type: ignore[misc] diff --git a/python/tvm_ffi/utils/lockfile.py b/python/tvm_ffi/utils/lockfile.py index 243a319e..6efe80f8 100644 --- a/python/tvm_ffi/utils/lockfile.py +++ b/python/tvm_ffi/utils/lockfile.py @@ -16,10 +16,12 @@ # under the License. """Simple cross-platform advisory file lock utilities.""" +from __future__ import annotations + import os import sys import time -from typing import Any, Optional +from typing import Any, Literal # Platform-specific imports for file locking if sys.platform == "win32": @@ -38,9 +40,9 @@ class FileLock: def __init__(self, lock_file_path: str) -> None: """Initialize a file lock using the given lock file path.""" self.lock_file_path = lock_file_path - self._file_descriptor = None + self._file_descriptor: int | None = None - def __enter__(self) -> "FileLock": + def __enter__(self) -> FileLock: """Acquire the lock upon entering the context. This method blocks until the lock is acquired. @@ -48,12 +50,12 @@ def __enter__(self) -> "FileLock": self.blocking_acquire() return self - def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool: + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> Literal[False]: """Context manager protocol: release the lock upon exiting the 'with' block.""" self.release() return False # Propagate exceptions, if any - def acquire(self) -> Optional[bool]: + def acquire(self) -> bool: """Acquire an exclusive, non-blocking lock on the file. Returns True if the lock was acquired, False otherwise. @@ -79,9 +81,7 @@ def acquire(self) -> Optional[bool]: self._file_descriptor = None raise RuntimeError(f"An unexpected error occurred: {e}") - def blocking_acquire( - self, timeout: Optional[float] = None, poll_interval: float = 0.1 - ) -> Optional[bool]: + def blocking_acquire(self, timeout: float | None = None, poll_interval: float = 0.1) -> bool: """Wait until an exclusive lock can be acquired, with an optional timeout. Args: diff --git a/tests/lint/check_asf_header.py b/tests/lint/check_asf_header.py index da304642..713520fa 100644 --- a/tests/lint/check_asf_header.py +++ b/tests/lint/check_asf_header.py @@ -170,7 +170,7 @@ } # Files and patterns to skip during header checking -SKIP_LIST = [] +SKIP_LIST: list[str] = [] def should_skip_file(filepath: str) -> bool: diff --git a/tests/lint/check_file_type.py b/tests/lint/check_file_type.py index 9d08f9a3..c776b209 100644 --- a/tests/lint/check_file_type.py +++ b/tests/lint/check_file_type.py @@ -183,8 +183,8 @@ def main() -> None: cmd = ["git", "ls-files"] proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) (out, _) = proc.communicate() - assert proc.returncode == 0, f"{' '.join(cmd)} errored: {out}" res = out.decode("utf-8") + assert proc.returncode == 0, f"{' '.join(cmd)} errored: {res}" flist = res.split() error_list = [] diff --git a/tests/python/test_container.py b/tests/python/test_container.py index f5f28f34..f42e9b31 100644 --- a/tests/python/test_container.py +++ b/tests/python/test_container.py @@ -38,10 +38,10 @@ def test_bad_constructor_init_state() -> None: proper repr code """ with pytest.raises(TypeError): - tvm_ffi.Array(1) + tvm_ffi.Array(1) # type: ignore[arg-type] with pytest.raises(AttributeError): - tvm_ffi.Map(1) + tvm_ffi.Map(1) # type: ignore[arg-type] def test_array_of_array_map() -> None: diff --git a/tests/python/test_dataclasses_c_class.py b/tests/python/test_dataclasses_c_class.py index a2fa80eb..0050e6c0 100644 --- a/tests/python/test_dataclasses_c_class.py +++ b/tests/python/test_dataclasses_c_class.py @@ -57,7 +57,7 @@ def test_cxx_class_derived_derived() -> None: def test_cxx_class_derived_derived_default() -> None: - obj = _TestCxxClassDerivedDerived(123, 456, 4, True) + obj = _TestCxxClassDerivedDerived(123, 456, 4, True) # type: ignore[call-arg,misc] assert obj.v_i64 == 123 assert obj.v_i32 == 456 assert isinstance(obj.v_f64, float) and obj.v_f64 == 4 diff --git a/tests/python/test_device.py b/tests/python/test_device.py index 9441c9fa..7a3638bf 100644 --- a/tests/python/test_device.py +++ b/tests/python/test_device.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. +from __future__ import annotations + import pickle import pytest @@ -87,7 +89,7 @@ def test_deive_type_error(dev_type: str, dev_id: int | None) -> None: def test_deive_id_error() -> None: with pytest.raises(TypeError): - tvm_ffi.device("cpu", "?") + tvm_ffi.device("cpu", "?") # type: ignore[arg-type] def test_device_pickle() -> None: diff --git a/tests/python/test_error.py b/tests/python/test_error.py index dda436c8..dd94cf39 100644 --- a/tests/python/test_error.py +++ b/tests/python/test_error.py @@ -39,9 +39,9 @@ def test_error_from_cxx() -> None: try: test_raise_error("ValueError", "error XYZ") except ValueError as e: - assert e.__tvm_ffi_error__.kind == "ValueError" - assert e.__tvm_ffi_error__.message == "error XYZ" - assert e.__tvm_ffi_error__.backtrace.find("TestRaiseError") != -1 + assert e.__tvm_ffi_error__.kind == "ValueError" # type: ignore[attr-defined] + assert e.__tvm_ffi_error__.message == "error XYZ" # type: ignore[attr-defined] + assert e.__tvm_ffi_error__.backtrace.find("TestRaiseError") != -1 # type: ignore[attr-defined] fapply = tvm_ffi.convert(lambda f, *args: f(*args)) @@ -64,17 +64,17 @@ def raise_error() -> None: try: fapply(cxx_test_raise_error, "ValueError", "error XYZ") except ValueError as e: - assert e.__tvm_ffi_error__.kind == "ValueError" - assert e.__tvm_ffi_error__.message == "error XYZ" - assert e.__tvm_ffi_error__.backtrace.find("TestRaiseError") != -1 - record_object.append(e.__tvm_ffi_error__) + assert e.__tvm_ffi_error__.kind == "ValueError" # type: ignore[attr-defined] + assert e.__tvm_ffi_error__.message == "error XYZ" # type: ignore[attr-defined] + assert e.__tvm_ffi_error__.backtrace.find("TestRaiseError") != -1 # type: ignore[attr-defined] + record_object.append(e.__tvm_ffi_error__) # type: ignore[attr-defined] raise e try: cxx_test_apply(raise_error) except ValueError as e: - backtrace = e.__tvm_ffi_error__.backtrace - assert e.__tvm_ffi_error__.same_as(record_object[0]) + backtrace = e.__tvm_ffi_error__.backtrace # type: ignore[attr-defined] + assert e.__tvm_ffi_error__.same_as(record_object[0]) # type: ignore[attr-defined] assert backtrace.count("TestRaiseError") == 1 # The following lines may fail if debug symbols are missing try: @@ -108,7 +108,7 @@ def raise_cxx_error() -> None: try: raise_cxx_error() except ValueError as e: - assert e.__tvm_ffi_error__.backtrace.find("raise_cxx_error") == -1 + assert e.__tvm_ffi_error__.backtrace.find("raise_cxx_error") == -1 # type: ignore[attr-defined] ffi_error1 = tvm_ffi.convert(e) ffi_error2 = fecho(e) assert ffi_error1.backtrace.find("raise_cxx_error") != -1 diff --git a/tests/python/test_function.py b/tests/python/test_function.py index d0afd5f6..edc9ffde 100644 --- a/tests/python/test_function.py +++ b/tests/python/test_function.py @@ -113,11 +113,11 @@ def test_string_bytes_passing() -> None: # small bytes assert fecho(b"hello") == b"hello" # large bytes - x = b"hello" * 100 - y = fecho(x) - assert y == x - assert y.__tvm_ffi_object__ is not None - fecho(y) == 1 + x2 = b"hello" * 100 + y2 = fecho(x2) + assert y2 == x2 + assert y2.__tvm_ffi_object__ is not None + fecho(y2) == 1 def test_nested_container_passing() -> None: diff --git a/tests/python/test_load_inline.py b/tests/python/test_load_inline.py index 5454284f..cd46bf5f 100644 --- a/tests/python/test_load_inline.py +++ b/tests/python/test_load_inline.py @@ -15,12 +15,16 @@ # specific language governing permissions and limitations # under the License. +from __future__ import annotations + +from types import ModuleType import numpy import pytest +torch: ModuleType | None try: - import torch + import torch # type: ignore[no-redef] except ImportError: torch = None @@ -197,6 +201,7 @@ def test_load_inline_cuda() -> None: @pytest.mark.skipif(torch is None, reason="Requires torch") def test_load_inline_with_env_tensor_allocator() -> None: + assert torch is not None if not hasattr(torch.Tensor, "__c_dlpack_tensor_allocator__"): pytest.skip("Torch does not support __c_dlpack_tensor_allocator__") mod: Module = tvm_ffi.cpp.load_inline( @@ -241,6 +246,7 @@ def test_load_inline_with_env_tensor_allocator() -> None: torch is None or not torch.cuda.is_available(), reason="Requires torch and CUDA" ) def test_load_inline_both() -> None: + assert torch is not None mod: Module = tvm_ffi.cpp.load_inline( name="hello", cpp_sources=r""" diff --git a/tests/python/test_object.py b/tests/python/test_object.py index ea54adf6..aa1a791b 100644 --- a/tests/python/test_object.py +++ b/tests/python/test_object.py @@ -24,6 +24,7 @@ def test_make_object() -> None: # with default values obj0 = tvm_ffi.testing.create_object("testing.TestObjectBase") + assert isinstance(obj0, tvm_ffi.testing.TestObjectBase) assert obj0.v_i64 == 10 assert obj0.v_f64 == 10.0 assert obj0.v_str == "hello" @@ -37,14 +38,16 @@ def test_make_object_via_init() -> None: def test_method() -> None: obj0 = tvm_ffi.testing.create_object("testing.TestObjectBase", v_i64=12) - assert obj0.add_i64(1) == 13 - assert type(obj0).add_i64.__doc__ == "add_i64 method" - assert type(obj0).v_i64.__doc__ == "i64 field" + assert isinstance(obj0, tvm_ffi.testing.TestObjectBase) + assert obj0.add_i64(1) == 13 # type: ignore[attr-defined] + assert type(obj0).add_i64.__doc__ == "add_i64 method" # type: ignore[attr-defined] + assert type(obj0).v_i64.__doc__ == "i64 field" # type: ignore[attr-defined] def test_setter() -> None: # test setter obj0 = tvm_ffi.testing.create_object("testing.TestObjectBase", v_i64=10, v_str="hello") + assert isinstance(obj0, tvm_ffi.testing.TestObjectBase) assert obj0.v_i64 == 10 obj0.v_i64 = 11 assert obj0.v_i64 == 11 @@ -52,10 +55,10 @@ def test_setter() -> None: assert obj0.v_str == "world" with pytest.raises(TypeError): - obj0.v_str = 1 + obj0.v_str = 1 # type: ignore[assignment] with pytest.raises(TypeError): - obj0.v_i64 = "hello" + obj0.v_i64 = "hello" # type: ignore[assignment] def test_derived_object() -> None: @@ -68,6 +71,7 @@ def test_derived_object() -> None: obj0 = tvm_ffi.testing.create_object( "testing.TestObjectDerived", v_i64=20, v_map=v_map, v_array=v_array ) + assert isinstance(obj0, tvm_ffi.testing.TestObjectDerived) assert obj0.v_map.same_as(v_map) assert obj0.v_array.same_as(v_array) assert obj0.v_i64 == 20 diff --git a/tests/python/test_stream.py b/tests/python/test_stream.py index cfaf6506..9280aabb 100644 --- a/tests/python/test_stream.py +++ b/tests/python/test_stream.py @@ -15,12 +15,17 @@ # specific language governing permissions and limitations # under the License. +from __future__ import annotations + +from types import ModuleType + import pytest import tvm_ffi import tvm_ffi.cpp +torch: ModuleType | None try: - import torch + import torch # type: ignore[no-redef] except ImportError: torch = None @@ -56,6 +61,7 @@ def test_raw_stream() -> None: torch is None or not torch.cuda.is_available(), reason="Requires torch and CUDA" ) def test_torch_stream() -> None: + assert torch is not None mod = gen_check_stream_mod() device_id = torch.cuda.current_device() device = tvm_ffi.device("cuda", device_id) @@ -78,6 +84,7 @@ def test_torch_stream() -> None: torch is None or not torch.cuda.is_available(), reason="Requires torch and CUDA" ) def test_torch_current_stream() -> None: + assert torch is not None mod = gen_check_stream_mod() device_id = torch.cuda.current_device() device = tvm_ffi.device("cuda", device_id) @@ -103,6 +110,7 @@ def test_torch_current_stream() -> None: torch is None or not torch.cuda.is_available(), reason="Requires torch and CUDA" ) def test_torch_graph() -> None: + assert torch is not None mod = gen_check_stream_mod() device_id = torch.cuda.current_device() device = tvm_ffi.device("cuda", device_id) diff --git a/tests/python/test_tensor.py b/tests/python/test_tensor.py index 4c2e9a87..186d91b8 100644 --- a/tests/python/test_tensor.py +++ b/tests/python/test_tensor.py @@ -14,10 +14,16 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + +from __future__ import annotations + +from types import ModuleType + import pytest +torch: ModuleType | None try: - import torch + import torch # type: ignore[no-redef] except ImportError: torch = None @@ -45,18 +51,19 @@ def test_shape_object() -> None: assert shape == (10, 8, 4, 2) fecho = tvm_ffi.convert(lambda x: x) - shape2 = fecho(shape) + shape2: tvm_ffi.Shape = fecho(shape) assert shape2.__tvm_ffi_object__.same_as(shape.__tvm_ffi_object__) assert isinstance(shape2, tvm_ffi.Shape) assert isinstance(shape2, tuple) - shape3 = tvm_ffi.convert(shape) + shape3: tvm_ffi.Shape = tvm_ffi.convert(shape) assert shape3.__tvm_ffi_object__.same_as(shape.__tvm_ffi_object__) assert isinstance(shape3, tvm_ffi.Shape) @pytest.mark.skipif(torch is None, reason="Fast torch dlpack importer is not enabled") def test_tensor_auto_dlpack() -> None: + assert torch is not None x = torch.arange(128) fecho = tvm_ffi.get_global_func("testing.echo") y = fecho(x) diff --git a/tests/scripts/benchmark_dlpack.py b/tests/scripts/benchmark_dlpack.py index 2d9b296b..4366b583 100644 --- a/tests/scripts/benchmark_dlpack.py +++ b/tests/scripts/benchmark_dlpack.py @@ -254,9 +254,9 @@ def tvm_ffi_nop_autodlpack_from_dltensor_test_wrapper(repeat: int, device: str) x = tvm_ffi.from_dlpack(torch.arange(1, device=device)) y = tvm_ffi.from_dlpack(torch.arange(1, device=device)) z = tvm_ffi.from_dlpack(torch.arange(1, device=device)) - x = tvm_ffi.core.DLTensorTestWrapper(x) - y = tvm_ffi.core.DLTensorTestWrapper(y) - z = tvm_ffi.core.DLTensorTestWrapper(z) + x = tvm_ffi.core.DLTensorTestWrapper(x) # type: ignore[assignment] + y = tvm_ffi.core.DLTensorTestWrapper(y) # type: ignore[assignment] + z = tvm_ffi.core.DLTensorTestWrapper(z) # type: ignore[assignment] bench_tvm_ffi_nop_autodlpack( f"tvm_ffi.nop.autodlpack(DLTensorTestWrapper[{device}])", x, y, z, repeat )