Skip to content

Commit

Permalink
Uses broadcast indices to remove D::Smaller: Copy trait bound
Browse files Browse the repository at this point in the history
  • Loading branch information
akern40 committed Aug 8, 2024
1 parent 5ed317a commit bb7e02f
Showing 1 changed file with 11 additions and 8 deletions.
19 changes: 11 additions & 8 deletions src/tri.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ use crate::{
Axis,
Data,
Dimension,
IntoDimension,
Zip,
};

Expand All @@ -26,7 +25,6 @@ where
S: Data<Elem = A>,
D: Dimension,
A: Clone + Zero,
D::Smaller: Copy,
{
/// Upper triangular of an array.
///
Expand Down Expand Up @@ -74,10 +72,12 @@ where

let mut res = Array::zeros(self.raw_dim());
let ncols = self.len_of(Axis(n - 1));
Zip::indexed(self.rows())
let nrows = self.len_of(Axis(n - 2));
let indices = Array::from_iter(0..nrows);
Zip::from(self.rows())
.and(res.rows_mut())
.for_each(|i, src, mut dst| {
let row_num = i.into_dimension().last_elem();
.and_broadcast(&indices)
.for_each(|src, mut dst, row_num| {
let mut lower = match k >= 0 {
true => row_num.saturating_add(k as usize), // Avoid overflow
false => row_num.saturating_sub(k.unsigned_abs()), // Avoid underflow, go to 0
Expand Down Expand Up @@ -135,10 +135,13 @@ where

let mut res = Array::zeros(self.raw_dim());
let ncols = self.len_of(Axis(n - 1));
Zip::indexed(self.rows())
let nrows = self.len_of(Axis(n - 2));
let indices = Array::from_iter(0..nrows);
Zip::from(self.rows())
.and(res.rows_mut())
.for_each(|i, src, mut dst| {
let row_num = i.into_dimension().last_elem();
.and_broadcast(&indices)
.for_each(|src, mut dst, row_num| {
// let row_num = i.into_dimension().last_elem();
let mut upper = match k >= 0 {
true => row_num.saturating_add(k as usize).saturating_add(1), // Avoid overflow
false => row_num.saturating_sub((k + 1).unsigned_abs()), // Avoid underflow
Expand Down

0 comments on commit bb7e02f

Please sign in to comment.