6
6
// option. This file may not be copied, modified, or distributed
7
7
// except according to those terms.
8
8
9
- use core:: cmp:: { max , min} ;
9
+ use core:: cmp:: min;
10
10
11
11
use num_traits:: Zero ;
12
12
13
- use crate :: { dimension:: is_layout_f, Array , ArrayBase , Axis , Data , Dimension , IntoDimension , Zip } ;
13
+ use crate :: {
14
+ dimension:: { is_layout_c, is_layout_f} ,
15
+ Array ,
16
+ ArrayBase ,
17
+ Axis ,
18
+ Data ,
19
+ Dimension ,
20
+ Zip ,
21
+ } ;
14
22
15
23
impl < S , A , D > ArrayBase < S , D >
16
24
where
17
25
S : Data < Elem = A > ,
18
26
D : Dimension ,
19
27
A : Clone + Zero ,
20
- D :: Smaller : Copy ,
21
28
{
22
29
/// Upper triangular of an array.
23
30
///
@@ -30,38 +37,56 @@ where
30
37
/// ```
31
38
/// use ndarray::array;
32
39
///
33
- /// let arr = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]];
34
- /// let res = arr.triu(0);
35
- /// assert_eq!(res, array![[1, 2, 3], [0, 5, 6], [0, 0, 9]]);
40
+ /// let arr = array![
41
+ /// [1, 2, 3],
42
+ /// [4, 5, 6],
43
+ /// [7, 8, 9]
44
+ /// ];
45
+ /// assert_eq!(
46
+ /// arr.triu(0),
47
+ /// array![
48
+ /// [1, 2, 3],
49
+ /// [0, 5, 6],
50
+ /// [0, 0, 9]
51
+ /// ]
52
+ /// );
36
53
/// ```
37
54
pub fn triu ( & self , k : isize ) -> Array < A , D >
38
55
{
39
56
if self . ndim ( ) <= 1 {
40
57
return self . to_owned ( ) ;
41
58
}
42
- match is_layout_f ( & self . dim , & self . strides ) {
43
- true => {
44
- let n = self . ndim ( ) ;
45
- let mut x = self . view ( ) ;
46
- x. swap_axes ( n - 2 , n - 1 ) ;
47
- let mut tril = x. tril ( -k) ;
48
- tril. swap_axes ( n - 2 , n - 1 ) ;
49
-
50
- tril
51
- }
52
- false => {
53
- let mut res = Array :: zeros ( self . raw_dim ( ) ) ;
54
- Zip :: indexed ( self . rows ( ) )
55
- . and ( res. rows_mut ( ) )
56
- . for_each ( |i, src, mut dst| {
57
- let row_num = i. into_dimension ( ) . last_elem ( ) ;
58
- let lower = max ( row_num as isize + k, 0 ) ;
59
- dst. slice_mut ( s ! [ lower..] ) . assign ( & src. slice ( s ! [ lower..] ) ) ;
60
- } ) ;
61
-
62
- res
63
- }
59
+
60
+ // Performance optimization for F-order arrays.
61
+ // C-order array check prevents infinite recursion in edge cases like [[1]].
62
+ // k-size check prevents underflow when k == isize::MIN
63
+ let n = self . ndim ( ) ;
64
+ if is_layout_f ( & self . dim , & self . strides ) && !is_layout_c ( & self . dim , & self . strides ) && k > isize:: MIN {
65
+ let mut x = self . view ( ) ;
66
+ x. swap_axes ( n - 2 , n - 1 ) ;
67
+ let mut tril = x. tril ( -k) ;
68
+ tril. swap_axes ( n - 2 , n - 1 ) ;
69
+
70
+ return tril;
64
71
}
72
+
73
+ let mut res = Array :: zeros ( self . raw_dim ( ) ) ;
74
+ let ncols = self . len_of ( Axis ( n - 1 ) ) ;
75
+ let nrows = self . len_of ( Axis ( n - 2 ) ) ;
76
+ let indices = Array :: from_iter ( 0 ..nrows) ;
77
+ Zip :: from ( self . rows ( ) )
78
+ . and ( res. rows_mut ( ) )
79
+ . and_broadcast ( & indices)
80
+ . for_each ( |src, mut dst, row_num| {
81
+ let mut lower = match k >= 0 {
82
+ true => row_num. saturating_add ( k as usize ) , // Avoid overflow
83
+ false => row_num. saturating_sub ( k. unsigned_abs ( ) ) , // Avoid underflow, go to 0
84
+ } ;
85
+ lower = min ( lower, ncols) ;
86
+ dst. slice_mut ( s ! [ lower..] ) . assign ( & src. slice ( s ! [ lower..] ) ) ;
87
+ } ) ;
88
+
89
+ res
65
90
}
66
91
67
92
/// Lower triangular of an array.
@@ -75,45 +100,65 @@ where
75
100
/// ```
76
101
/// use ndarray::array;
77
102
///
78
- /// let arr = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]];
79
- /// let res = arr.tril(0);
80
- /// assert_eq!(res, array![[1, 0, 0], [4, 5, 0], [7, 8, 9]]);
103
+ /// let arr = array![
104
+ /// [1, 2, 3],
105
+ /// [4, 5, 6],
106
+ /// [7, 8, 9]
107
+ /// ];
108
+ /// assert_eq!(
109
+ /// arr.tril(0),
110
+ /// array![
111
+ /// [1, 0, 0],
112
+ /// [4, 5, 0],
113
+ /// [7, 8, 9]
114
+ /// ]
115
+ /// );
81
116
/// ```
82
117
pub fn tril ( & self , k : isize ) -> Array < A , D >
83
118
{
84
119
if self . ndim ( ) <= 1 {
85
120
return self . to_owned ( ) ;
86
121
}
87
- match is_layout_f ( & self . dim , & self . strides ) {
88
- true => {
89
- let n = self . ndim ( ) ;
90
- let mut x = self . view ( ) ;
91
- x. swap_axes ( n - 2 , n - 1 ) ;
92
- let mut tril = x. triu ( -k) ;
93
- tril. swap_axes ( n - 2 , n - 1 ) ;
94
-
95
- tril
96
- }
97
- false => {
98
- let mut res = Array :: zeros ( self . raw_dim ( ) ) ;
99
- let ncols = self . len_of ( Axis ( self . ndim ( ) - 1 ) ) as isize ;
100
- Zip :: indexed ( self . rows ( ) )
101
- . and ( res. rows_mut ( ) )
102
- . for_each ( |i, src, mut dst| {
103
- let row_num = i. into_dimension ( ) . last_elem ( ) ;
104
- let upper = min ( row_num as isize + k, ncols) + 1 ;
105
- dst. slice_mut ( s ! [ ..upper] ) . assign ( & src. slice ( s ! [ ..upper] ) ) ;
106
- } ) ;
107
-
108
- res
109
- }
122
+
123
+ // Performance optimization for F-order arrays.
124
+ // C-order array check prevents infinite recursion in edge cases like [[1]].
125
+ // k-size check prevents underflow when k == isize::MIN
126
+ let n = self . ndim ( ) ;
127
+ if is_layout_f ( & self . dim , & self . strides ) && !is_layout_c ( & self . dim , & self . strides ) && k > isize:: MIN {
128
+ let mut x = self . view ( ) ;
129
+ x. swap_axes ( n - 2 , n - 1 ) ;
130
+ let mut tril = x. triu ( -k) ;
131
+ tril. swap_axes ( n - 2 , n - 1 ) ;
132
+
133
+ return tril;
110
134
}
135
+
136
+ let mut res = Array :: zeros ( self . raw_dim ( ) ) ;
137
+ let ncols = self . len_of ( Axis ( n - 1 ) ) ;
138
+ let nrows = self . len_of ( Axis ( n - 2 ) ) ;
139
+ let indices = Array :: from_iter ( 0 ..nrows) ;
140
+ Zip :: from ( self . rows ( ) )
141
+ . and ( res. rows_mut ( ) )
142
+ . and_broadcast ( & indices)
143
+ . for_each ( |src, mut dst, row_num| {
144
+ // let row_num = i.into_dimension().last_elem();
145
+ let mut upper = match k >= 0 {
146
+ true => row_num. saturating_add ( k as usize ) . saturating_add ( 1 ) , // Avoid overflow
147
+ false => row_num. saturating_sub ( ( k + 1 ) . unsigned_abs ( ) ) , // Avoid underflow
148
+ } ;
149
+ upper = min ( upper, ncols) ;
150
+ dst. slice_mut ( s ! [ ..upper] ) . assign ( & src. slice ( s ! [ ..upper] ) ) ;
151
+ } ) ;
152
+
153
+ res
111
154
}
112
155
}
113
156
114
157
#[ cfg( test) ]
115
158
mod tests
116
159
{
160
+ use core:: isize;
161
+
117
162
use crate :: { array, dimension, Array0 , Array1 , Array2 , Array3 , ShapeBuilder } ;
118
163
use alloc:: vec;
119
164
@@ -188,6 +233,19 @@ mod tests
188
233
assert_eq ! ( res, array![ [ 1 , 0 , 0 ] , [ 4 , 5 , 0 ] , [ 7 , 8 , 9 ] ] ) ;
189
234
}
190
235
236
+ #[ test]
237
+ fn test_2d_single ( )
238
+ {
239
+ let x = array ! [ [ 1 ] ] ;
240
+
241
+ assert_eq ! ( x. triu( 0 ) , array![ [ 1 ] ] ) ;
242
+ assert_eq ! ( x. tril( 0 ) , array![ [ 1 ] ] ) ;
243
+ assert_eq ! ( x. triu( 1 ) , array![ [ 0 ] ] ) ;
244
+ assert_eq ! ( x. tril( 1 ) , array![ [ 1 ] ] ) ;
245
+ assert_eq ! ( x. triu( -1 ) , array![ [ 1 ] ] ) ;
246
+ assert_eq ! ( x. tril( -1 ) , array![ [ 0 ] ] ) ;
247
+ }
248
+
191
249
#[ test]
192
250
fn test_3d ( )
193
251
{
@@ -285,8 +343,25 @@ mod tests
285
343
let res = x. triu ( 0 ) ;
286
344
assert_eq ! ( res, array![ [ 1 , 2 , 3 ] , [ 0 , 5 , 6 ] ] ) ;
287
345
346
+ let res = x. tril ( 0 ) ;
347
+ assert_eq ! ( res, array![ [ 1 , 0 , 0 ] , [ 4 , 5 , 0 ] ] ) ;
348
+
288
349
let x = array ! [ [ 1 , 2 ] , [ 3 , 4 ] , [ 5 , 6 ] ] ;
289
350
let res = x. triu ( 0 ) ;
290
351
assert_eq ! ( res, array![ [ 1 , 2 ] , [ 0 , 4 ] , [ 0 , 0 ] ] ) ;
352
+
353
+ let res = x. tril ( 0 ) ;
354
+ assert_eq ! ( res, array![ [ 1 , 0 ] , [ 3 , 4 ] , [ 5 , 6 ] ] ) ;
355
+ }
356
+
357
+ #[ test]
358
+ fn test_odd_k ( )
359
+ {
360
+ let x = array ! [ [ 1 , 2 , 3 ] , [ 4 , 5 , 6 ] , [ 7 , 8 , 9 ] ] ;
361
+ let z = Array2 :: zeros ( [ 3 , 3 ] ) ;
362
+ assert_eq ! ( x. triu( isize :: MIN ) , x) ;
363
+ assert_eq ! ( x. tril( isize :: MIN ) , z) ;
364
+ assert_eq ! ( x. triu( isize :: MAX ) , z) ;
365
+ assert_eq ! ( x. tril( isize :: MAX ) , x) ;
291
366
}
292
367
}
0 commit comments