Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,30 @@ repos:
rev: v0.10.0.1
hooks:
- id: shellcheck
- repo: https://github.com/pre-commit/mirrors-mypy
rev: "v1.17.0"
hooks:
- id: mypy
name: mypy for `python/` and `tests/`
additional_dependencies: ['numpy >= 1.22', "ml-dtypes >= 0.1", "pytest", "typing-extensions>=4.5"]
args: [--show-error-codes, --python-version=3.9]
exclude: ^.*/_ffi_api\.py$
files: ^(python/|tests/).*\.py$
- id: mypy
name: mypy for `examples/inline_module`
additional_dependencies: ['numpy >= 1.22', "ml-dtypes >= 0.1", "pytest", "typing-extensions>=4.5"]
args: [--show-error-codes, --python-version=3.9]
exclude: ^.*/_ffi_api\.py$
files: ^examples/inline_module/.*\.py$
- id: mypy
name: mypy for `examples/packaging`
additional_dependencies: ['numpy >= 1.22', "ml-dtypes >= 0.1", "pytest", "typing-extensions>=4.5"]
args: [--show-error-codes, --python-version=3.9]
exclude: ^.*/_ffi_api\.py$
files: ^examples/packaging/.*\.py$
- id: mypy
name: mypy for `examples/quick_start`
additional_dependencies: ['numpy >= 1.22', "ml-dtypes >= 0.1", "pytest", "typing-extensions>=4.5"]
args: [--show-error-codes, --python-version=3.9]
exclude: ^.*/_ffi_api\.py$
files: ^examples/quick_start/.*\.py$
2 changes: 1 addition & 1 deletion examples/packaging/python/my_ffi_extension/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,4 @@ def raise_error(msg: str) -> None:
The error raised by the function.

"""
return _ffi_api.raise_error(msg)
return _ffi_api.raise_error(msg) # type: ignore[attr-defined]
18 changes: 3 additions & 15 deletions examples/quick_start/run_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,9 @@
# under the License.
"""Quick start script to run tvm-ffi examples from prebuilt libraries."""

import tvm_ffi

try:
import torch
except ImportError:
torch = None


import numpy
import torch
import tvm_ffi


def run_add_one_cpu() -> None:
Expand All @@ -40,9 +34,6 @@ def run_add_one_cpu() -> None:
print("numpy.result after add_one(x, y)")
print(x)

if torch is None:
return

x = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float32)
y = torch.empty_like(x)
# tvm-ffi automatically handles DLPack compatible tensors
Expand All @@ -63,9 +54,6 @@ def run_add_one_c() -> None:
print("numpy.result after add_one_c(x, y)")
print(x)

if torch is None:
return

x = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float32)
y = torch.empty_like(x)
mod.add_one_c(x, y)
Expand All @@ -75,7 +63,7 @@ def run_add_one_c() -> None:

def run_add_one_cuda() -> None:
"""Load the add_one_cuda module and call the add_one_cuda function."""
if torch is None or not torch.cuda.is_available():
if not torch.cuda.is_available():
return

mod = tvm_ffi.load_module("build/add_one_cuda.so")
Expand Down
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -213,3 +213,7 @@ environment = { MACOSX_DEPLOYMENT_TARGET = "10.14" }

[tool.cibuildwheel.windows]
archs = ["AMD64"]

[tool.mypy]
allow_redefinition = true
ignore_missing_imports = true
11 changes: 8 additions & 3 deletions python/tvm_ffi/_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,25 @@
# under the License.
"""Conversion utilities to bring python objects into ffi values."""

from __future__ import annotations

from numbers import Number
from types import ModuleType
from typing import Any

from . import container, core

torch: ModuleType | None = None
try:
import torch
import torch # type: ignore[no-redef]
except ImportError:
torch = None
pass

numpy: ModuleType | None = None
try:
import numpy
except ImportError:
numpy = None
pass


def convert(value: Any) -> Any: # noqa: PLR0911,PLR0912
Expand Down
1 change: 1 addition & 0 deletions python/tvm_ffi/_dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class dtype(str):
"""

__slots__ = ["__tvm_ffi_dtype__"]
__tvm_ffi_dtype__: core.DataType

_NUMPY_DTYPE_TO_STR: ClassVar[dict[Any, str]] = {}

Expand Down
43 changes: 43 additions & 0 deletions python/tvm_ffi/_ffi_api.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""FFI API."""

from typing import Any

def ModuleGetKind(*args: Any) -> Any: ...
def ModuleImplementsFunction(*args: Any) -> Any: ...
def ModuleGetFunction(*args: Any) -> Any: ...
def ModuleImportModule(*args: Any) -> Any: ...
def ModuleInspectSource(*args: Any) -> Any: ...
def ModuleGetWriteFormats(*args: Any) -> Any: ...
def ModuleGetPropertyMask(*args: Any) -> Any: ...
def ModuleClearImports(*args: Any) -> Any: ...
def ModuleWriteToFile(*args: Any) -> Any: ...
def ModuleLoadFromFile(*args: Any) -> Any: ...
def SystemLib(*args: Any) -> Any: ...
def Array(*args: Any) -> Any: ...
def ArrayGetItem(*args: Any) -> Any: ...
def ArraySize(*args: Any) -> Any: ...
def MapForwardIterFunctor(*args: Any) -> Any: ...
def Map(*args: Any) -> Any: ...
def MapGetItem(*args: Any) -> Any: ...
def MapCount(*args: Any) -> Any: ...
def MapSize(*args: Any) -> Any: ...
def MakeObjectFromPackedArgs(*args: Any) -> Any: ...
def ToJSONGraphString(*args: Any) -> Any: ...
def FromJSONGraphString(*args: Any) -> Any: ...
def Shape(*args: Any) -> Any: ...
12 changes: 6 additions & 6 deletions python/tvm_ffi/_optional_torch_c_dlpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@
"""

import warnings
from typing import Any, Optional
from typing import Any

from . import libinfo


def load_torch_c_dlpack_extension() -> Optional[Any]:
def load_torch_c_dlpack_extension() -> Any:
"""Load the torch c dlpack extension."""
cpp_source = """
#include <dlpack/dlpack.h>
Expand Down Expand Up @@ -556,17 +556,17 @@ def load_torch_c_dlpack_extension() -> Optional[Any]:
extra_include_paths=include_paths,
)
# set the dlpack related flags
torch.Tensor.__c_dlpack_from_pyobject__ = mod.TorchDLPackFromPyObjectPtr()
torch.Tensor.__c_dlpack_to_pyobject__ = mod.TorchDLPackToPyObjectPtr()
torch.Tensor.__c_dlpack_tensor_allocator__ = mod.TorchDLPackTensorAllocatorPtr()
setattr(torch.Tensor, "__c_dlpack_from_pyobject__", mod.TorchDLPackFromPyObjectPtr())
setattr(torch.Tensor, "__c_dlpack_to_pyobject__", mod.TorchDLPackToPyObjectPtr())
setattr(torch.Tensor, "__c_dlpack_tensor_allocator__", mod.TorchDLPackTensorAllocatorPtr())
return mod
except ImportError:
pass
except Exception as e:
warnings.warn(
f"Failed to load torch c dlpack extension: {e},EnvTensorAllocator will not be enabled."
)
return None
return None


# keep alive
Expand Down
20 changes: 12 additions & 8 deletions python/tvm_ffi/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@
# specific language governing permissions and limitations
# under the License.
"""Tensor related objects and functions."""

from __future__ import annotations

# we name it as _tensor.py to avoid potential future case
# if we also want to expose a tensor function in the root namespace

from numbers import Integral
from typing import Any, Optional, Union
from typing import Any

from . import _ffi_api, core, registry
from .core import (
Expand All @@ -43,23 +45,25 @@ class Shape(tuple, PyNativeObject):

"""

def __new__(cls, content: tuple[int, ...]) -> "Shape":
__tvm_ffi_object__: Any

def __new__(cls, content: tuple[int, ...]) -> Shape:
if any(not isinstance(x, Integral) for x in content):
raise ValueError("Shape must be a tuple of integers")
val = tuple.__new__(cls, content)
val: Shape = tuple.__new__(cls, content)
val.__init_tvm_ffi_object_by_constructor__(_ffi_api.Shape, *content)
return val

# pylint: disable=no-self-argument
def __from_tvm_ffi_object__(cls, obj: Any) -> "Shape":
def __from_tvm_ffi_object__(cls, obj: Any) -> Shape:
"""Construct from a given tvm object."""
content = _shape_obj_get_py_tuple(obj)
val = tuple.__new__(cls, content)
val.__tvm_ffi_object__ = obj
val: Shape = tuple.__new__(cls, content) # type: ignore[arg-type]
val.__tvm_ffi_object__ = obj # type: ignore[attr-defined]
return val


def device(device_type: Union[str, int, DLDeviceType], index: Optional[int] = None) -> Device:
def device(device_type: str | int | DLDeviceType, index: int | None = None) -> Device:
"""Construct a TVM FFI device with given device type and index.

Parameters
Expand Down
27 changes: 16 additions & 11 deletions python/tvm_ffi/access_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,16 @@ class AccessKind(IntEnum):
class AccessStep(core.Object):
"""Access step container."""

kind: AccessKind
key: Any


@register_object("ffi.reflection.AccessPath")
class AccessPath(core.Object):
"""Access path container."""

parent: "AccessPath"

def __init__(self) -> None:
"""Disallow direct construction; use `AccessPath.root()` instead."""
super().__init__()
Expand All @@ -55,19 +60,19 @@ def __init__(self) -> None:
@staticmethod
def root() -> "AccessPath":
"""Create a root access path."""
return AccessPath._root()
return AccessPath._root() # type: ignore[attr-defined]

def __eq__(self, other: Any) -> bool:
"""Return whether two access paths are equal."""
if not isinstance(other, AccessPath):
return False
return self._path_equal(other)
return self._path_equal(other) # type: ignore[attr-defined]

def __ne__(self, other: Any) -> bool:
"""Return whether two access paths are not equal."""
if not isinstance(other, AccessPath):
return True
return not self._path_equal(other)
return not self._path_equal(other) # type: ignore[attr-defined]

def is_prefix_of(self, other: "AccessPath") -> bool:
"""Check if this access path is a prefix of another access path.
Expand All @@ -83,7 +88,7 @@ def is_prefix_of(self, other: "AccessPath") -> bool:
True if this access path is a prefix of the other access path, False otherwise

"""
return self._is_prefix_of(other)
return self._is_prefix_of(other) # type: ignore[attr-defined]

def attr(self, attr_key: str) -> "AccessPath":
"""Create an access path to the attribute of the current object.
Expand All @@ -99,7 +104,7 @@ def attr(self, attr_key: str) -> "AccessPath":
The extended access path

"""
return self._attr(attr_key)
return self._attr(attr_key) # type: ignore[attr-defined]

def attr_missing(self, attr_key: str) -> "AccessPath":
"""Create an access path that indicate an attribute is missing.
Expand All @@ -115,7 +120,7 @@ def attr_missing(self, attr_key: str) -> "AccessPath":
The extended access path

"""
return self._attr_missing(attr_key)
return self._attr_missing(attr_key) # type: ignore[attr-defined]

def array_item(self, index: int) -> "AccessPath":
"""Create an access path to the item of the current array.
Expand All @@ -131,7 +136,7 @@ def array_item(self, index: int) -> "AccessPath":
The extended access path

"""
return self._array_item(index)
return self._array_item(index) # type: ignore[attr-defined]

def array_item_missing(self, index: int) -> "AccessPath":
"""Create an access path that indicate an array item is missing.
Expand All @@ -147,7 +152,7 @@ def array_item_missing(self, index: int) -> "AccessPath":
The extended access path

"""
return self._array_item_missing(index)
return self._array_item_missing(index) # type: ignore[attr-defined]

def map_item(self, key: Any) -> "AccessPath":
"""Create an access path to the item of the current map.
Expand All @@ -163,7 +168,7 @@ def map_item(self, key: Any) -> "AccessPath":
The extended access path

"""
return self._map_item(key)
return self._map_item(key) # type: ignore[attr-defined]

def map_item_missing(self, key: Any) -> "AccessPath":
"""Create an access path that indicate a map item is missing.
Expand All @@ -179,7 +184,7 @@ def map_item_missing(self, key: Any) -> "AccessPath":
The extended access path

"""
return self._map_item_missing(key)
return self._map_item_missing(key) # type: ignore[attr-defined]

def to_steps(self) -> list["AccessStep"]:
"""Convert the access path to a list of access steps.
Expand All @@ -190,6 +195,6 @@ def to_steps(self) -> list["AccessStep"]:
The list of access steps

"""
return self._to_steps()
return self._to_steps() # type: ignore[attr-defined]

__hash__ = core.Object.__hash__
Loading
Loading