Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 28 additions & 4 deletions include/tvm/ffi/container/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,12 @@ class Tensor : public ObjectRef {
*/
int64_t size(int64_t idx) const {
const TensorObj* ptr = get();
return ptr->shape[idx >= 0 ? idx : (ptr->ndim + idx)];
int64_t adjusted_idx = idx >= 0 ? idx : (ptr->ndim + idx);
if (adjusted_idx < 0 || adjusted_idx >= ptr->ndim) {
TVM_FFI_THROW(IndexError) << "Index " << idx << " out of bounds for tensor with " << ptr->ndim
<< " dimensions";
}
return ptr->shape[adjusted_idx];
}
Comment on lines 326 to 334
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The index adjustment and bounds checking logic is duplicated in Tensor::size, Tensor::stride, TensorView::size, and TensorView::stride. To improve maintainability and reduce code duplication, consider extracting this logic into a common helper function.

For example, you could add a free helper function in a details namespace:

namespace tvm {
namespace ffi {
namespace details {
inline int64_t CheckAndAdjustTensorIndex(int64_t idx, int32_t ndim) {
    int64_t adjusted_idx = idx >= 0 ? idx : (ndim + idx);
    if (adjusted_idx < 0 || adjusted_idx >= ndim) {
      TVM_FFI_THROW(IndexError) << "Index " << idx << " out of bounds for tensor with " << ndim
                                << " dimensions";
    }
    return adjusted_idx;
}
} // namespace details
} // namespace ffi
} // namespace tvm

Then Tensor::size and TensorView::size could be simplified to:

// Tensor::size
int64_t size(int64_t idx) const {
  const TensorObj* ptr = get();
  return ptr->shape[details::CheckAndAdjustTensorIndex(idx, ptr->ndim)];
}

// TensorView::size
int64_t size(int64_t idx) const {
  return tensor_.shape[details::CheckAndAdjustTensorIndex(idx, tensor_.ndim)];
}

A similar simplification would apply to the stride methods.


/*!
Expand All @@ -336,7 +341,12 @@ class Tensor : public ObjectRef {
*/
int64_t stride(int64_t idx) const {
const TensorObj* ptr = get();
return ptr->strides[idx >= 0 ? idx : (ptr->ndim + idx)];
int64_t adjusted_idx = idx >= 0 ? idx : (ptr->ndim + idx);
if (adjusted_idx < 0 || adjusted_idx >= ptr->ndim) {
TVM_FFI_THROW(IndexError) << "Index " << idx << " out of bounds for tensor with " << ptr->ndim
<< " dimensions";
}
return ptr->strides[adjusted_idx];
}

/*!
Expand Down Expand Up @@ -754,15 +764,29 @@ class TensorView {
* \param idx The index of the size.
* \return The size of the idx-th dimension.
*/
int64_t size(int64_t idx) const { return tensor_.shape[idx >= 0 ? idx : tensor_.ndim + idx]; }
int64_t size(int64_t idx) const {
int64_t adjusted_idx = idx >= 0 ? idx : (tensor_.ndim + idx);
if (adjusted_idx < 0 || adjusted_idx >= tensor_.ndim) {
TVM_FFI_THROW(IndexError) << "Index " << idx << " out of bounds for tensor with "
<< tensor_.ndim << " dimensions";
}
return tensor_.shape[adjusted_idx];
}

/*!
* \brief Get the stride of the idx-th dimension. If the idx is negative,
* it gets the stride of last idx-th dimension.
* \param idx The index of the stride.
* \return The stride of the idx-th dimension.
*/
int64_t stride(int64_t idx) const { return tensor_.strides[idx >= 0 ? idx : tensor_.ndim + idx]; }
int64_t stride(int64_t idx) const {
int64_t adjusted_idx = idx >= 0 ? idx : (tensor_.ndim + idx);
if (adjusted_idx < 0 || adjusted_idx >= tensor_.ndim) {
TVM_FFI_THROW(IndexError) << "Index " << idx << " out of bounds for tensor with "
<< tensor_.ndim << " dimensions";
}
return tensor_.strides[adjusted_idx];
}

/*!
* \brief Get the byte offset of the Tensor.
Expand Down
14 changes: 14 additions & 0 deletions tests/cpp/test_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -365,4 +365,18 @@ TEST(Tensor, AsStrided) {
EXPECT_EQ(offset_data2[0 * 3 + 0 * 1], 2.0f); // Points to data[2]
}

TEST(Tensor, SizeStrideOutOfBounds) {
Tensor tensor = Empty({2, 3, 4}, DLDataType({kDLFloat, 32, 1}), DLDevice({kDLCPU, 0}));
EXPECT_THROW({ tensor.size(3); }, tvm::ffi::Error);
EXPECT_THROW({ tensor.size(-4); }, tvm::ffi::Error);
EXPECT_THROW({ tensor.stride(3); }, tvm::ffi::Error);
EXPECT_THROW({ tensor.stride(-4); }, tvm::ffi::Error);

TensorView tensor_view = tensor;
EXPECT_THROW({ tensor_view.size(3); }, tvm::ffi::Error);
EXPECT_THROW({ tensor_view.size(-4); }, tvm::ffi::Error);
EXPECT_THROW({ tensor_view.stride(3); }, tvm::ffi::Error);
EXPECT_THROW({ tensor_view.stride(-4); }, tvm::ffi::Error);
}

} // namespace
9 changes: 5 additions & 4 deletions tests/python/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from types import ModuleType
from typing import Any, NamedTuple

import numpy.typing as npt
import pytest

torch: ModuleType | None
Expand All @@ -33,7 +34,7 @@


def test_tensor_attributes() -> None:
data = np.zeros((10, 8, 4, 2), dtype="int16")
data: npt.NDArray[Any] = np.zeros((10, 8, 4, 2), dtype="int16")
if not hasattr(data, "__dlpack__"):
return
x = tvm_ffi.from_dlpack(data)
Expand Down Expand Up @@ -84,7 +85,7 @@ class MyTensor(tvm_ffi.Tensor):
old_tensor = tvm_ffi.core._CLASS_TENSOR
tvm_ffi.core._set_class_tensor(MyTensor)

data = np.zeros((10, 8, 4, 2), dtype="int16")
data: npt.NDArray[Any] = np.zeros((10, 8, 4, 2), dtype="int16")
if not hasattr(data, "__dlpack__"):
return
x = tvm_ffi.from_dlpack(data)
Expand All @@ -105,7 +106,7 @@ def __tvm_ffi_object__(self) -> tvm_ffi.Tensor:
"""Implement __tvm_ffi_object__ protocol."""
return self._tensor

data = np.zeros((10, 8, 4, 2), dtype="int32")
data: npt.NDArray[Any] = np.zeros((10, 8, 4, 2), dtype="int32")
if not hasattr(data, "__dlpack__"):
return
x = tvm_ffi.from_dlpack(data)
Expand Down Expand Up @@ -159,7 +160,7 @@ def test_optional_tensor_view() -> None:
"testing.optional_tensor_view_has_value"
)
assert not optional_tensor_view_has_value(None)
x = np.zeros((128,), dtype="float32")
x: npt.NDArray[Any] = np.zeros((128,), dtype="float32")
if not hasattr(x, "__dlpack__"):
return
assert optional_tensor_view_has_value(x)
Expand Down