Skip to content
303 changes: 297 additions & 6 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

use core::fmt;
use core::fmt::Debug;
use core::iter::FusedIterator;
use core::marker::PhantomData;
use core::ops::{
Add, AddAssign, BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Not, Shl,
ShlAssign, Shr, ShrAssign, Sub, SubAssign,
Expand Down Expand Up @@ -72,7 +74,8 @@ pub trait MemoryAddress:
+ BitXor<Output = Self::RAW>
+ Debug
+ From<u8>
+ TryInto<usize, Error: Debug>;
+ TryInto<usize, Error: Debug>
+ TryFrom<usize, Error: Debug>;

/// Get the raw underlying address value.
fn raw(self) -> Self::RAW;
Expand Down Expand Up @@ -124,7 +127,8 @@ impl<T: MemoryAddress> AddrRange<T> {
pub fn iter(&self) -> AddrIter<T> {
AddrIter {
current: self.start,
end: self.end,
end: Some(self.end),
_phantom: PhantomData,
}
}

Expand All @@ -144,21 +148,224 @@ impl<T: MemoryAddress> AddrRange<T> {
}

/// An iterator over a memory range
pub struct AddrIter<T: MemoryAddress> {
#[allow(private_bounds)]
pub struct AddrIter<T: MemoryAddress, I: IterInclusivity = NonInclusive> {
current: T,
end: T,
end: Option<T>, // None here indicates that this is exhausted
_phantom: PhantomData<I>,
}
impl<T: MemoryAddress> Iterator for AddrIter<T> {

trait IterInclusivity: 'static {
fn exhausted<T: Ord>(start: &T, end: &T) -> bool;
}
pub enum NonInclusive {}
Comment on lines +160 to +178
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor nitpicks: Could you please add a doc comment to these pub items? Also, the line spacing is a little inconsistent.


impl IterInclusivity for NonInclusive {
fn exhausted<T: Ord>(start: &T, end: &T) -> bool {
start >= end
}
}

pub enum Inclusive {}

impl IterInclusivity for Inclusive {
fn exhausted<T: Ord>(start: &T, end: &T) -> bool {
start > end
}
}

impl<T: MemoryAddress, I: IterInclusivity> Iterator for AddrIter<T, I> {
type Item = T;
fn next(&mut self) -> Option<Self::Item> {
if self.current >= self.end {
if I::exhausted(&self.current, &self.end?) {
None
} else {
let ret = Some(self.current);
self.current += 1.into();
ret
}
}

fn size_hint(&self) -> (usize, Option<usize>) {
let Some(end) = self.end else {
return (0, Some(0));
};
let ni_count = (end - self.current)
.try_into()
.expect("address range is larger than the architecture's usize");
if core::any::TypeId::of::<I>() == core::any::TypeId::of::<NonInclusive>() {
(ni_count, Some(ni_count))
} else if core::any::TypeId::of::<I>() == core::any::TypeId::of::<Inclusive>() {
(ni_count + 1, Some(ni_count + 1))
} else {
unreachable!()
}
}

fn last(self) -> Option<Self::Item> {
self.max()
}

fn nth(&mut self, n: usize) -> Option<Self::Item> {
let Ok(n): Result<T::RAW, _> = n.try_into() else {
// Fail to cast indicates that n > T::RAW::MAX, so we explicitly exhaust self.
self.end.take();
return None;
};
match self.current.checked_add(n) {
Some(n) => self.current = n,
None if self.current.raw() < n => {
self.end.take();
return None;
}
None => panic!("Attempted to iterate over invalid address"),
}
if I::exhausted(&self.current, &self.end?) {
return None;
}
Some(self.current)
}

fn max(self) -> Option<Self::Item>
where
Self: Sized,
Self::Item: Ord,
{
Some(self.end.unwrap_or(self.current))
}

fn min(self) -> Option<Self::Item> {
Some(self.current)
}

fn is_sorted(self) -> bool
where
Self: Sized,
Self::Item: PartialOrd,
{
true
}
}

impl<T: MemoryAddress> DoubleEndedIterator for AddrIter<T, NonInclusive> {
fn next_back(&mut self) -> Option<Self::Item> {
if NonInclusive::exhausted(&self.current, &self.end?) {
None
} else {
let one: T::RAW = 1u8.into();
self.end = Some(self.end? - one);
self.end
}
}
fn nth_back(&mut self, n: usize) -> Option<Self::Item> {
if n == 0 {
return self.next_back(); // Avoids sub-with-overflow below
}
let Ok(n): Result<T::RAW, _> = n.try_into() else {
// Fail to cast indicates that n > T::RAW::MAX, so we explicitly exhaust self.
self.end.take();
return None;
};
let Some(ret) = self.end?.checked_sub(n) else {
if self.end?.raw() < n {
panic!("Attempted to iterate over invalid address")
}
self.end.take();
return None;
};
self.end = Some(ret);
self.next_back()
}
}

impl<T: MemoryAddress> DoubleEndedIterator for AddrIter<T, Inclusive> {
fn next_back(&mut self) -> Option<Self::Item> {
if Inclusive::exhausted(&self.current, &self.end?) {
None
} else {
let ret = self.end?;

// We need to be able to step back to `0`.
// We return `0` when self.end is currently `0`.
// But then we subtract `0` by `1` triggering a sub-with-overflow
// When we trigger a sub-with-overflow we return early and dont decrement `self.end`
// The next call to self.next() will return as exhausted and the
let Some(step) = self.end?.checked_sub(1.into()) else {
// Check if this was an underflow or a non-canonical address
// Panic on non-canonical
// We can eat the overhead here because this branch is rare
if self.end?.raw() != 0u8.into() {
panic!("Attempted to iterate over invalid address")
}
self.end = None;
return Some(ret);
};
self.end = Some(step);
Some(ret)
}
}

fn nth_back(&mut self, n: usize) -> Option<Self::Item> {
if n == 0 {
return self.next_back();
}
let Ok(n): Result<T::RAW, _> = n.try_into() else {
// Fail to cast indicates that n > T::RAW::MAX, so we explicitly exhaust self.
self.end.take();
return None;
};

let Some(ret) = self.end?.checked_sub(n) else {
if self.end?.raw() < n {
panic!("Attempted to iterate over invalid address")
}
self.end.take();
return None;
};
self.end = Some(ret);
self.end
}
}

impl<T: MemoryAddress> ExactSizeIterator for AddrIter<T, Inclusive> {
fn len(&self) -> usize {
let Some(end) = self.end else { return 0 };
(end - self.current)
.try_into()
.expect("address range is larger than the architecture's usize")
+ 1
}
}

impl<T: MemoryAddress> ExactSizeIterator for AddrIter<T, NonInclusive> {
fn len(&self) -> usize {
let Some(end) = self.end else { return 0 };
(end - self.current)
.try_into()
.expect("address range is larger than the architecture's usize")
}
}

impl<T: MemoryAddress> FusedIterator for AddrIter<T> {}

impl<T: MemoryAddress> From<core::ops::Range<T>> for AddrIter<T, NonInclusive> {
fn from(range: core::ops::Range<T>) -> Self {
Self {
current: range.start,
end: Some(range.end),
_phantom: PhantomData,
}
}
}

impl<T: MemoryAddress> From<core::ops::RangeInclusive<T>> for AddrIter<T, Inclusive> {
fn from(range: core::ops::RangeInclusive<T>) -> Self {
Self {
current: *range.start(),
end: Some(*range.end()),
_phantom: PhantomData,
}
}
}

#[cfg(test)]
Expand Down Expand Up @@ -194,6 +401,31 @@ mod tests {
assert_eq!(a.raw() as usize, i);
}

assert_eq!(r.iter().nth(0), Some(VirtAddr::new(0x0)));
assert_eq!(r.iter().nth(1), Some(VirtAddr::new(0x1)));
assert_eq!(r.iter().nth(2), Some(VirtAddr::new(0x2)));
assert_eq!(r.iter().nth(3), None);

{
let mut range = r.iter();
assert_eq!(range.next_back(), Some(VirtAddr::new(0x2)));
assert_eq!(range.next_back(), Some(VirtAddr::new(0x1)));
assert_eq!(range.next_back(), Some(VirtAddr::new(0x0)));
assert_eq!(range.next_back(), None);
assert_eq!(range.next(), None);

let mut range = r.iter();
assert_eq!(range.next(), Some(VirtAddr::new(0x0)));
assert_eq!(range.next_back(), Some(VirtAddr::new(0x2)));
assert_eq!(range.next(), Some(VirtAddr::new(0x1)));
assert_eq!(range.next_back(), None);

assert_eq!(r.iter().nth_back(0), Some(VirtAddr::new(0x2)));
assert_eq!(r.iter().nth_back(1), Some(VirtAddr::new(0x1)));
assert_eq!(r.iter().nth_back(2), Some(VirtAddr::new(0x0)));
assert_eq!(r.iter().nth_back(3), None);
}

let r = AddrRange::new(PhysAddr::new(0x2), PhysAddr::new(0x4)).unwrap();
let mut i = r.iter();
assert_eq!(i.next().unwrap(), PhysAddr::new(0x2));
Expand All @@ -202,4 +434,63 @@ mod tests {

assert_eq!(r.iter().map(|a| a.raw() as usize).sum::<usize>(), 0x5);
}

#[test]
fn test_iter_incl() {
let range = VirtAddr::new(0x0)..=VirtAddr::new(0x3);
let mut i = AddrIter::from(range.clone());
assert_eq!(i.next().unwrap(), VirtAddr::new(0x0));
assert_eq!(i.next().unwrap(), VirtAddr::new(0x1));
assert_eq!(i.next().unwrap(), VirtAddr::new(0x2));
assert_eq!(i.next().unwrap(), VirtAddr::new(0x3));

let mut i = AddrIter::from(range.clone());
assert_eq!(i.next_back(), Some(VirtAddr::new(0x3)));
assert_eq!(i.next_back(), Some(VirtAddr::new(0x2)));
assert_eq!(i.next_back(), Some(VirtAddr::new(0x1)));
assert_eq!(i.next_back(), Some(VirtAddr::new(0x0)));
assert_eq!(i.next_back(), None);

let mut i = AddrIter::from(range.clone());
assert_eq!(i.next_back(), Some(VirtAddr::new(0x3)));
assert_eq!(i.next(), Some(VirtAddr::new(0x0)));
assert_eq!(i.next_back(), Some(VirtAddr::new(0x2)));
assert_eq!(i.next(), Some(VirtAddr::new(0x1)));
assert_eq!(i.next_back(), None);

assert_eq!(
AddrIter::from(range.clone()).nth(0),
Some(VirtAddr::new(0x0))
);
assert_eq!(
AddrIter::from(range.clone()).nth(1),
Some(VirtAddr::new(0x1))
);
assert_eq!(
AddrIter::from(range.clone()).nth(2),
Some(VirtAddr::new(0x2))
);
assert_eq!(
AddrIter::from(range.clone()).nth(3),
Some(VirtAddr::new(0x3))
);
assert_eq!(AddrIter::from(range.clone()).nth(4), None);
}

#[test]
fn iterator_assert_sizes() {
let range_incl = VirtAddr::new(0x0)..=VirtAddr::new(0x3);
assert_eq!(
AddrIter::from(range_incl.clone()).count(),
AddrIter::from(range_incl.clone()).len()
);
assert_eq!(
AddrIter::from(range_incl.clone()).count(),
AddrIter::from(range_incl.clone()).size_hint().0
);
assert_eq!(
AddrIter::from(range_incl.clone()).count(),
AddrIter::from(range_incl.clone()).size_hint().1.unwrap()
);
}
}
Loading