diff --git a/python/tvm_ffi/container.py b/python/tvm_ffi/container.py index d2aff296..2b64ed63 100644 --- a/python/tvm_ffi/container.py +++ b/python/tvm_ffi/container.py @@ -192,6 +192,10 @@ def __contains__(self, value: object) -> bool: """Check if the array contains a value.""" return _ffi_api.ArrayContains(self, value) + def __bool__(self) -> bool: + """Return True if the array is non-empty.""" + return len(self) > 0 + def __add__(self, other: Iterable[T]) -> Array[T]: """Concatenate two arrays.""" return type(self)(itertools.chain(self, other)) @@ -337,6 +341,10 @@ def __len__(self) -> int: """Return the number of items in the map.""" return _ffi_api.MapSize(self) + def __bool__(self) -> bool: + """Return True if the map is non-empty.""" + return len(self) > 0 + def __iter__(self) -> Iterator[K]: """Iterate over the map's keys.""" return iter(self.keys()) diff --git a/tests/python/test_container.py b/tests/python/test_container.py index 8650e080..37f7432b 100644 --- a/tests/python/test_container.py +++ b/tests/python/test_container.py @@ -227,3 +227,31 @@ def test_large_map_get() -> None: def test_array_contains(arr: list[Any], value: Any, expected: bool) -> None: a = tvm_ffi.convert(arr) assert (value in a) == expected + + +@pytest.mark.parametrize( + "arr, expected", + [ + ([1, 2, 3], True), + ([1], True), + ([], False), + (["hello"], True), + ], +) +def test_array_bool(arr: list[Any], expected: bool) -> None: + a = tvm_ffi.Array(arr) + assert bool(a) is expected + + +@pytest.mark.parametrize( + "mapping, expected", + [ + ({"a": 1, "b": 2}, True), + ({"a": 1}, True), + ({}, False), + ({1: "one"}, True), + ], +) +def test_map_bool(mapping: dict[Any, Any], expected: bool) -> None: + m = tvm_ffi.Map(mapping) + assert bool(m) is expected