diff --git a/python/tvm_ffi/container.py b/python/tvm_ffi/container.py index f179fc9f..d4aa24f5 100644 --- a/python/tvm_ffi/container.py +++ b/python/tvm_ffi/container.py @@ -117,14 +117,12 @@ def __init__(self, input_list: Iterable[T]) -> None: def __getitem__(self, idx: SupportsIndex, /) -> T: ... @overload - def __getitem__(self, idx: slice, /) -> Array[T]: ... + def __getitem__(self, idx: slice, /) -> list[T]: ... - def __getitem__(self, idx: SupportsIndex | slice, /) -> T | Array[T]: - """Return one element or a new :class:`Array` for a slice.""" + def __getitem__(self, idx: SupportsIndex | slice, /) -> T | list[T]: + """Return one element or a list for a slice.""" length = len(self) result = getitem_helper(self, _ffi_api.ArrayGetItem, length, idx) - if isinstance(result, list): - return cast(Array[T], type(self)(result)) return result def __len__(self) -> int: diff --git a/tests/python/test_container.py b/tests/python/test_container.py index 54b41b70..3f3465ea 100644 --- a/tests/python/test_container.py +++ b/tests/python/test_container.py @@ -28,6 +28,7 @@ def test_array() -> None: assert len(a) == 3 assert a[-1] == 3 a_slice = a[-3:-1] + assert isinstance(a_slice, list) # TVM array slicing returns a list[T] instead of Array[T] assert (a_slice[0], a_slice[1]) == (1, 2)