Skip to content

Commit 1df6c32

Browse files
authored
Fix infinite recursion, overflow, and off-by-one error in triu/tril (#1418)
* Fixes infinite recursion and off-by-one error * Avoids overflow using saturating arithmetic * Removes unused import * Fixes bug for isize::MAX for triu * Fix formatting * Uses broadcast indices to remove D::Smaller: Copy trait bound
1 parent f563af0 commit 1df6c32

File tree

1 file changed

+129
-54
lines changed

1 file changed

+129
-54
lines changed

src/tri.rs

+129-54
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,25 @@
66
// option. This file may not be copied, modified, or distributed
77
// except according to those terms.
88

9-
use core::cmp::{max, min};
9+
use core::cmp::min;
1010

1111
use num_traits::Zero;
1212

13-
use crate::{dimension::is_layout_f, Array, ArrayBase, Axis, Data, Dimension, IntoDimension, Zip};
13+
use crate::{
14+
dimension::{is_layout_c, is_layout_f},
15+
Array,
16+
ArrayBase,
17+
Axis,
18+
Data,
19+
Dimension,
20+
Zip,
21+
};
1422

1523
impl<S, A, D> ArrayBase<S, D>
1624
where
1725
S: Data<Elem = A>,
1826
D: Dimension,
1927
A: Clone + Zero,
20-
D::Smaller: Copy,
2128
{
2229
/// Upper triangular of an array.
2330
///
@@ -30,38 +37,56 @@ where
3037
/// ```
3138
/// use ndarray::array;
3239
///
33-
/// let arr = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]];
34-
/// let res = arr.triu(0);
35-
/// assert_eq!(res, array![[1, 2, 3], [0, 5, 6], [0, 0, 9]]);
40+
/// let arr = array![
41+
/// [1, 2, 3],
42+
/// [4, 5, 6],
43+
/// [7, 8, 9]
44+
/// ];
45+
/// assert_eq!(
46+
/// arr.triu(0),
47+
/// array![
48+
/// [1, 2, 3],
49+
/// [0, 5, 6],
50+
/// [0, 0, 9]
51+
/// ]
52+
/// );
3653
/// ```
3754
pub fn triu(&self, k: isize) -> Array<A, D>
3855
{
3956
if self.ndim() <= 1 {
4057
return self.to_owned();
4158
}
42-
match is_layout_f(&self.dim, &self.strides) {
43-
true => {
44-
let n = self.ndim();
45-
let mut x = self.view();
46-
x.swap_axes(n - 2, n - 1);
47-
let mut tril = x.tril(-k);
48-
tril.swap_axes(n - 2, n - 1);
49-
50-
tril
51-
}
52-
false => {
53-
let mut res = Array::zeros(self.raw_dim());
54-
Zip::indexed(self.rows())
55-
.and(res.rows_mut())
56-
.for_each(|i, src, mut dst| {
57-
let row_num = i.into_dimension().last_elem();
58-
let lower = max(row_num as isize + k, 0);
59-
dst.slice_mut(s![lower..]).assign(&src.slice(s![lower..]));
60-
});
61-
62-
res
63-
}
59+
60+
// Performance optimization for F-order arrays.
61+
// C-order array check prevents infinite recursion in edge cases like [[1]].
62+
// k-size check prevents underflow when k == isize::MIN
63+
let n = self.ndim();
64+
if is_layout_f(&self.dim, &self.strides) && !is_layout_c(&self.dim, &self.strides) && k > isize::MIN {
65+
let mut x = self.view();
66+
x.swap_axes(n - 2, n - 1);
67+
let mut tril = x.tril(-k);
68+
tril.swap_axes(n - 2, n - 1);
69+
70+
return tril;
6471
}
72+
73+
let mut res = Array::zeros(self.raw_dim());
74+
let ncols = self.len_of(Axis(n - 1));
75+
let nrows = self.len_of(Axis(n - 2));
76+
let indices = Array::from_iter(0..nrows);
77+
Zip::from(self.rows())
78+
.and(res.rows_mut())
79+
.and_broadcast(&indices)
80+
.for_each(|src, mut dst, row_num| {
81+
let mut lower = match k >= 0 {
82+
true => row_num.saturating_add(k as usize), // Avoid overflow
83+
false => row_num.saturating_sub(k.unsigned_abs()), // Avoid underflow, go to 0
84+
};
85+
lower = min(lower, ncols);
86+
dst.slice_mut(s![lower..]).assign(&src.slice(s![lower..]));
87+
});
88+
89+
res
6590
}
6691

6792
/// Lower triangular of an array.
@@ -75,45 +100,65 @@ where
75100
/// ```
76101
/// use ndarray::array;
77102
///
78-
/// let arr = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]];
79-
/// let res = arr.tril(0);
80-
/// assert_eq!(res, array![[1, 0, 0], [4, 5, 0], [7, 8, 9]]);
103+
/// let arr = array![
104+
/// [1, 2, 3],
105+
/// [4, 5, 6],
106+
/// [7, 8, 9]
107+
/// ];
108+
/// assert_eq!(
109+
/// arr.tril(0),
110+
/// array![
111+
/// [1, 0, 0],
112+
/// [4, 5, 0],
113+
/// [7, 8, 9]
114+
/// ]
115+
/// );
81116
/// ```
82117
pub fn tril(&self, k: isize) -> Array<A, D>
83118
{
84119
if self.ndim() <= 1 {
85120
return self.to_owned();
86121
}
87-
match is_layout_f(&self.dim, &self.strides) {
88-
true => {
89-
let n = self.ndim();
90-
let mut x = self.view();
91-
x.swap_axes(n - 2, n - 1);
92-
let mut tril = x.triu(-k);
93-
tril.swap_axes(n - 2, n - 1);
94-
95-
tril
96-
}
97-
false => {
98-
let mut res = Array::zeros(self.raw_dim());
99-
let ncols = self.len_of(Axis(self.ndim() - 1)) as isize;
100-
Zip::indexed(self.rows())
101-
.and(res.rows_mut())
102-
.for_each(|i, src, mut dst| {
103-
let row_num = i.into_dimension().last_elem();
104-
let upper = min(row_num as isize + k, ncols) + 1;
105-
dst.slice_mut(s![..upper]).assign(&src.slice(s![..upper]));
106-
});
107-
108-
res
109-
}
122+
123+
// Performance optimization for F-order arrays.
124+
// C-order array check prevents infinite recursion in edge cases like [[1]].
125+
// k-size check prevents underflow when k == isize::MIN
126+
let n = self.ndim();
127+
if is_layout_f(&self.dim, &self.strides) && !is_layout_c(&self.dim, &self.strides) && k > isize::MIN {
128+
let mut x = self.view();
129+
x.swap_axes(n - 2, n - 1);
130+
let mut tril = x.triu(-k);
131+
tril.swap_axes(n - 2, n - 1);
132+
133+
return tril;
110134
}
135+
136+
let mut res = Array::zeros(self.raw_dim());
137+
let ncols = self.len_of(Axis(n - 1));
138+
let nrows = self.len_of(Axis(n - 2));
139+
let indices = Array::from_iter(0..nrows);
140+
Zip::from(self.rows())
141+
.and(res.rows_mut())
142+
.and_broadcast(&indices)
143+
.for_each(|src, mut dst, row_num| {
144+
// let row_num = i.into_dimension().last_elem();
145+
let mut upper = match k >= 0 {
146+
true => row_num.saturating_add(k as usize).saturating_add(1), // Avoid overflow
147+
false => row_num.saturating_sub((k + 1).unsigned_abs()), // Avoid underflow
148+
};
149+
upper = min(upper, ncols);
150+
dst.slice_mut(s![..upper]).assign(&src.slice(s![..upper]));
151+
});
152+
153+
res
111154
}
112155
}
113156

114157
#[cfg(test)]
115158
mod tests
116159
{
160+
use core::isize;
161+
117162
use crate::{array, dimension, Array0, Array1, Array2, Array3, ShapeBuilder};
118163
use alloc::vec;
119164

@@ -188,6 +233,19 @@ mod tests
188233
assert_eq!(res, array![[1, 0, 0], [4, 5, 0], [7, 8, 9]]);
189234
}
190235

236+
#[test]
237+
fn test_2d_single()
238+
{
239+
let x = array![[1]];
240+
241+
assert_eq!(x.triu(0), array![[1]]);
242+
assert_eq!(x.tril(0), array![[1]]);
243+
assert_eq!(x.triu(1), array![[0]]);
244+
assert_eq!(x.tril(1), array![[1]]);
245+
assert_eq!(x.triu(-1), array![[1]]);
246+
assert_eq!(x.tril(-1), array![[0]]);
247+
}
248+
191249
#[test]
192250
fn test_3d()
193251
{
@@ -285,8 +343,25 @@ mod tests
285343
let res = x.triu(0);
286344
assert_eq!(res, array![[1, 2, 3], [0, 5, 6]]);
287345

346+
let res = x.tril(0);
347+
assert_eq!(res, array![[1, 0, 0], [4, 5, 0]]);
348+
288349
let x = array![[1, 2], [3, 4], [5, 6]];
289350
let res = x.triu(0);
290351
assert_eq!(res, array![[1, 2], [0, 4], [0, 0]]);
352+
353+
let res = x.tril(0);
354+
assert_eq!(res, array![[1, 0], [3, 4], [5, 6]]);
355+
}
356+
357+
#[test]
358+
fn test_odd_k()
359+
{
360+
let x = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]];
361+
let z = Array2::zeros([3, 3]);
362+
assert_eq!(x.triu(isize::MIN), x);
363+
assert_eq!(x.tril(isize::MIN), z);
364+
assert_eq!(x.triu(isize::MAX), z);
365+
assert_eq!(x.tril(isize::MAX), x);
291366
}
292367
}

0 commit comments

Comments
 (0)