diff --git a/Cargo.toml b/Cargo.toml index 9e0c186..e237215 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,7 +9,7 @@ name = "slab" version = "0.4.9" authors = ["Carl Lerche "] edition = "2018" -rust-version = "1.46" +rust-version = "1.51" license = "MIT" description = "Pre-allocated storage for a uniform data type" repository = "https://github.com/tokio-rs/slab" diff --git a/src/lib.rs b/src/lib.rs index cd5fe0e..bd00e08 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -125,6 +125,7 @@ mod builder; use alloc::vec::{self, Vec}; use core::iter::{self, FromIterator, FusedIterator}; +use core::mem::MaybeUninit; use core::{fmt, mem, ops, slice}; /// Pre-allocated storage for a uniform data type @@ -169,6 +170,33 @@ impl Default for Slab { } } +#[derive(Debug, Clone, PartialEq, Eq)] +/// The error type returned by [`Slab::get_disjoint_mut`]. +pub enum GetDisjointMutError { + /// An index provided was not associated with a value. + IndexVacant, + + /// An index provided was out-of-bounds for the slab. + IndexOutOfBounds, + + /// Two indices provided were overlapping. + OverlappingIndices, +} + +impl fmt::Display for GetDisjointMutError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let msg = match self { + GetDisjointMutError::IndexVacant => "an index is vacant", + GetDisjointMutError::IndexOutOfBounds => "an index is out of bounds", + GetDisjointMutError::OverlappingIndices => "there were overlapping indices", + }; + fmt::Display::fmt(msg, f) + } +} + +#[cfg(feature = "std")] +impl std::error::Error for GetDisjointMutError {} + /// A handle to a vacant entry in a `Slab`. /// /// `VacantEntry` allows constructing values with the key that they will be @@ -776,6 +804,49 @@ impl Slab { } } + /// Returns mutable references to many indices at once. + /// + /// Returns [`GetDisjointMutError`] if the indices are out of bounds, + /// overlapping, or vacant. + pub fn get_disjoint_mut( + &mut self, + keys: [usize; N], + ) -> Result<[&mut T; N], GetDisjointMutError> { + // NB: The optimizer should inline the loops into a sequence + // of instructions without additional branching. + for (i, &key) in keys.iter().enumerate() { + for &prev_key in &keys[..i] { + if key == prev_key { + return Err(GetDisjointMutError::OverlappingIndices); + } + } + } + + let entries_ptr = self.entries.as_mut_ptr(); + let entries_cap = self.entries.capacity(); + + let mut res = MaybeUninit::<[&mut T; N]>::uninit(); + let res_ptr = res.as_mut_ptr() as *mut &mut T; + + for (i, &key) in keys.iter().enumerate() { + if key >= entries_cap { + return Err(GetDisjointMutError::IndexOutOfBounds); + } + // SAFETY: we made sure above that this key is in bounds. + match unsafe { &mut *entries_ptr.add(key) } { + Entry::Vacant(_) => return Err(GetDisjointMutError::IndexVacant), + Entry::Occupied(entry) => { + // SAFETY: `res` and `keys` both have N elements so `i` must be in bounds. + // We checked above that all selected `entry`s are distinct. + unsafe { res_ptr.add(i).write(entry) }; + } + } + } + // SAFETY: the loop above only terminates successfully if it initialized + // all elements of this array. + Ok(unsafe { res.assume_init() }) + } + /// Return a reference to the value associated with the given key without /// performing bounds checking. /// diff --git a/tests/slab.rs b/tests/slab.rs index dd71f99..1b134f8 100644 --- a/tests/slab.rs +++ b/tests/slab.rs @@ -2,7 +2,10 @@ use slab::*; -use std::panic::{catch_unwind, resume_unwind, AssertUnwindSafe}; +use std::{ + iter::FromIterator, + panic::{catch_unwind, resume_unwind, AssertUnwindSafe}, +}; #[test] fn insert_get_remove_one() { @@ -730,3 +733,37 @@ fn clone_from() { assert_eq!(iter2.next(), None); assert!(slab2.capacity() >= 10); } + +#[test] +fn get_disjoint_mut() { + let mut slab = Slab::from_iter((0..5).enumerate()); + slab.remove(1); + slab.remove(3); + + assert_eq!(slab.get_disjoint_mut([]), Ok([])); + + assert_eq!( + slab.get_disjoint_mut([4, 2, 0]).unwrap().map(|x| *x), + [4, 2, 0] + ); + + assert_eq!( + slab.get_disjoint_mut([42, 2, 1, 2]), + Err(GetDisjointMutError::OverlappingIndices) + ); + + assert_eq!( + slab.get_disjoint_mut([1, 5]), + Err(GetDisjointMutError::IndexVacant) + ); + + assert_eq!( + slab.get_disjoint_mut([5, 1]), + Err(GetDisjointMutError::IndexOutOfBounds) + ); + + let [a, b] = slab.get_disjoint_mut([0, 4]).unwrap(); + (*a, *b) = (*b, *a); + assert_eq!(slab[0], 4); + assert_eq!(slab[4], 0); +}