diff --git a/python/tvm_ffi/container.py b/python/tvm_ffi/container.py index 6f29dfdf..e3b2fb73 100644 --- a/python/tvm_ffi/container.py +++ b/python/tvm_ffi/container.py @@ -18,9 +18,12 @@ from __future__ import annotations -import collections.abc +import operator +from collections.abc import ItemsView as ItemsViewBase from collections.abc import Iterator, Mapping, Sequence -from typing import Any, Callable +from collections.abc import KeysView as KeysViewBase +from collections.abc import ValuesView as ValuesViewBase +from typing import Any, Callable, SupportsIndex, TypeVar, cast, overload from . import _ffi_api, core from .registry import register_object @@ -28,12 +31,18 @@ __all__ = ["Array", "Map"] +T = TypeVar("T") +K = TypeVar("K") +V = TypeVar("V") +_DefaultT = TypeVar("_DefaultT") + + def getitem_helper( obj: Any, - elem_getter: Callable[[Any, int], Any], + elem_getter: Callable[[Any, int], T], length: int, - idx: int | slice, -) -> Any: + idx: SupportsIndex | slice, +) -> T | list[T]: """Implement a pythonic __getitem__ helper. Parameters @@ -41,47 +50,46 @@ def getitem_helper( obj: Any The original object - elem_getter : Callable[[Any, int], Any] + elem_getter : Callable[[Any, int], T] A simple function that takes index and return a single element. length : int The size of the array - idx : int or slice + idx : SupportsIndex or slice The argument passed to getitem Returns ------- result : object - The result of getitem + The element for integer indices or a ``list`` for slices. """ if isinstance(idx, slice): - start = idx.start if idx.start is not None else 0 - stop = idx.stop if idx.stop is not None else length - step = idx.step if idx.step is not None else 1 - if start < 0: - start += length - if stop < 0: - stop += length + start, stop, step = idx.indices(length) return [elem_getter(obj, i) for i in range(start, stop, step)] - if idx < -length or idx >= length: - raise IndexError(f"Index out of range. size: {length}, got index {idx}") - if idx < 0: - idx += length - return elem_getter(obj, idx) + try: + index = operator.index(idx) + except TypeError as exc: # pragma: no cover - defensive, matches list behaviour + raise TypeError(f"indices must be integers or slices, not {type(idx).__name__}") from exc + + if index < -length or index >= length: + raise IndexError(f"Index out of range. size: {length}, got index {index}") + if index < 0: + index += length + return elem_getter(obj, index) @register_object("ffi.Array") -class Array(core.Object, collections.abc.Sequence): +class Array(core.Object, Sequence[T]): """Array container that represents a sequence of values in ffi. :py:func:`tvm_ffi.convert` will map python list/tuple to this class. Parameters ---------- - input_list : Sequence[Any] + input_list : Sequence[T] The list of values to be stored in the array. See Also @@ -100,18 +108,34 @@ class Array(core.Object, collections.abc.Sequence): """ - def __init__(self, input_list: Sequence[Any]) -> None: + def __init__(self, input_list: Sequence[T]) -> None: """Construct an Array from a Python sequence.""" self.__init_handle_by_constructor__(_ffi_api.Array, *input_list) - def __getitem__(self, idx: int | slice) -> Any: - """Return one element or a Python list for a slice.""" - return getitem_helper(self, _ffi_api.ArrayGetItem, len(self), idx) + @overload + def __getitem__(self, idx: SupportsIndex, /) -> T: ... + + @overload + def __getitem__(self, idx: slice, /) -> Array[T]: ... + + def __getitem__(self, idx: SupportsIndex | slice, /) -> T | Array[T]: + """Return one element or a new :class:`Array` for a slice.""" + length = len(self) + result = getitem_helper(self, _ffi_api.ArrayGetItem, length, idx) + if isinstance(result, list): + return cast(Array[T], type(self)(result)) + return result def __len__(self) -> int: """Return the number of elements in the array.""" return _ffi_api.ArraySize(self) + def __iter__(self) -> Iterator[T]: + """Iterate over the elements in the array.""" + length = len(self) + for i in range(length): + yield self[i] + def __repr__(self) -> str: """Return a string representation of the array.""" # exception safety handling for chandle=None @@ -120,79 +144,87 @@ def __repr__(self) -> str: return "[" + ", ".join([x.__repr__() for x in self]) + "]" -class KeysView(collections.abc.KeysView): +class KeysView(KeysViewBase[K]): """Helper class to return keys view.""" - def __init__(self, backend_map: Map) -> None: + def __init__(self, backend_map: Map[K, V]) -> None: self._backend_map = backend_map def __len__(self) -> int: return len(self._backend_map) - def __iter__(self) -> Iterator[Any]: - if self.__len__() == 0: - return - functor = _ffi_api.MapForwardIterFunctor(self._backend_map) - while True: - k = functor(0) - yield k + def __iter__(self) -> Iterator[K]: + size = len(self._backend_map) + functor: Callable[[int], Any] = _ffi_api.MapForwardIterFunctor(self._backend_map) + for _ in range(size): + key = cast(K, functor(0)) + yield key if not functor(2): break - def __contains__(self, k: Any) -> bool: - return self._backend_map.__contains__(k) + def __contains__(self, k: object) -> bool: + return k in self._backend_map -class ValuesView(collections.abc.ValuesView): +class ValuesView(ValuesViewBase[V]): """Helper class to return values view.""" - def __init__(self, backend_map: Map) -> None: + def __init__(self, backend_map: Map[K, V]) -> None: self._backend_map = backend_map def __len__(self) -> int: return len(self._backend_map) - def __iter__(self) -> Iterator[Any]: - if self.__len__() == 0: - return - functor = _ffi_api.MapForwardIterFunctor(self._backend_map) - while True: - v = functor(1) - yield v + def __iter__(self) -> Iterator[V]: + size = len(self._backend_map) + functor: Callable[[int], Any] = _ffi_api.MapForwardIterFunctor(self._backend_map) + for _ in range(size): + value = cast(V, functor(1)) + yield value if not functor(2): break -class ItemsView(collections.abc.ItemsView): +class ItemsView(ItemsViewBase[K, V]): """Helper class to return items view.""" - def __init__(self, backend_map: Map) -> None: - self.backend_map = backend_map + def __init__(self, backend_map: Map[K, V]) -> None: + self._backend_map = backend_map def __len__(self) -> int: - return len(self.backend_map) - - def __iter__(self) -> Iterator[tuple[Any, Any]]: - if self.__len__() == 0: - return - functor = _ffi_api.MapForwardIterFunctor(self.backend_map) - while True: - k = functor(0) - v = functor(1) - yield (k, v) + return len(self._backend_map) + + def __iter__(self) -> Iterator[tuple[K, V]]: + size = len(self._backend_map) + functor: Callable[[int], Any] = _ffi_api.MapForwardIterFunctor(self._backend_map) + for _ in range(size): + key = cast(K, functor(0)) + value = cast(V, functor(1)) + yield (key, value) if not functor(2): break + def __contains__(self, item: object) -> bool: + if not isinstance(item, tuple) or len(item) != 2: + return False + key, value = item + try: + existing_value = self._backend_map[key] + except KeyError: + return False + else: + return existing_value == value + @register_object("ffi.Map") -class Map(core.Object, collections.abc.Mapping): +class Map(core.Object, Mapping[K, V]): """Map container. :py:func:`tvm_ffi.convert` will map python dict to this class. Parameters ---------- - input_dict : Mapping[Any, Any] + input_dict : Mapping[K, V] The dictionary of values to be stored in the map. See Also @@ -213,31 +245,31 @@ class Map(core.Object, collections.abc.Mapping): """ - def __init__(self, input_dict: Mapping[Any, Any]) -> None: + def __init__(self, input_dict: Mapping[K, V]) -> None: """Construct a Map from a Python mapping.""" - list_kvs = [] + list_kvs: list[Any] = [] for k, v in input_dict.items(): list_kvs.append(k) list_kvs.append(v) self.__init_handle_by_constructor__(_ffi_api.Map, *list_kvs) - def __getitem__(self, k: Any) -> Any: + def __getitem__(self, k: K) -> V: """Return the value for key `k` or raise KeyError.""" - return _ffi_api.MapGetItem(self, k) + return cast(V, _ffi_api.MapGetItem(self, k)) - def __contains__(self, k: Any) -> bool: + def __contains__(self, k: object) -> bool: """Return True if the map contains key `k`.""" return _ffi_api.MapCount(self, k) != 0 - def keys(self) -> KeysView: + def keys(self) -> KeysView[K]: """Return a dynamic view of the map's keys.""" return KeysView(self) - def values(self) -> ValuesView: + def values(self) -> ValuesView[V]: """Return a dynamic view of the map's values.""" return ValuesView(self) - def items(self) -> ItemsView: + def items(self) -> ItemsView[K, V]: """Get the items from the map.""" return ItemsView(self) @@ -245,11 +277,17 @@ def __len__(self) -> int: """Return the number of items in the map.""" return _ffi_api.MapSize(self) - def __iter__(self) -> Iterator[Any]: + def __iter__(self) -> Iterator[K]: """Iterate over the map's keys.""" return iter(self.keys()) - def get(self, key: Any, default: Any | None = None) -> Any: + @overload + def get(self, key: K) -> V | None: ... + + @overload + def get(self, key: K, default: V | _DefaultT) -> V | _DefaultT: ... + + def get(self, key: K, default: V | _DefaultT | None = None) -> V | _DefaultT | None: """Get an element with a default value. Parameters @@ -266,7 +304,10 @@ def get(self, key: Any, default: Any | None = None) -> Any: The result value. """ - return self[key] if key in self else default + try: + return self[key] + except KeyError: + return default def __repr__(self) -> str: """Return a string representation of the map.""" diff --git a/python/tvm_ffi/testing.py b/python/tvm_ffi/testing.py index 97691584..5bbed339 100644 --- a/python/tvm_ffi/testing.py +++ b/python/tvm_ffi/testing.py @@ -52,8 +52,8 @@ def __init__(self, a: int, b: int) -> None: class TestObjectDerived(TestObjectBase): """Test object derived class.""" - v_map: Map - v_array: Array + v_map: Map[Any, Any] + v_array: Array[Any] def create_object(type_key: str, **kwargs: Any) -> Object: