Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make iterators covariant in element type #1417

Merged
merged 2 commits into from
Aug 6, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/impl_owned_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use alloc::vec::Vec;
use std::mem;
use std::mem::MaybeUninit;

#[allow(unused_imports)]
#[allow(unused_imports)] // Needed for Rust 1.64
use rawpointer::PointerExt;

use crate::imp_prelude::*;
Expand Down Expand Up @@ -907,7 +907,7 @@ where D: Dimension

// iter is a raw pointer iterator traversing the array in memory order now with the
// sorted axes.
let mut iter = Baseiter::new(self_.ptr.as_ptr(), self_.dim, self_.strides);
let mut iter = Baseiter::new(self_.ptr, self_.dim, self_.strides);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review note: it's evident that Baseiter is constructed from a NonNull everywhere, which means that its non-null requirement is easily fulfilled.

let mut dropped_elements = 0;

let mut last_ptr = data_ptr;
Expand Down
8 changes: 4 additions & 4 deletions src/impl_views/conversions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ where D: Dimension
#[inline]
pub(crate) fn into_base_iter(self) -> Baseiter<A, D>
{
unsafe { Baseiter::new(self.ptr.as_ptr(), self.dim, self.strides) }
unsafe { Baseiter::new(self.ptr, self.dim, self.strides) }
}
}

Expand All @@ -209,7 +209,7 @@ where D: Dimension
#[inline]
pub(crate) fn into_base_iter(self) -> Baseiter<A, D>
{
unsafe { Baseiter::new(self.ptr.as_ptr(), self.dim, self.strides) }
unsafe { Baseiter::new(self.ptr, self.dim, self.strides) }
}
}

Expand All @@ -220,7 +220,7 @@ where D: Dimension
#[inline]
pub(crate) fn into_base_iter(self) -> Baseiter<A, D>
{
unsafe { Baseiter::new(self.ptr.as_ptr(), self.dim, self.strides) }
unsafe { Baseiter::new(self.ptr, self.dim, self.strides) }
}

#[inline]
Expand Down Expand Up @@ -262,7 +262,7 @@ where D: Dimension
#[inline]
pub(crate) fn into_base_iter(self) -> Baseiter<A, D>
{
unsafe { Baseiter::new(self.ptr.as_ptr(), self.dim, self.strides) }
unsafe { Baseiter::new(self.ptr, self.dim, self.strides) }
}

#[inline]
Expand Down
5 changes: 2 additions & 3 deletions src/iterators/into_iter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,15 @@ impl<A, D> IntoIter<A, D>
where D: Dimension
{
/// Create a new by-value iterator that consumes `array`
pub(crate) fn new(mut array: Array<A, D>) -> Self
pub(crate) fn new(array: Array<A, D>) -> Self
{
unsafe {
let array_head_ptr = array.ptr;
let ptr = array.as_mut_ptr();
let mut array_data = array.data;
let data_len = array_data.release_all_elements();
debug_assert!(data_len >= array.dim.size());
let has_unreachable_elements = array.dim.size() != data_len;
let inner = Baseiter::new(ptr, array.dim, array.strides);
let inner = Baseiter::new(array_head_ptr, array.dim, array.strides);

IntoIter {
array_data,
Expand Down
23 changes: 14 additions & 9 deletions src/iterators/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ use alloc::vec::Vec;
use std::iter::FromIterator;
use std::marker::PhantomData;
use std::ptr;
use std::ptr::NonNull;

#[allow(unused_imports)] // Needed for Rust 1.64
use rawpointer::PointerExt;

use crate::Ix1;

Expand All @@ -38,7 +42,7 @@ use std::slice::{self, Iter as SliceIter, IterMut as SliceIterMut};
#[derive(Debug)]
pub struct Baseiter<A, D>
{
ptr: *mut A,
ptr: NonNull<A>,
dim: D,
strides: D,
index: Option<D>,
Expand All @@ -50,7 +54,7 @@ impl<A, D: Dimension> Baseiter<A, D>
/// to be correct to avoid performing an unsafe pointer offset while
/// iterating.
#[inline]
pub unsafe fn new(ptr: *mut A, len: D, stride: D) -> Baseiter<A, D>
pub unsafe fn new(ptr: NonNull<A>, len: D, stride: D) -> Baseiter<A, D>
{
Baseiter {
ptr,
Expand All @@ -74,7 +78,7 @@ impl<A, D: Dimension> Iterator for Baseiter<A, D>
};
let offset = D::stride_offset(&index, &self.strides);
self.index = self.dim.next_for(index);
unsafe { Some(self.ptr.offset(offset)) }
unsafe { Some(self.ptr.offset(offset).as_ptr()) }
}

fn size_hint(&self) -> (usize, Option<usize>)
Expand All @@ -99,7 +103,7 @@ impl<A, D: Dimension> Iterator for Baseiter<A, D>
let mut i = 0;
let i_end = len - elem_index;
while i < i_end {
accum = g(accum, row_ptr.offset(i as isize * stride));
accum = g(accum, row_ptr.offset(i as isize * stride).as_ptr());
i += 1;
}
}
Expand Down Expand Up @@ -140,12 +144,12 @@ impl<A> DoubleEndedIterator for Baseiter<A, Ix1>
Some(ix) => ix,
};
self.dim[0] -= 1;
let offset = <_>::stride_offset(&self.dim, &self.strides);
let offset = Ix1::stride_offset(&self.dim, &self.strides);
if index == self.dim {
self.index = None;
}

unsafe { Some(self.ptr.offset(offset)) }
unsafe { Some(self.ptr.offset(offset).as_ptr()) }
}

fn nth_back(&mut self, n: usize) -> Option<*mut A>
Expand All @@ -154,11 +158,11 @@ impl<A> DoubleEndedIterator for Baseiter<A, Ix1>
let len = self.dim[0] - index[0];
if n < len {
self.dim[0] -= n + 1;
let offset = <_>::stride_offset(&self.dim, &self.strides);
let offset = Ix1::stride_offset(&self.dim, &self.strides);
if index == self.dim {
self.index = None;
}
unsafe { Some(self.ptr.offset(offset)) }
unsafe { Some(self.ptr.offset(offset).as_ptr()) }
} else {
self.index = None;
None
Expand All @@ -178,7 +182,8 @@ impl<A> DoubleEndedIterator for Baseiter<A, Ix1>
accum = g(
accum,
self.ptr
.offset(Ix1::stride_offset(&self.dim, &self.strides)),
.offset(Ix1::stride_offset(&self.dim, &self.strides))
.as_ptr(),
);
}
}
Expand Down
34 changes: 31 additions & 3 deletions tests/iterators.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
#![allow(
clippy::many_single_char_names, clippy::deref_addrof, clippy::unreadable_literal, clippy::many_single_char_names
)]
#![allow(clippy::deref_addrof, clippy::unreadable_literal)]

use ndarray::prelude::*;
use ndarray::{arr3, indices, s, Slice, Zip};
Expand Down Expand Up @@ -1055,3 +1053,33 @@ impl Drop for DropCount<'_>
self.drops.set(self.drops.get() + 1);
}
}

#[test]
fn test_impl_iter_compiles()
{
// Requires that the iterators are covariant in the element type

// base case: std
fn slice_iter_non_empty_indices<'s, 'a>(array: &'a Vec<&'s str>) -> impl Iterator<Item = usize> + 'a
{
array
.iter()
.enumerate()
.filter(|(_index, elem)| !elem.is_empty())
.map(|(index, _elem)| index)
}

let _ = slice_iter_non_empty_indices;

// ndarray case
fn array_iter_non_empty_indices<'s, 'a>(array: &'a Array<&'s str, Ix1>) -> impl Iterator<Item = usize> + 'a
{
array
.iter()
.enumerate()
.filter(|(_index, elem)| !elem.is_empty())
.map(|(index, _elem)| index)
}

let _ = array_iter_non_empty_indices;
}