From e9ec104f627112f34b7354be4576950b35d3c264 Mon Sep 17 00:00:00 2001 From: David Schwab Date: Tue, 2 Jan 2024 15:02:20 +0100 Subject: [PATCH 1/5] add tch-rs fork implement result indexing --- src/tensor/index.rs | 123 ++++++++++++++++++++++++++------------------ 1 file changed, 73 insertions(+), 50 deletions(-) diff --git a/src/tensor/index.rs b/src/tensor/index.rs index d8ad2fa5..ee921ac8 100644 --- a/src/tensor/index.rs +++ b/src/tensor/index.rs @@ -131,120 +131,147 @@ impl_from_range!(RangeInclusive); impl_from_range!(RangeTo); impl_from_range!(RangeToInclusive); +type IndexResult = Result; + pub trait IndexOp { fn i(&self, index: T) -> Tensor; + fn f_i(&self, index: T) -> IndexResult; } impl IndexOp for Tensor -where - A: Into, + where + A: Into, { fn i(&self, index: A) -> Tensor { self.indexer(&[index.into()]) } + fn f_i(&self, index: A) -> IndexResult { + self.f_indexer(&[index.into()]) + } } -impl IndexOp<(A,)> for Tensor -where - A: Into, +impl IndexOp<(A, )> for Tensor + where + A: Into, { - fn i(&self, index: (A,)) -> Tensor { + fn i(&self, index: (A, )) -> Tensor { + self.f_i(index).unwrap() + } + fn f_i(&self, index: (A, )) -> IndexResult { let idx_a = index.0.into(); - self.indexer(&[idx_a]) + self.f_indexer(&[idx_a]) } } impl IndexOp<(A, B)> for Tensor -where - A: Into, - B: Into, + where + A: Into, + B: Into, { fn i(&self, index: (A, B)) -> Tensor { + self.f_i(index).unwrap() + } + fn f_i(&self, index: (A, B)) -> IndexResult { let idx_a = index.0.into(); let idx_b = index.1.into(); - self.indexer(&[idx_a, idx_b]) + self.f_indexer(&[idx_a, idx_b]) } } impl IndexOp<(A, B, C)> for Tensor -where - A: Into, - B: Into, - C: Into, + where + A: Into, + B: Into, + C: Into, { fn i(&self, index: (A, B, C)) -> Tensor { + self.f_i(index).unwrap() + } + fn f_i(&self, index: (A, B, C)) -> IndexResult { let idx_a = index.0.into(); let idx_b = index.1.into(); let idx_c = index.2.into(); - self.indexer(&[idx_a, idx_b, idx_c]) + self.f_indexer(&[idx_a, idx_b, idx_c]) } } impl IndexOp<(A, B, C, D)> for Tensor -where - A: Into, - B: Into, - C: Into, - D: Into, + where + A: Into, + B: Into, + C: Into, + D: Into, { fn i(&self, index: (A, B, C, D)) -> Tensor { + self.f_i(index).unwrap() + } + fn f_i(&self, index: (A, B, C, D)) -> IndexResult { let idx_a = index.0.into(); let idx_b = index.1.into(); let idx_c = index.2.into(); let idx_d = index.3.into(); - self.indexer(&[idx_a, idx_b, idx_c, idx_d]) + self.f_indexer(&[idx_a, idx_b, idx_c, idx_d]) } } impl IndexOp<(A, B, C, D, E)> for Tensor -where - A: Into, - B: Into, - C: Into, - D: Into, - E: Into, + where + A: Into, + B: Into, + C: Into, + D: Into, + E: Into, { fn i(&self, index: (A, B, C, D, E)) -> Tensor { + self.f_i(index).unwrap() + } + fn f_i(&self, index: (A, B, C, D, E)) -> IndexResult { let idx_a = index.0.into(); let idx_b = index.1.into(); let idx_c = index.2.into(); let idx_d = index.3.into(); let idx_e = index.4.into(); - self.indexer(&[idx_a, idx_b, idx_c, idx_d, idx_e]) + self.f_indexer(&[idx_a, idx_b, idx_c, idx_d, idx_e]) } } impl IndexOp<(A, B, C, D, E, F)> for Tensor -where - A: Into, - B: Into, - C: Into, - D: Into, - E: Into, - F: Into, + where + A: Into, + B: Into, + C: Into, + D: Into, + E: Into, + F: Into, { fn i(&self, index: (A, B, C, D, E, F)) -> Tensor { + self.f_i(index).unwrap() + } + fn f_i(&self, index: (A, B, C, D, E, F)) -> IndexResult { let idx_a = index.0.into(); let idx_b = index.1.into(); let idx_c = index.2.into(); let idx_d = index.3.into(); let idx_e = index.4.into(); let idx_f = index.5.into(); - self.indexer(&[idx_a, idx_b, idx_c, idx_d, idx_e, idx_f]) + self.f_indexer(&[idx_a, idx_b, idx_c, idx_d, idx_e, idx_f]) } } impl IndexOp<(A, B, C, D, E, F, G)> for Tensor -where - A: Into, - B: Into, - C: Into, - D: Into, - E: Into, - F: Into, - G: Into, + where + A: Into, + B: Into, + C: Into, + D: Into, + E: Into, + F: Into, + G: Into, { fn i(&self, index: (A, B, C, D, E, F, G)) -> Tensor { + self.f_i(index).unwrap() + } + fn f_i(&self, index: (A, B, C, D, E, F, G)) -> IndexResult { let idx_a = index.0.into(); let idx_b = index.1.into(); let idx_c = index.2.into(); @@ -252,7 +279,7 @@ where let idx_e = index.4.into(); let idx_f = index.5.into(); let idx_g = index.6.into(); - self.indexer(&[idx_a, idx_b, idx_c, idx_d, idx_e, idx_f, idx_g]) + self.f_indexer(&[idx_a, idx_b, idx_c, idx_d, idx_e, idx_f, idx_g]) } } @@ -338,8 +365,4 @@ impl Tensor { Ok(curr_tensor) } - - fn indexer(&self, index_spec: &[TensorIndexer]) -> Tensor { - self.f_indexer(index_spec).unwrap() - } } From f63558d7d911557783681aab454875a3c88af1c4 Mon Sep 17 00:00:00 2001 From: David Schwab Date: Tue, 2 Jan 2024 15:17:42 +0100 Subject: [PATCH 2/5] remove unwraps --- src/tensor/index.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/tensor/index.rs b/src/tensor/index.rs index ee921ac8..6e714074 100644 --- a/src/tensor/index.rs +++ b/src/tensor/index.rs @@ -131,7 +131,7 @@ impl_from_range!(RangeInclusive); impl_from_range!(RangeTo); impl_from_range!(RangeToInclusive); -type IndexResult = Result; +pub type IndexResult = Result; pub trait IndexOp { fn i(&self, index: T) -> Tensor; @@ -143,7 +143,7 @@ impl IndexOp for Tensor A: Into, { fn i(&self, index: A) -> Tensor { - self.indexer(&[index.into()]) + self.f_i(index).unwrap() } fn f_i(&self, index: A) -> IndexResult { self.f_indexer(&[index.into()]) @@ -348,7 +348,7 @@ impl Tensor { (Excluded(start), Included(end)) => Some((*start + 1, *end - *start)), (Excluded(start), Excluded(end)) => Some((*start + 1, *end - *start - 1)), } { - (curr_tensor.narrow(curr_idx, start, length.max(0)), curr_idx + 1) + (curr_tensor.f_narrow(curr_idx, start, length.max(0))?, curr_idx + 1) } else { (curr_tensor, curr_idx + 1) } From bf3afe0a0aa4a44cc3aa758930e252b61b4feced Mon Sep 17 00:00:00 2001 From: David Schwab Date: Tue, 2 Jan 2024 15:36:44 +0100 Subject: [PATCH 3/5] fmt --- src/tensor/index.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/tensor/index.rs b/src/tensor/index.rs index 6e714074..9ec0e7a8 100644 --- a/src/tensor/index.rs +++ b/src/tensor/index.rs @@ -358,11 +358,9 @@ impl Tensor { (curr_tensor.index_select(curr_idx, &index_tensor), curr_idx + 1) } }; - curr_tensor = next_tensor; curr_idx = next_idx; } - Ok(curr_tensor) } } From 12b97a6231a826e5ff034ce0ec31bd391ee73df0 Mon Sep 17 00:00:00 2001 From: laurent Date: Sat, 6 Jan 2024 11:16:08 +0100 Subject: [PATCH 4/5] Fix rustfmt + minor tweaks. --- src/tensor/index.rs | 110 +++++++++++++++++++++++--------------------- 1 file changed, 58 insertions(+), 52 deletions(-) diff --git a/src/tensor/index.rs b/src/tensor/index.rs index 9ec0e7a8..2bed7772 100644 --- a/src/tensor/index.rs +++ b/src/tensor/index.rs @@ -52,7 +52,7 @@ //! shape mismatch error due to advanced indexing rule. Another distinction //! is that `i` guarantees the input and result tensor shares the same //! underlying storage, while NumPy may copy the tensor in certain scenarios. -use crate::{TchError, Tensor}; +use crate::{Result, TchError, Tensor}; use std::ops::{ Bound, Range, RangeBounds, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive, }; @@ -131,47 +131,48 @@ impl_from_range!(RangeInclusive); impl_from_range!(RangeTo); impl_from_range!(RangeToInclusive); -pub type IndexResult = Result; - pub trait IndexOp { fn i(&self, index: T) -> Tensor; - fn f_i(&self, index: T) -> IndexResult; + fn f_i(&self, index: T) -> Result; } impl IndexOp for Tensor - where - A: Into, +where + A: Into, { fn i(&self, index: A) -> Tensor { self.f_i(index).unwrap() } - fn f_i(&self, index: A) -> IndexResult { + + fn f_i(&self, index: A) -> Result { self.f_indexer(&[index.into()]) } } -impl IndexOp<(A, )> for Tensor - where - A: Into, +impl IndexOp<(A,)> for Tensor +where + A: Into, { - fn i(&self, index: (A, )) -> Tensor { + fn i(&self, index: (A,)) -> Tensor { self.f_i(index).unwrap() } - fn f_i(&self, index: (A, )) -> IndexResult { + + fn f_i(&self, index: (A,)) -> Result { let idx_a = index.0.into(); self.f_indexer(&[idx_a]) } } impl IndexOp<(A, B)> for Tensor - where - A: Into, - B: Into, +where + A: Into, + B: Into, { fn i(&self, index: (A, B)) -> Tensor { self.f_i(index).unwrap() } - fn f_i(&self, index: (A, B)) -> IndexResult { + + fn f_i(&self, index: (A, B)) -> Result { let idx_a = index.0.into(); let idx_b = index.1.into(); self.f_indexer(&[idx_a, idx_b]) @@ -179,15 +180,16 @@ impl IndexOp<(A, B)> for Tensor } impl IndexOp<(A, B, C)> for Tensor - where - A: Into, - B: Into, - C: Into, +where + A: Into, + B: Into, + C: Into, { fn i(&self, index: (A, B, C)) -> Tensor { self.f_i(index).unwrap() } - fn f_i(&self, index: (A, B, C)) -> IndexResult { + + fn f_i(&self, index: (A, B, C)) -> Result { let idx_a = index.0.into(); let idx_b = index.1.into(); let idx_c = index.2.into(); @@ -196,16 +198,17 @@ impl IndexOp<(A, B, C)> for Tensor } impl IndexOp<(A, B, C, D)> for Tensor - where - A: Into, - B: Into, - C: Into, - D: Into, +where + A: Into, + B: Into, + C: Into, + D: Into, { fn i(&self, index: (A, B, C, D)) -> Tensor { self.f_i(index).unwrap() } - fn f_i(&self, index: (A, B, C, D)) -> IndexResult { + + fn f_i(&self, index: (A, B, C, D)) -> Result { let idx_a = index.0.into(); let idx_b = index.1.into(); let idx_c = index.2.into(); @@ -215,17 +218,18 @@ impl IndexOp<(A, B, C, D)> for Tensor } impl IndexOp<(A, B, C, D, E)> for Tensor - where - A: Into, - B: Into, - C: Into, - D: Into, - E: Into, +where + A: Into, + B: Into, + C: Into, + D: Into, + E: Into, { fn i(&self, index: (A, B, C, D, E)) -> Tensor { self.f_i(index).unwrap() } - fn f_i(&self, index: (A, B, C, D, E)) -> IndexResult { + + fn f_i(&self, index: (A, B, C, D, E)) -> Result { let idx_a = index.0.into(); let idx_b = index.1.into(); let idx_c = index.2.into(); @@ -236,18 +240,19 @@ impl IndexOp<(A, B, C, D, E)> for Tensor } impl IndexOp<(A, B, C, D, E, F)> for Tensor - where - A: Into, - B: Into, - C: Into, - D: Into, - E: Into, - F: Into, +where + A: Into, + B: Into, + C: Into, + D: Into, + E: Into, + F: Into, { fn i(&self, index: (A, B, C, D, E, F)) -> Tensor { self.f_i(index).unwrap() } - fn f_i(&self, index: (A, B, C, D, E, F)) -> IndexResult { + + fn f_i(&self, index: (A, B, C, D, E, F)) -> Result { let idx_a = index.0.into(); let idx_b = index.1.into(); let idx_c = index.2.into(); @@ -259,19 +264,20 @@ impl IndexOp<(A, B, C, D, E, F)> for Tensor } impl IndexOp<(A, B, C, D, E, F, G)> for Tensor - where - A: Into, - B: Into, - C: Into, - D: Into, - E: Into, - F: Into, - G: Into, +where + A: Into, + B: Into, + C: Into, + D: Into, + E: Into, + F: Into, + G: Into, { fn i(&self, index: (A, B, C, D, E, F, G)) -> Tensor { self.f_i(index).unwrap() } - fn f_i(&self, index: (A, B, C, D, E, F, G)) -> IndexResult { + + fn f_i(&self, index: (A, B, C, D, E, F, G)) -> Result { let idx_a = index.0.into(); let idx_b = index.1.into(); let idx_c = index.2.into(); @@ -284,7 +290,7 @@ impl IndexOp<(A, B, C, D, E, F, G)> for Tensor } impl Tensor { - fn f_indexer(&self, index_spec: &[TensorIndexer]) -> Result { + fn f_indexer(&self, index_spec: &[TensorIndexer]) -> Result { use std::ops::Bound::*; use TensorIndexer::*; From 7f41f93e056a6aa3ffb7055f2b310a551dd3acfc Mon Sep 17 00:00:00 2001 From: laurent Date: Sat, 6 Jan 2024 11:22:04 +0100 Subject: [PATCH 5/5] Clippy fix. --- tests/tensor_tests.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/tensor_tests.rs b/tests/tensor_tests.rs index 1da8b1db..fc8e585a 100644 --- a/tests/tensor_tests.rs +++ b/tests/tensor_tests.rs @@ -1,3 +1,4 @@ +#![allow(clippy::unnecessary_fallible_conversions)] use anyhow::Result; use half::f16; use std::convert::{TryFrom, TryInto};