diff --git a/include/tvm/ffi/container/tensor.h b/include/tvm/ffi/container/tensor.h index 3675bb5a..eb6c9fe5 100644 --- a/include/tvm/ffi/container/tensor.h +++ b/include/tvm/ffi/container/tensor.h @@ -325,7 +325,12 @@ class Tensor : public ObjectRef { */ int64_t size(int64_t idx) const { const TensorObj* ptr = get(); - return ptr->shape[idx >= 0 ? idx : (ptr->ndim + idx)]; + int64_t adjusted_idx = idx >= 0 ? idx : (ptr->ndim + idx); + if (adjusted_idx < 0 || adjusted_idx >= ptr->ndim) { + TVM_FFI_THROW(IndexError) << "Index " << idx << " out of bounds for tensor with " << ptr->ndim + << " dimensions"; + } + return ptr->shape[adjusted_idx]; } /*! @@ -336,7 +341,12 @@ class Tensor : public ObjectRef { */ int64_t stride(int64_t idx) const { const TensorObj* ptr = get(); - return ptr->strides[idx >= 0 ? idx : (ptr->ndim + idx)]; + int64_t adjusted_idx = idx >= 0 ? idx : (ptr->ndim + idx); + if (adjusted_idx < 0 || adjusted_idx >= ptr->ndim) { + TVM_FFI_THROW(IndexError) << "Index " << idx << " out of bounds for tensor with " << ptr->ndim + << " dimensions"; + } + return ptr->strides[adjusted_idx]; } /*! @@ -754,7 +764,14 @@ class TensorView { * \param idx The index of the size. * \return The size of the idx-th dimension. */ - int64_t size(int64_t idx) const { return tensor_.shape[idx >= 0 ? idx : tensor_.ndim + idx]; } + int64_t size(int64_t idx) const { + int64_t adjusted_idx = idx >= 0 ? idx : (tensor_.ndim + idx); + if (adjusted_idx < 0 || adjusted_idx >= tensor_.ndim) { + TVM_FFI_THROW(IndexError) << "Index " << idx << " out of bounds for tensor with " + << tensor_.ndim << " dimensions"; + } + return tensor_.shape[adjusted_idx]; + } /*! * \brief Get the stride of the idx-th dimension. If the idx is negative, @@ -762,7 +779,14 @@ class TensorView { * \param idx The index of the stride. * \return The stride of the idx-th dimension. */ - int64_t stride(int64_t idx) const { return tensor_.strides[idx >= 0 ? idx : tensor_.ndim + idx]; } + int64_t stride(int64_t idx) const { + int64_t adjusted_idx = idx >= 0 ? idx : (tensor_.ndim + idx); + if (adjusted_idx < 0 || adjusted_idx >= tensor_.ndim) { + TVM_FFI_THROW(IndexError) << "Index " << idx << " out of bounds for tensor with " + << tensor_.ndim << " dimensions"; + } + return tensor_.strides[adjusted_idx]; + } /*! * \brief Get the byte offset of the Tensor. diff --git a/tests/cpp/test_tensor.cc b/tests/cpp/test_tensor.cc index b5c82bc2..1b14e6fb 100644 --- a/tests/cpp/test_tensor.cc +++ b/tests/cpp/test_tensor.cc @@ -365,4 +365,18 @@ TEST(Tensor, AsStrided) { EXPECT_EQ(offset_data2[0 * 3 + 0 * 1], 2.0f); // Points to data[2] } +TEST(Tensor, SizeStrideOutOfBounds) { + Tensor tensor = Empty({2, 3, 4}, DLDataType({kDLFloat, 32, 1}), DLDevice({kDLCPU, 0})); + EXPECT_THROW({ tensor.size(3); }, tvm::ffi::Error); + EXPECT_THROW({ tensor.size(-4); }, tvm::ffi::Error); + EXPECT_THROW({ tensor.stride(3); }, tvm::ffi::Error); + EXPECT_THROW({ tensor.stride(-4); }, tvm::ffi::Error); + + TensorView tensor_view = tensor; + EXPECT_THROW({ tensor_view.size(3); }, tvm::ffi::Error); + EXPECT_THROW({ tensor_view.size(-4); }, tvm::ffi::Error); + EXPECT_THROW({ tensor_view.stride(3); }, tvm::ffi::Error); + EXPECT_THROW({ tensor_view.stride(-4); }, tvm::ffi::Error); +} + } // namespace diff --git a/tests/python/test_tensor.py b/tests/python/test_tensor.py index 9c938a8c..d091d859 100644 --- a/tests/python/test_tensor.py +++ b/tests/python/test_tensor.py @@ -20,6 +20,7 @@ from types import ModuleType from typing import Any, NamedTuple +import numpy.typing as npt import pytest torch: ModuleType | None @@ -33,7 +34,7 @@ def test_tensor_attributes() -> None: - data = np.zeros((10, 8, 4, 2), dtype="int16") + data: npt.NDArray[Any] = np.zeros((10, 8, 4, 2), dtype="int16") if not hasattr(data, "__dlpack__"): return x = tvm_ffi.from_dlpack(data) @@ -84,7 +85,7 @@ class MyTensor(tvm_ffi.Tensor): old_tensor = tvm_ffi.core._CLASS_TENSOR tvm_ffi.core._set_class_tensor(MyTensor) - data = np.zeros((10, 8, 4, 2), dtype="int16") + data: npt.NDArray[Any] = np.zeros((10, 8, 4, 2), dtype="int16") if not hasattr(data, "__dlpack__"): return x = tvm_ffi.from_dlpack(data) @@ -105,7 +106,7 @@ def __tvm_ffi_object__(self) -> tvm_ffi.Tensor: """Implement __tvm_ffi_object__ protocol.""" return self._tensor - data = np.zeros((10, 8, 4, 2), dtype="int32") + data: npt.NDArray[Any] = np.zeros((10, 8, 4, 2), dtype="int32") if not hasattr(data, "__dlpack__"): return x = tvm_ffi.from_dlpack(data) @@ -159,7 +160,7 @@ def test_optional_tensor_view() -> None: "testing.optional_tensor_view_has_value" ) assert not optional_tensor_view_has_value(None) - x = np.zeros((128,), dtype="float32") + x: npt.NDArray[Any] = np.zeros((128,), dtype="float32") if not hasattr(x, "__dlpack__"): return assert optional_tensor_view_has_value(x)