Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ name = "slab"
version = "0.4.9"
authors = ["Carl Lerche <me@carllerche.com>"]
edition = "2018"
rust-version = "1.46"
rust-version = "1.51"
license = "MIT"
description = "Pre-allocated storage for a uniform data type"
repository = "https://github.com/tokio-rs/slab"
Expand Down
71 changes: 71 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ mod builder;

use alloc::vec::{self, Vec};
use core::iter::{self, FromIterator, FusedIterator};
use core::mem::MaybeUninit;
use core::{fmt, mem, ops, slice};

/// Pre-allocated storage for a uniform data type
Expand Down Expand Up @@ -169,6 +170,33 @@ impl<T> Default for Slab<T> {
}
}

#[derive(Debug, Clone, PartialEq, Eq)]
/// The error type returned by [`Slab::get_disjoint_mut`].
pub enum GetDisjointMutError {
/// An index provided was not associated with a value.
IndexVacant,

/// An index provided was out-of-bounds for the slab.
IndexOutOfBounds,

/// Two indices provided were overlapping.
OverlappingIndices,
}

impl fmt::Display for GetDisjointMutError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let msg = match self {
GetDisjointMutError::IndexVacant => "an index is vacant",
GetDisjointMutError::IndexOutOfBounds => "an index is out of bounds",
GetDisjointMutError::OverlappingIndices => "there were overlapping indices",
};
fmt::Display::fmt(msg, f)
}
}

#[cfg(feature = "std")]
impl std::error::Error for GetDisjointMutError {}

/// A handle to a vacant entry in a `Slab`.
///
/// `VacantEntry` allows constructing values with the key that they will be
Expand Down Expand Up @@ -776,6 +804,49 @@ impl<T> Slab<T> {
}
}

/// Returns mutable references to many indices at once.
///
/// Returns [`GetDisjointMutError`] if the indices are out of bounds,
/// overlapping, or vacant.
pub fn get_disjoint_mut<const N: usize>(
&mut self,
keys: [usize; N],
) -> Result<[&mut T; N], GetDisjointMutError> {
// NB: The optimizer should inline the loops into a sequence
// of instructions without additional branching.
for (i, &key) in keys.iter().enumerate() {
for &prev_key in &keys[..i] {
if key == prev_key {
return Err(GetDisjointMutError::OverlappingIndices);
}
}
}

let entries_ptr = self.entries.as_mut_ptr();
let entries_cap = self.entries.capacity();

let mut res = MaybeUninit::<[&mut T; N]>::uninit();
let res_ptr = res.as_mut_ptr() as *mut &mut T;

for (i, &key) in keys.iter().enumerate() {
if key >= entries_cap {
return Err(GetDisjointMutError::IndexOutOfBounds);
}
// SAFETY: we made sure above that this key is in bounds.
match unsafe { &mut *entries_ptr.add(key) } {
Entry::Vacant(_) => return Err(GetDisjointMutError::IndexVacant),
Entry::Occupied(entry) => {
// SAFETY: `res` and `keys` both have N elements so `i` must be in bounds.
// We checked above that all selected `entry`s are distinct.
unsafe { res_ptr.add(i).write(entry) };
}
}
}
// SAFETY: the loop above only terminates successfully if it initialized
// all elements of this array.
Ok(unsafe { res.assume_init() })
}

/// Return a reference to the value associated with the given key without
/// performing bounds checking.
///
Expand Down
39 changes: 38 additions & 1 deletion tests/slab.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@

use slab::*;

use std::panic::{catch_unwind, resume_unwind, AssertUnwindSafe};
use std::{
iter::FromIterator,
panic::{catch_unwind, resume_unwind, AssertUnwindSafe},
};

#[test]
fn insert_get_remove_one() {
Expand Down Expand Up @@ -730,3 +733,37 @@ fn clone_from() {
assert_eq!(iter2.next(), None);
assert!(slab2.capacity() >= 10);
}

#[test]
fn get_disjoint_mut() {
let mut slab = Slab::from_iter((0..5).enumerate());
slab.remove(1);
slab.remove(3);

assert_eq!(slab.get_disjoint_mut([]), Ok([]));

assert_eq!(
slab.get_disjoint_mut([4, 2, 0]).unwrap().map(|x| *x),
[4, 2, 0]
);

assert_eq!(
slab.get_disjoint_mut([42, 2, 1, 2]),
Err(GetDisjointMutError::OverlappingIndices)
);

assert_eq!(
slab.get_disjoint_mut([1, 5]),
Err(GetDisjointMutError::IndexVacant)
);

assert_eq!(
slab.get_disjoint_mut([5, 1]),
Err(GetDisjointMutError::IndexOutOfBounds)
);

let [a, b] = slab.get_disjoint_mut([0, 4]).unwrap();
(*a, *b) = (*b, *a);
assert_eq!(slab[0], 4);
assert_eq!(slab[4], 0);
}