From cffd5094800f523b54e4f673238907627698cbb8 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Tue, 23 Sep 2025 12:23:20 -0700 Subject: [PATCH 1/4] fix(Array): Make array slicing behavior consistent to TVM --- python/tvm_ffi/container.py | 4 +--- tests/python/test_container.py | 1 + 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/python/tvm_ffi/container.py b/python/tvm_ffi/container.py index f179fc9f..a9de497a 100644 --- a/python/tvm_ffi/container.py +++ b/python/tvm_ffi/container.py @@ -119,12 +119,10 @@ def __getitem__(self, idx: SupportsIndex, /) -> T: ... @overload def __getitem__(self, idx: slice, /) -> Array[T]: ... - def __getitem__(self, idx: SupportsIndex | slice, /) -> T | Array[T]: + def __getitem__(self, idx: SupportsIndex | slice, /) -> T | Array[T] | list[T]: """Return one element or a new :class:`Array` for a slice.""" length = len(self) result = getitem_helper(self, _ffi_api.ArrayGetItem, length, idx) - if isinstance(result, list): - return cast(Array[T], type(self)(result)) return result def __len__(self) -> int: diff --git a/tests/python/test_container.py b/tests/python/test_container.py index 54b41b70..3f3465ea 100644 --- a/tests/python/test_container.py +++ b/tests/python/test_container.py @@ -28,6 +28,7 @@ def test_array() -> None: assert len(a) == 3 assert a[-1] == 3 a_slice = a[-3:-1] + assert isinstance(a_slice, list) # TVM array slicing returns a list[T] instead of Array[T] assert (a_slice[0], a_slice[1]) == (1, 2) From e95a0a6e4638374d353116d83fdd78d5c2baa1a4 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Tue, 23 Sep 2025 12:31:11 -0700 Subject: [PATCH 2/4] Update __getitem__ to support list return type --- python/tvm_ffi/container.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm_ffi/container.py b/python/tvm_ffi/container.py index a9de497a..86d2c311 100644 --- a/python/tvm_ffi/container.py +++ b/python/tvm_ffi/container.py @@ -117,7 +117,7 @@ def __init__(self, input_list: Iterable[T]) -> None: def __getitem__(self, idx: SupportsIndex, /) -> T: ... @overload - def __getitem__(self, idx: slice, /) -> Array[T]: ... + def __getitem__(self, idx: slice, /) -> Array[T] | list[T]: ... def __getitem__(self, idx: SupportsIndex | slice, /) -> T | Array[T] | list[T]: """Return one element or a new :class:`Array` for a slice.""" From e6a5e00a9f2c642734a468204c6052a77a156dc0 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Tue, 23 Sep 2025 13:19:49 -0700 Subject: [PATCH 3/4] Update __getitem__ method to return list instead of Array --- python/tvm_ffi/container.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm_ffi/container.py b/python/tvm_ffi/container.py index 86d2c311..3fe67ee1 100644 --- a/python/tvm_ffi/container.py +++ b/python/tvm_ffi/container.py @@ -117,9 +117,9 @@ def __init__(self, input_list: Iterable[T]) -> None: def __getitem__(self, idx: SupportsIndex, /) -> T: ... @overload - def __getitem__(self, idx: slice, /) -> Array[T] | list[T]: ... + def __getitem__(self, idx: slice, /) -> list[T]: ... - def __getitem__(self, idx: SupportsIndex | slice, /) -> T | Array[T] | list[T]: + def __getitem__(self, idx: SupportsIndex | slice, /) -> T | list[T]: """Return one element or a new :class:`Array` for a slice.""" length = len(self) result = getitem_helper(self, _ffi_api.ArrayGetItem, length, idx) From cbd725c8bc877b99d62615983dbb0386be658b21 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Tue, 23 Sep 2025 13:40:39 -0700 Subject: [PATCH 4/4] Update container.py --- python/tvm_ffi/container.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm_ffi/container.py b/python/tvm_ffi/container.py index 3fe67ee1..d4aa24f5 100644 --- a/python/tvm_ffi/container.py +++ b/python/tvm_ffi/container.py @@ -120,7 +120,7 @@ def __getitem__(self, idx: SupportsIndex, /) -> T: ... def __getitem__(self, idx: slice, /) -> list[T]: ... def __getitem__(self, idx: SupportsIndex | slice, /) -> T | list[T]: - """Return one element or a new :class:`Array` for a slice.""" + """Return one element or a list for a slice.""" length = len(self) result = getitem_helper(self, _ffi_api.ArrayGetItem, length, idx) return result