Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions python/tvm_ffi/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,14 +117,12 @@ def __init__(self, input_list: Iterable[T]) -> None:
def __getitem__(self, idx: SupportsIndex, /) -> T: ...

@overload
def __getitem__(self, idx: slice, /) -> Array[T]: ...
def __getitem__(self, idx: slice, /) -> Array[T] | list[T]: ...

def __getitem__(self, idx: SupportsIndex | slice, /) -> T | Array[T]:
def __getitem__(self, idx: SupportsIndex | slice, /) -> T | Array[T] | list[T]:
"""Return one element or a new :class:`Array` for a slice."""
length = len(self)
result = getitem_helper(self, _ffi_api.ArrayGetItem, length, idx)
if isinstance(result, list):
return cast(Array[T], type(self)(result))
return result

def __len__(self) -> int:
Expand Down
1 change: 1 addition & 0 deletions tests/python/test_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def test_array() -> None:
assert len(a) == 3
assert a[-1] == 3
a_slice = a[-3:-1]
assert isinstance(a_slice, list) # TVM array slicing returns a list[T] instead of Array[T]
assert (a_slice[0], a_slice[1]) == (1, 2)


Expand Down
Loading