diff --git a/src/tensor/index.rs b/src/tensor/index.rs index d8ad2fa5..2bed7772 100644 --- a/src/tensor/index.rs +++ b/src/tensor/index.rs @@ -52,7 +52,7 @@ //! shape mismatch error due to advanced indexing rule. Another distinction //! is that `i` guarantees the input and result tensor shares the same //! underlying storage, while NumPy may copy the tensor in certain scenarios. -use crate::{TchError, Tensor}; +use crate::{Result, TchError, Tensor}; use std::ops::{ Bound, Range, RangeBounds, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive, }; @@ -133,6 +133,7 @@ impl_from_range!(RangeToInclusive); pub trait IndexOp { fn i(&self, index: T) -> Tensor; + fn f_i(&self, index: T) -> Result; } impl IndexOp for Tensor @@ -140,7 +141,11 @@ where A: Into, { fn i(&self, index: A) -> Tensor { - self.indexer(&[index.into()]) + self.f_i(index).unwrap() + } + + fn f_i(&self, index: A) -> Result { + self.f_indexer(&[index.into()]) } } @@ -149,8 +154,12 @@ where A: Into, { fn i(&self, index: (A,)) -> Tensor { + self.f_i(index).unwrap() + } + + fn f_i(&self, index: (A,)) -> Result { let idx_a = index.0.into(); - self.indexer(&[idx_a]) + self.f_indexer(&[idx_a]) } } @@ -160,9 +169,13 @@ where B: Into, { fn i(&self, index: (A, B)) -> Tensor { + self.f_i(index).unwrap() + } + + fn f_i(&self, index: (A, B)) -> Result { let idx_a = index.0.into(); let idx_b = index.1.into(); - self.indexer(&[idx_a, idx_b]) + self.f_indexer(&[idx_a, idx_b]) } } @@ -173,10 +186,14 @@ where C: Into, { fn i(&self, index: (A, B, C)) -> Tensor { + self.f_i(index).unwrap() + } + + fn f_i(&self, index: (A, B, C)) -> Result { let idx_a = index.0.into(); let idx_b = index.1.into(); let idx_c = index.2.into(); - self.indexer(&[idx_a, idx_b, idx_c]) + self.f_indexer(&[idx_a, idx_b, idx_c]) } } @@ -188,11 +205,15 @@ where D: Into, { fn i(&self, index: (A, B, C, D)) -> Tensor { + self.f_i(index).unwrap() + } + + fn f_i(&self, index: (A, B, C, D)) -> Result { let idx_a = index.0.into(); let idx_b = index.1.into(); let idx_c = index.2.into(); let idx_d = index.3.into(); - self.indexer(&[idx_a, idx_b, idx_c, idx_d]) + self.f_indexer(&[idx_a, idx_b, idx_c, idx_d]) } } @@ -205,12 +226,16 @@ where E: Into, { fn i(&self, index: (A, B, C, D, E)) -> Tensor { + self.f_i(index).unwrap() + } + + fn f_i(&self, index: (A, B, C, D, E)) -> Result { let idx_a = index.0.into(); let idx_b = index.1.into(); let idx_c = index.2.into(); let idx_d = index.3.into(); let idx_e = index.4.into(); - self.indexer(&[idx_a, idx_b, idx_c, idx_d, idx_e]) + self.f_indexer(&[idx_a, idx_b, idx_c, idx_d, idx_e]) } } @@ -224,13 +249,17 @@ where F: Into, { fn i(&self, index: (A, B, C, D, E, F)) -> Tensor { + self.f_i(index).unwrap() + } + + fn f_i(&self, index: (A, B, C, D, E, F)) -> Result { let idx_a = index.0.into(); let idx_b = index.1.into(); let idx_c = index.2.into(); let idx_d = index.3.into(); let idx_e = index.4.into(); let idx_f = index.5.into(); - self.indexer(&[idx_a, idx_b, idx_c, idx_d, idx_e, idx_f]) + self.f_indexer(&[idx_a, idx_b, idx_c, idx_d, idx_e, idx_f]) } } @@ -245,6 +274,10 @@ where G: Into, { fn i(&self, index: (A, B, C, D, E, F, G)) -> Tensor { + self.f_i(index).unwrap() + } + + fn f_i(&self, index: (A, B, C, D, E, F, G)) -> Result { let idx_a = index.0.into(); let idx_b = index.1.into(); let idx_c = index.2.into(); @@ -252,12 +285,12 @@ where let idx_e = index.4.into(); let idx_f = index.5.into(); let idx_g = index.6.into(); - self.indexer(&[idx_a, idx_b, idx_c, idx_d, idx_e, idx_f, idx_g]) + self.f_indexer(&[idx_a, idx_b, idx_c, idx_d, idx_e, idx_f, idx_g]) } } impl Tensor { - fn f_indexer(&self, index_spec: &[TensorIndexer]) -> Result { + fn f_indexer(&self, index_spec: &[TensorIndexer]) -> Result { use std::ops::Bound::*; use TensorIndexer::*; @@ -321,7 +354,7 @@ impl Tensor { (Excluded(start), Included(end)) => Some((*start + 1, *end - *start)), (Excluded(start), Excluded(end)) => Some((*start + 1, *end - *start - 1)), } { - (curr_tensor.narrow(curr_idx, start, length.max(0)), curr_idx + 1) + (curr_tensor.f_narrow(curr_idx, start, length.max(0))?, curr_idx + 1) } else { (curr_tensor, curr_idx + 1) } @@ -331,15 +364,9 @@ impl Tensor { (curr_tensor.index_select(curr_idx, &index_tensor), curr_idx + 1) } }; - curr_tensor = next_tensor; curr_idx = next_idx; } - Ok(curr_tensor) } - - fn indexer(&self, index_spec: &[TensorIndexer]) -> Tensor { - self.f_indexer(index_spec).unwrap() - } } diff --git a/tests/tensor_tests.rs b/tests/tensor_tests.rs index 1da8b1db..fc8e585a 100644 --- a/tests/tensor_tests.rs +++ b/tests/tensor_tests.rs @@ -1,3 +1,4 @@ +#![allow(clippy::unnecessary_fallible_conversions)] use anyhow::Result; use half::f16; use std::convert::{TryFrom, TryInto};