Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 113 additions & 72 deletions python/tvm_ffi/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,70 +18,78 @@

from __future__ import annotations

import collections.abc
import operator
from collections.abc import ItemsView as ItemsViewBase
from collections.abc import Iterator, Mapping, Sequence
from typing import Any, Callable
from collections.abc import KeysView as KeysViewBase
from collections.abc import ValuesView as ValuesViewBase
from typing import Any, Callable, SupportsIndex, TypeVar, cast, overload

from . import _ffi_api, core
from .registry import register_object

__all__ = ["Array", "Map"]


T = TypeVar("T")
K = TypeVar("K")
V = TypeVar("V")
_DefaultT = TypeVar("_DefaultT")


def getitem_helper(
obj: Any,
elem_getter: Callable[[Any, int], Any],
elem_getter: Callable[[Any, int], T],
length: int,
idx: int | slice,
) -> Any:
idx: SupportsIndex | slice,
) -> T | list[T]:
"""Implement a pythonic __getitem__ helper.

Parameters
----------
obj: Any
The original object

elem_getter : Callable[[Any, int], Any]
elem_getter : Callable[[Any, int], T]
A simple function that takes index and return a single element.

length : int
The size of the array

idx : int or slice
idx : SupportsIndex or slice
The argument passed to getitem

Returns
-------
result : object
The result of getitem
The element for integer indices or a ``list`` for slices.

"""
if isinstance(idx, slice):
start = idx.start if idx.start is not None else 0
stop = idx.stop if idx.stop is not None else length
step = idx.step if idx.step is not None else 1
if start < 0:
start += length
if stop < 0:
stop += length
start, stop, step = idx.indices(length)
return [elem_getter(obj, i) for i in range(start, stop, step)]

if idx < -length or idx >= length:
raise IndexError(f"Index out of range. size: {length}, got index {idx}")
if idx < 0:
idx += length
return elem_getter(obj, idx)
try:
index = operator.index(idx)
except TypeError as exc: # pragma: no cover - defensive, matches list behaviour
raise TypeError(f"indices must be integers or slices, not {type(idx).__name__}") from exc

if index < -length or index >= length:
raise IndexError(f"Index out of range. size: {length}, got index {index}")
if index < 0:
index += length
return elem_getter(obj, index)


@register_object("ffi.Array")
class Array(core.Object, collections.abc.Sequence):
class Array(core.Object, Sequence[T]):
"""Array container that represents a sequence of values in ffi.

:py:func:`tvm_ffi.convert` will map python list/tuple to this class.

Parameters
----------
input_list : Sequence[Any]
input_list : Sequence[T]
The list of values to be stored in the array.

See Also
Expand All @@ -100,18 +108,34 @@ class Array(core.Object, collections.abc.Sequence):

"""

def __init__(self, input_list: Sequence[Any]) -> None:
def __init__(self, input_list: Sequence[T]) -> None:
"""Construct an Array from a Python sequence."""
self.__init_handle_by_constructor__(_ffi_api.Array, *input_list)

def __getitem__(self, idx: int | slice) -> Any:
"""Return one element or a Python list for a slice."""
return getitem_helper(self, _ffi_api.ArrayGetItem, len(self), idx)
@overload
def __getitem__(self, idx: SupportsIndex, /) -> T: ...

@overload
def __getitem__(self, idx: slice, /) -> Array[T]: ...

def __getitem__(self, idx: SupportsIndex | slice, /) -> T | Array[T]:
"""Return one element or a new :class:`Array` for a slice."""
length = len(self)
result = getitem_helper(self, _ffi_api.ArrayGetItem, length, idx)
if isinstance(result, list):
return cast(Array[T], type(self)(result))
return result

def __len__(self) -> int:
"""Return the number of elements in the array."""
return _ffi_api.ArraySize(self)

def __iter__(self) -> Iterator[T]:
"""Iterate over the elements in the array."""
length = len(self)
for i in range(length):
yield self[i]

def __repr__(self) -> str:
"""Return a string representation of the array."""
# exception safety handling for chandle=None
Expand All @@ -120,79 +144,87 @@ def __repr__(self) -> str:
return "[" + ", ".join([x.__repr__() for x in self]) + "]"


class KeysView(collections.abc.KeysView):
class KeysView(KeysViewBase[K]):
"""Helper class to return keys view."""

def __init__(self, backend_map: Map) -> None:
def __init__(self, backend_map: Map[K, V]) -> None:
self._backend_map = backend_map

def __len__(self) -> int:
return len(self._backend_map)

def __iter__(self) -> Iterator[Any]:
if self.__len__() == 0:
return
functor = _ffi_api.MapForwardIterFunctor(self._backend_map)
while True:
k = functor(0)
yield k
def __iter__(self) -> Iterator[K]:
size = len(self._backend_map)
functor: Callable[[int], Any] = _ffi_api.MapForwardIterFunctor(self._backend_map)
for _ in range(size):
key = cast(K, functor(0))
yield key
if not functor(2):
break

def __contains__(self, k: Any) -> bool:
return self._backend_map.__contains__(k)
def __contains__(self, k: object) -> bool:
return k in self._backend_map


class ValuesView(collections.abc.ValuesView):
class ValuesView(ValuesViewBase[V]):
"""Helper class to return values view."""

def __init__(self, backend_map: Map) -> None:
def __init__(self, backend_map: Map[K, V]) -> None:
self._backend_map = backend_map

def __len__(self) -> int:
return len(self._backend_map)

def __iter__(self) -> Iterator[Any]:
if self.__len__() == 0:
return
functor = _ffi_api.MapForwardIterFunctor(self._backend_map)
while True:
v = functor(1)
yield v
def __iter__(self) -> Iterator[V]:
size = len(self._backend_map)
functor: Callable[[int], Any] = _ffi_api.MapForwardIterFunctor(self._backend_map)
for _ in range(size):
value = cast(V, functor(1))
yield value
if not functor(2):
break


class ItemsView(collections.abc.ItemsView):
class ItemsView(ItemsViewBase[K, V]):
"""Helper class to return items view."""

def __init__(self, backend_map: Map) -> None:
self.backend_map = backend_map
def __init__(self, backend_map: Map[K, V]) -> None:
self._backend_map = backend_map

def __len__(self) -> int:
return len(self.backend_map)

def __iter__(self) -> Iterator[tuple[Any, Any]]:
if self.__len__() == 0:
return
functor = _ffi_api.MapForwardIterFunctor(self.backend_map)
while True:
k = functor(0)
v = functor(1)
yield (k, v)
return len(self._backend_map)

def __iter__(self) -> Iterator[tuple[K, V]]:
size = len(self._backend_map)
functor: Callable[[int], Any] = _ffi_api.MapForwardIterFunctor(self._backend_map)
for _ in range(size):
key = cast(K, functor(0))
value = cast(V, functor(1))
yield (key, value)
if not functor(2):
break

def __contains__(self, item: object) -> bool:
if not isinstance(item, tuple) or len(item) != 2:
return False
key, value = item
try:
existing_value = self._backend_map[key]
except KeyError:
return False
else:
return existing_value == value


@register_object("ffi.Map")
class Map(core.Object, collections.abc.Mapping):
class Map(core.Object, Mapping[K, V]):
"""Map container.

:py:func:`tvm_ffi.convert` will map python dict to this class.

Parameters
----------
input_dict : Mapping[Any, Any]
input_dict : Mapping[K, V]
The dictionary of values to be stored in the map.

See Also
Expand All @@ -213,43 +245,49 @@ class Map(core.Object, collections.abc.Mapping):

"""

def __init__(self, input_dict: Mapping[Any, Any]) -> None:
def __init__(self, input_dict: Mapping[K, V]) -> None:
"""Construct a Map from a Python mapping."""
list_kvs = []
list_kvs: list[Any] = []
for k, v in input_dict.items():
list_kvs.append(k)
list_kvs.append(v)
self.__init_handle_by_constructor__(_ffi_api.Map, *list_kvs)

def __getitem__(self, k: Any) -> Any:
def __getitem__(self, k: K) -> V:
"""Return the value for key `k` or raise KeyError."""
return _ffi_api.MapGetItem(self, k)
return cast(V, _ffi_api.MapGetItem(self, k))

def __contains__(self, k: Any) -> bool:
def __contains__(self, k: object) -> bool:
"""Return True if the map contains key `k`."""
return _ffi_api.MapCount(self, k) != 0

def keys(self) -> KeysView:
def keys(self) -> KeysView[K]:
"""Return a dynamic view of the map's keys."""
return KeysView(self)

def values(self) -> ValuesView:
def values(self) -> ValuesView[V]:
"""Return a dynamic view of the map's values."""
return ValuesView(self)

def items(self) -> ItemsView:
def items(self) -> ItemsView[K, V]:
"""Get the items from the map."""
return ItemsView(self)

def __len__(self) -> int:
"""Return the number of items in the map."""
return _ffi_api.MapSize(self)

def __iter__(self) -> Iterator[Any]:
def __iter__(self) -> Iterator[K]:
"""Iterate over the map's keys."""
return iter(self.keys())

def get(self, key: Any, default: Any | None = None) -> Any:
@overload
def get(self, key: K) -> V | None: ...

@overload
def get(self, key: K, default: V | _DefaultT) -> V | _DefaultT: ...

def get(self, key: K, default: V | _DefaultT | None = None) -> V | _DefaultT | None:
"""Get an element with a default value.

Parameters
Expand All @@ -266,7 +304,10 @@ def get(self, key: Any, default: Any | None = None) -> Any:
The result value.

"""
return self[key] if key in self else default
try:
return self[key]
except KeyError:
return default

def __repr__(self) -> str:
"""Return a string representation of the map."""
Expand Down
4 changes: 2 additions & 2 deletions python/tvm_ffi/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ def __init__(self, a: int, b: int) -> None:
class TestObjectDerived(TestObjectBase):
"""Test object derived class."""

v_map: Map
v_array: Array
v_map: Map[Any, Any]
v_array: Array[Any]


def create_object(type_key: str, **kwargs: Any) -> Object:
Expand Down
Loading