diff --git a/include/tvm/ffi/container/array.h b/include/tvm/ffi/container/array.h index 60fa6fe4..32c1a228 100644 --- a/include/tvm/ffi/container/array.h +++ b/include/tvm/ffi/container/array.h @@ -70,7 +70,7 @@ class ArrayObj : public Object, public details::InplaceArrayBase= size_) { + if (i < 0 || i >= size_) { TVM_FFI_THROW(IndexError) << "Index " << i << " out of bounds " << size_; } return static_cast(data_)[i]; @@ -91,7 +91,7 @@ class ArrayObj : public Object, public details::InplaceArrayBase= size_) { + if (i < 0 || i >= size_) { TVM_FFI_THROW(IndexError) << "Index " << i << " out of bounds " << size_; } static_cast(data_)[i] = std::move(item); diff --git a/tests/cpp/test_array.cc b/tests/cpp/test_array.cc index a86cc294..b7d1fa3d 100644 --- a/tests/cpp/test_array.cc +++ b/tests/cpp/test_array.cc @@ -312,4 +312,35 @@ TEST(Array, Contains) { EXPECT_FALSE(f(str_arr, String("foo")).cast()); } +TEST(Array, NegativeIndexThrows) { + Array arr = {1, 2, 3}; + // Directly test ArrayObj methods, which are the ones modified in this PR. + // The Array wrapper methods already had negative index checks. + ArrayObj* arr_obj = arr.GetArrayObj(); + + // Test ArrayObj::at (which calls operator[]) + EXPECT_THROW( + { + try { + [[maybe_unused]] const auto& val = arr_obj->at(-1); + } catch (const Error& error) { + EXPECT_EQ(error.kind(), "IndexError"); + throw; + } + }, + ::tvm::ffi::Error); + + // Test ArrayObj::SetItem + EXPECT_THROW( + { + try { + arr_obj->SetItem(-1, Any(42)); + } catch (const Error& error) { + EXPECT_EQ(error.kind(), "IndexError"); + throw; + } + }, + ::tvm::ffi::Error); +} + } // namespace