diff --git a/src/function.rs b/src/function.rs index 512c8ba70..434a895a5 100644 --- a/src/function.rs +++ b/src/function.rs @@ -428,11 +428,11 @@ where fn wait_for<'me>(&'me self, zalsa: &'me Zalsa, key_index: Id) -> WaitForResult<'me> { match self .sync_table - .try_claim(zalsa, key_index, Reentrancy::Deny) + .peek_claim(zalsa, key_index, Reentrancy::Deny) { ClaimResult::Running(blocked_on) => WaitForResult::Running(blocked_on), ClaimResult::Cycle { inner } => WaitForResult::Cycle { inner }, - ClaimResult::Claimed(_) => WaitForResult::Available, + ClaimResult::Claimed(()) => WaitForResult::Available, } } diff --git a/src/function/sync.rs b/src/function/sync.rs index 02f1bffd0..c9a74a307 100644 --- a/src/function/sync.rs +++ b/src/function/sync.rs @@ -20,7 +20,7 @@ pub(crate) struct SyncTable { ingredient: IngredientIndex, } -pub(crate) enum ClaimResult<'a> { +pub(crate) enum ClaimResult<'a, Guard = ClaimGuard<'a>> { /// Can't claim the query because it is running on an other thread. Running(Running<'a>), /// Claiming the query results in a cycle. @@ -31,7 +31,7 @@ pub(crate) enum ClaimResult<'a> { inner: bool, }, /// Successfully claimed the query. - Claimed(ClaimGuard<'a>), + Claimed(Guard), } pub(crate) struct SyncState { @@ -87,10 +87,7 @@ impl SyncTable { } }; - let &mut SyncState { - ref mut anyone_waiting, - .. - } = occupied_entry.into_mut(); + let SyncState { anyone_waiting, .. } = occupied_entry.into_mut(); // NB: `Ordering::Relaxed` is sufficient here, // as there are no loads that are "gated" on this @@ -125,6 +122,51 @@ impl SyncTable { } } + /// Claims the given key index, or blocks if it is running on another thread. + pub(crate) fn peek_claim<'me>( + &'me self, + zalsa: &'me Zalsa, + key_index: Id, + reentrant: Reentrancy, + ) -> ClaimResult<'me, ()> { + let mut write = self.syncs.lock(); + match write.entry(key_index) { + std::collections::hash_map::Entry::Occupied(occupied_entry) => { + let id = match occupied_entry.get().id { + SyncOwner::Thread(id) => id, + SyncOwner::Transferred => { + return match self.peek_claim_transferred(zalsa, occupied_entry, reentrant) { + Ok(claimed) => claimed, + Err(other_thread) => match other_thread.block(write) { + BlockResult::Cycle => ClaimResult::Cycle { inner: false }, + BlockResult::Running(running) => ClaimResult::Running(running), + }, + } + } + }; + + let SyncState { anyone_waiting, .. } = occupied_entry.into_mut(); + + // NB: `Ordering::Relaxed` is sufficient here, + // as there are no loads that are "gated" on this + // value. Everything that is written is also protected + // by a lock that must be acquired. The role of this + // boolean is to decide *whether* to acquire the lock, + // not to gate future atomic reads. + *anyone_waiting = true; + match zalsa.runtime().block( + DatabaseKeyIndex::new(self.ingredient, key_index), + id, + write, + ) { + BlockResult::Running(blocked_on) => ClaimResult::Running(blocked_on), + BlockResult::Cycle => ClaimResult::Cycle { inner: false }, + } + } + std::collections::hash_map::Entry::Vacant(_) => ClaimResult::Claimed(()), + } + } + #[cold] #[inline(never)] fn try_claim_transferred<'me>( @@ -179,6 +221,34 @@ impl SyncTable { } } + #[cold] + #[inline(never)] + fn peek_claim_transferred<'me>( + &'me self, + zalsa: &'me Zalsa, + mut entry: OccupiedEntry, + reentrant: Reentrancy, + ) -> Result, Box>> { + let key_index = *entry.key(); + let database_key_index = DatabaseKeyIndex::new(self.ingredient, key_index); + let thread_id = thread::current().id(); + + match zalsa + .runtime() + .block_transferred(database_key_index, thread_id) + { + BlockTransferredResult::ImTheOwner if reentrant.is_allow() => { + Ok(ClaimResult::Claimed(())) + } + BlockTransferredResult::ImTheOwner => Ok(ClaimResult::Cycle { inner: true }), + BlockTransferredResult::OwnedBy(other_thread) => { + entry.get_mut().anyone_waiting = true; + Err(other_thread) + } + BlockTransferredResult::Released => Ok(ClaimResult::Claimed(())), + } + } + /// Marks `key_index` as a transfer target. /// /// Returns the `SyncOwnerId` of the thread that currently owns this query.