diff --git a/eagerpy/tensor/pytorch.py b/eagerpy/tensor/pytorch.py index 17f082f..97b8015 100644 --- a/eagerpy/tensor/pytorch.py +++ b/eagerpy/tensor/pytorch.py @@ -554,8 +554,11 @@ def __ge__(self: TensorType, other: TensorOrScalar) -> TensorType: def __getitem__(self: TensorType, index: Any) -> TensorType: if isinstance(index, tuple): index = tuple(x.raw if isinstance(x, Tensor) else x for x in index) - elif isinstance(index, Tensor): - index = index.raw + else: + if isinstance(index, Tensor): + index = index.raw + if isinstance(index, np.ndarray): + index = torch.as_tensor(index) return type(self)(self.raw[index]) def take_along_axis(self: TensorType, index: TensorType, axis: int) -> TensorType: