Skip to content
226 changes: 225 additions & 1 deletion lock_api/src/mutex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ use core::ops::{Deref, DerefMut};
#[cfg(feature = "arc_lock")]
use alloc::sync::Arc;
#[cfg(feature = "arc_lock")]
use core::any::Any;
#[cfg(feature = "arc_lock")]
use core::mem::ManuallyDrop;
#[cfg(feature = "arc_lock")]
use core::ptr;
Expand Down Expand Up @@ -308,7 +310,7 @@ impl<R: RawMutex, T: ?Sized> Mutex<R, T> {
#[inline]
unsafe fn make_arc_guard_unchecked(self: &Arc<Self>) -> ArcMutexGuard<R, T> {
ArcMutexGuard {
mutex: self.clone(),
mutex: Arc::clone(self),
marker: PhantomData,
}
}
Expand Down Expand Up @@ -754,6 +756,62 @@ impl<R: RawMutex, T: ?Sized> ArcMutexGuard<R, T> {
&s.mutex
}

/// Makes a new `MappedArcMutexGuard` for a component of the locked data.
///
/// This operation cannot fail as the `ArcMutexGuard` passed
/// in already locked the mutex.
///
/// This is an associated function that needs to be
/// used as `ArcMutexGuard::map(...)`. A method would interfere with methods of
/// the same name on the contents of the locked data.
#[inline]
pub fn map<U: ?Sized, F>(s: Self, f: F) -> MappedArcMutexGuard<R, U>
where
F: FnOnce(&mut T) -> &mut U,
T: Sized + 'static,
{
let data = f(unsafe { &mut *s.mutex.data.get() });
// Safety: this reference is outlived by the Arc itself, which ensures it stays valid.
let raw = unsafe { mem::transmute(&s.mutex.raw) };
// Safety: we are "cloning" the Arc without bumping the refcount,
// because we're about to forget the original along with `s`.
let mutex: Arc<_> = unsafe { ptr::read(&s.mutex) };

// We do not want to unlock the mutex, and we do not want to drop s.mutex, so just forget
// the entire thing.
mem::forget(s);

MappedArcMutexGuard { mutex, raw, data }
}

/// Attempts to make a new `MappedArcMutexGuard` for a component of the
/// locked data. The original guard is returned if the closure returns `None`.
///
/// This operation cannot fail as the `ArcMutexGuard` passed
/// in already locked the mutex.
///
/// This is an associated function that needs to be
/// used as `ArcMutexGuard::try_map(...)`. A method would interfere with methods of
/// the same name on the contents of the locked data.
#[inline]
pub fn try_map<U: ?Sized, F>(s: Self, f: F) -> Result<MappedArcMutexGuard<R, U>, Self>
where
F: FnOnce(&mut T) -> Option<&mut U>,
T: Sized + 'static,
{
let data = match f(unsafe { &mut *s.mutex.data.get() }) {
Some(data) => data,
None => return Err(s),
};
// Safety: this reference is outlived by the Arc itself, which ensures it stays valid.
let raw = unsafe { mem::transmute(&s.mutex.raw) };
// Safety: we are "cloning" the Arc without bumping the refcount,
// because we're about to forget the original along with `s`.
let mutex: Arc<_> = unsafe { ptr::read(&s.mutex) };
mem::forget(s);
Ok(MappedArcMutexGuard { mutex, raw, data })
}

/// Unlocks the mutex and returns the `Arc` that was held by the [`ArcMutexGuard`].
#[inline]
#[track_caller]
Expand Down Expand Up @@ -859,6 +917,7 @@ impl<R: RawMutex, T: ?Sized> Drop for ArcMutexGuard<R, T> {
#[inline]
fn drop(&mut self) {
// Safety: A MutexGuard always holds the lock.
// Safety: The dropped mutex is not accessible after this method returns.
unsafe {
self.mutex.raw.unlock();
}
Expand Down Expand Up @@ -1037,3 +1096,168 @@ impl<'a, R: RawMutex + 'a, T: fmt::Display + ?Sized + 'a> fmt::Display

#[cfg(feature = "owning_ref")]
unsafe impl<'a, R: RawMutex + 'a, T: ?Sized + 'a> StableAddress for MappedMutexGuard<'a, R, T> {}

/// An RAII mutex guard returned by `ArcMutexGuard::map`, which can point to a
/// subfield of the protected data.
///
/// The main difference between `MappedArcMutexGuard` and `ArcMutexGuard` is that the
/// former doesn't support temporarily unlocking and re-locking, since that
/// could introduce soundness issues if the locked object is modified by another
/// thread.
#[cfg(feature = "arc_lock")]
#[clippy::has_significant_drop]
#[must_use = "if unused the Mutex will immediately unlock"]
pub struct MappedArcMutexGuard<R: RawMutex + 'static, U: ?Sized> {
// This actually stores a `Arc<Mutex<R, T>>` for some `T`.
// We don't _really_ care about it, but we need it to stay alive so the raw reference below
// stays valid.
mutex: Arc<dyn Any>,

// Note: the `&'static` is a lie.
// It should be outlived by the mutex right above.
raw: &'static R,
data: *mut U,
}

#[cfg(feature = "arc_lock")]
unsafe impl<R: RawMutex + Sync, U: ?Sized + Sync> Sync for MappedArcMutexGuard<R, U> {}
#[cfg(feature = "arc_lock")]
unsafe impl<R: RawMutex, U: ?Sized + Sync> Send for MappedArcMutexGuard<R, U> where
R::GuardMarker: Send
{
}

#[cfg(feature = "arc_lock")]
impl<R: RawMutex, U: ?Sized> MappedArcMutexGuard<R, U> {
/// Drop the content (mostly the Arc) without unlocking the mutex.
#[inline]
fn forget(s: Self) {
// SAFETY: make sure the Arc gets it reference decremented
let mut s = ManuallyDrop::new(s);
unsafe { ptr::drop_in_place(&mut s.mutex) };
}

/// Makes a new `MappedArcMutexGuard` for a component of the locked data.
///
/// This operation cannot fail as the `MappedArcMutexGuard` passed
/// in already locked the mutex.
///
/// This is an associated function that needs to be
/// used as `MappedArcMutexGuard::map(...)`. A method would interfere with methods of
/// the same name on the contents of the locked data.
#[inline]
pub fn map<V: ?Sized, F>(s: Self, f: F) -> MappedArcMutexGuard<R, V>
where
F: FnOnce(&mut U) -> &mut V,
{
// Can't drop `s` or it will unlock the mutex.
let mut s = ManuallyDrop::new(s);

let data = f(unsafe { &mut *s.data });
let raw = s.raw;
// Safety: we are about to forget `s.mutex` along with `s`, so making a copy here can be
// considered a "move".
let mutex: Arc<dyn Any> = unsafe { ptr::read(&s.mutex) };

MappedArcMutexGuard { mutex, raw, data }
}

/// Attempts to make a new `MappedArcMutexGuard` for a component of the
/// locked data. The original guard is returned if the closure returns `None`.
///
/// This operation cannot fail as the `MappedArcMutexGuard` passed
/// in already locked the mutex.
///
/// This is an associated function that needs to be
/// used as `MappedArcMutexGuard::try_map(...)`. A method would interfere with methods of
/// the same name on the contents of the locked data.
#[inline]
pub fn try_map<V: ?Sized, F>(s: Self, f: F) -> Result<MappedArcMutexGuard<R, V>, Self>
where
F: FnOnce(&mut U) -> Option<&mut V>,
{
let data = match f(unsafe { &mut *s.data }) {
Some(data) => data,
None => return Err(s),
};
let raw = s.raw;

// Safety: we are about to forget `s.mutex` along with `s`, so making a copy here can be
// considered a "move".
let mutex: Arc<dyn Any> = unsafe { ptr::read(&s.mutex) };
// Can't drop `s` or it will unlock the mutex.
mem::forget(s);

Ok(MappedArcMutexGuard { mutex, raw, data })
}
}

#[cfg(feature = "arc_lock")]
impl<R: RawMutexFair, U: ?Sized> MappedArcMutexGuard<R, U> {
/// Unlocks the mutex using a fair unlock protocol.
///
/// By default, mutexes are unfair and allow the current thread to re-lock
/// the mutex before another has the chance to acquire the lock, even if
/// that thread has been blocked on the mutex for a long time. This is the
/// default because it allows much higher throughput as it avoids forcing a
/// context switch on every mutex unlock. This can result in one thread
/// acquiring a mutex many more times than other threads.
///
/// However, in some cases it can be beneficial to ensure fairness by forcing
/// the lock to pass on to a waiting thread if there is one. This is done by
/// using this method instead of dropping the `MutexGuard` normally.
#[inline]
pub fn unlock_fair(s: Self) {
// Safety: A MutexGuard always holds the lock.
unsafe {
s.raw.unlock_fair();
}
Self::forget(s);
}
}

#[cfg(feature = "arc_lock")]
impl<R: RawMutex, U: ?Sized> Deref for MappedArcMutexGuard<R, U> {
type Target = U;
#[inline]
fn deref(&self) -> &U {
unsafe { &*self.data }
}
}

#[cfg(feature = "arc_lock")]
impl<R: RawMutex, U: ?Sized> DerefMut for MappedArcMutexGuard<R, U> {
#[inline]
fn deref_mut(&mut self) -> &mut U {
unsafe { &mut *self.data }
}
}

#[cfg(feature = "arc_lock")]
impl<R: RawMutex, U: ?Sized> Drop for MappedArcMutexGuard<R, U> {
#[inline]
fn drop(&mut self) {
// Safety: A MappedArcMutexGuard always holds the lock.
// Safety: self.mutex will not be reachable after this function returns.
unsafe {
self.raw.unlock();
}
}
}

#[cfg(feature = "arc_lock")]
impl<R: RawMutex, U: fmt::Debug + ?Sized> fmt::Debug for MappedArcMutexGuard<R, U> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Debug::fmt(&**self, f)
}
}

#[cfg(feature = "arc_lock")]
impl<R: RawMutex, U: fmt::Display + ?Sized> fmt::Display for MappedArcMutexGuard<R, U> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
(**self).fmt(f)
}
}

#[cfg(all(feature = "arc_lock", feature = "owning_ref"))]
unsafe impl<R: RawMutex, U: ?Sized> StableAddress for MappedArcMutexGuard<R, U> {}
3 changes: 3 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ pub use self::rwlock::{
};
pub use ::lock_api;

#[cfg(feature = "arc_lock")]
pub use self::mutex::MappedArcMutexGuard;

#[cfg(feature = "arc_lock")]
pub use self::lock_api::{
ArcMutexGuard, ArcReentrantMutexGuard, ArcRwLockReadGuard, ArcRwLockUpgradableReadGuard,
Expand Down
35 changes: 35 additions & 0 deletions src/mutex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,16 @@ pub type MutexGuard<'a, T> = lock_api::MutexGuard<'a, RawMutex, T>;
/// thread.
pub type MappedMutexGuard<'a, T> = lock_api::MappedMutexGuard<'a, RawMutex, T>;

/// An RAII mutex guard returned by `ArcMutexGuard::map`, which can point to a
/// subfield of the protected data.
///
/// The main difference between `MappedArcMutexGuard` and `ArcMutexGuard` is that the
/// former doesn't support temporarily unlocking and re-locking, since that
/// could introduce soundness issues if the locked object is modified by another
/// thread.
#[cfg(feature = "arc_lock")]
pub type MappedArcMutexGuard<U> = lock_api::MappedArcMutexGuard<RawMutex, U>;

#[cfg(test)]
mod tests {
use crate::{Condvar, MappedMutexGuard, Mutex, MutexGuard};
Expand Down Expand Up @@ -311,6 +321,31 @@ mod tests {
assert_eq!(contents, *(deserialized.lock()));
}

#[cfg(feature = "arc_lock")]
#[test]
fn test_arc_map() {
use lock_api::{ArcMutexGuard, MappedArcMutexGuard};
use std::sync::Arc;

let contents: Vec<u8> = vec![0, 1, 2];
let mutex: Arc<Mutex<Vec<u8>>> = Arc::new(Mutex::new(contents));

let guard = mutex.lock_arc();

// Example of a failible mapping function: getting a chunk
let guard = ArcMutexGuard::try_map(guard, |contents| contents.first_chunk_mut::<3>())
.ok()
.expect("Could not get the first 3 elements as a chunk.");

// Example of chained mapping: accessing a value.
let guard = MappedArcMutexGuard::map(guard, |contents| &mut contents[1]);

// The point of the ArcMutexGuard is that we don't borrow the mutex, so we can drop it.
drop(mutex);

assert_eq!(*guard, 1);
}

#[test]
fn test_map_or_err_not_mapped() {
let mut map = HashMap::new();
Expand Down