Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 139 additions & 19 deletions src/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1074,6 +1074,27 @@ where
/// The key may be any borrowed form of the map's key type, but `Hash` and `Eq` on the borrowed
/// form must match those for the key type.
pub fn remove<'g, Q>(&'g self, key: &Q, guard: &'g Guard) -> Option<&'g V>
where
K: Borrow<Q>,
Q: ?Sized + Hash + Eq,
{
self.replace_node(key, None, None, guard)
}

/// Replaces node value with v, conditional upon match of cv.
/// If resulting value does not exist it removes the key (and its corresponding value) from this map.
/// This method does nothing if the key is not in the map.
/// Returns the previous value associated with the given key.
///
/// The key may be any borrowed form of the map's key type, but `Hash` and `Eq` on the borrowed
/// form must match those for the key type.
fn replace_node<'g, Q>(
&'g self,
key: &Q,
new_value: Option<V>,
observed_value: Option<Shared<'g, V>>,
guard: &'g Guard,
) -> Option<&'g V>
where
K: Borrow<Q>,
Q: ?Sized + Hash + Eq,
Expand Down Expand Up @@ -1152,26 +1173,36 @@ where
let next = n.next.load(Ordering::SeqCst, guard);
if n.hash == h && n.key.borrow() == key {
let ev = n.value.load(Ordering::SeqCst, guard);
old_val = Some(ev);

// remove the BinEntry containing the removed key value pair from the bucket
if !pred.is_null() {
// either by changing the pointer of the previous BinEntry, if present
// safety: as above
unsafe { pred.deref() }
.as_node()
.unwrap()
.next
.store(next, Ordering::SeqCst);
} else {
// or by setting the next node as the first BinEntry if there is no previous entry
t.store_bin(i, next);
}

// in either case, mark the BinEntry as garbage, since it was just removed
// safety: as for val below / in put
unsafe { guard.defer_destroy(e) };

// just remove the node if the value is the one we expected at method call
if observed_value.map(|ov| ov == ev).unwrap_or(true) {
// found the node but we have a new value to replace the old one
if let Some(nv) = new_value {
n.value.store(Owned::new(nv), Ordering::SeqCst);
// we are just replacing entry value and we do not want to remove the node
// so we stop iterating here
break;
}
// we remember the old value so that we can return it and mark it for deletion below
old_val = Some(ev);
// remove the BinEntry containing the removed key value pair from the bucket
if !pred.is_null() {
// either by changing the pointer of the previous BinEntry, if present
// safety: as above
unsafe { pred.deref() }
.as_node()
.unwrap()
.next
.store(next, Ordering::SeqCst);
} else {
// or by setting the next node as the first BinEntry if there is no previous entry
t.store_bin(i, next);
}

// in either case, mark the BinEntry as garbage, since it was just removed
// safety: as for val below / in put
unsafe { guard.defer_destroy(e) };
}
// since the key was found and only one node exists per key, we can break here
break;
}
Expand Down Expand Up @@ -1218,6 +1249,46 @@ where
None
}

/// Retains only the elements specified by the predicate.
///
/// In other words, remove all pairs (k, v) such that f(&k,&v) returns false.
///
/// If `f` returns `false` for a given key/value pair, but the value for that pair is concurrently
/// modified before the removal takes place, the entry will not be removed.
/// If you want the removal to happen even in the case of concurrent modification, use [`HashMap::retain_force`].
pub fn retain<F>(&mut self, mut f: F)
where
F: FnMut(&K, &V) -> bool,
{
let guard = epoch::pin();
// removed selected keys
for (k, v) in self.iter(&guard) {
if !f(k, v) {
let old_value: Shared<'_, V> = Shared::from(v as *const V);
self.replace_node(k, None, Some(old_value), &guard);
}
}
}

/// Retains only the elements specified by the predicate.
///
/// In other words, remove all pairs (k, v) such that f(&k,&v) returns false.
///
/// This method always deletes any key/value pair that `f` returns `false` for,
/// even if if the value is updated concurrently. If you do not want that behavior, use [`HashMap::retain`].
pub fn retain_force<F>(&mut self, mut f: F)
where
F: FnMut(&K, &V) -> bool,
{
let guard = epoch::pin();
// removed selected keys
for (k, v) in self.iter(&guard) {
if !f(k, v) {
self.replace_node(k, None, None, &guard);
}
}
}

/// An iterator visiting all key-value pairs in arbitrary order.
/// The iterator element type is `(&'g K, &'g V)`.
///
Expand Down Expand Up @@ -1604,5 +1675,54 @@ mod tests {
/// drop(guard);
/// drop(r);
/// ```

#[test]
fn replace_empty() {
let map = HashMap::<usize, usize>::new();

{
let guard = epoch::pin();
let old = map.replace_node(&42, None, None, &guard);
assert!(old.is_none());
}
}

#[test]
fn replace_existing() {
let map = HashMap::<usize, usize>::new();
{
let guard = epoch::pin();
map.insert(42, 42, &guard);
let old = map.replace_node(&42, Some(10), None, &guard);
assert!(old.is_none());
assert_eq!(*map.get(&42, &guard).unwrap(), 10);
}
}

#[test]
fn replace_existing_observed_value_matching() {
let map = HashMap::<usize, usize>::new();
{
let guard = epoch::pin();
map.insert(42, 42, &guard);
let observed_value = Shared::from(map.get(&42, &guard).unwrap() as *const _);
let old = map.replace_node(&42, Some(10), Some(observed_value), &guard);
assert!(old.is_none());
assert_eq!(*map.get(&42, &guard).unwrap(), 10);
}
}

#[test]
fn replace_existing_observed_value_non_matching() {
let map = HashMap::<usize, usize>::new();
{
let guard = epoch::pin();
map.insert(42, 42, &guard);
let old = map.replace_node(&42, Some(10), Some(Shared::null()), &guard);
assert!(old.is_none());
assert_eq!(*map.get(&42, &guard).unwrap(), 42);
}
}

#[allow(dead_code)]
struct CompileFailTests;
47 changes: 47 additions & 0 deletions tests/basic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -445,3 +445,50 @@ fn from_iter_empty() {

assert_eq!(map.len(), 0)
}

#[test]
fn retain_empty() {
let mut map = HashMap::<&'static str, u32>::new();
map.retain(|_, _| false);
assert_eq!(map.len(), 0);
}

#[test]
fn retain_all_false() {
let mut map: HashMap<u32, u32> = (0..10 as u32).map(|x| (x, x)).collect();
map.retain(|_, _| false);
assert_eq!(map.len(), 0);
}

#[test]
fn retain_all_true() {
let size = 10usize;
let mut map: HashMap<usize, usize> = (0..size).map(|x| (x, x)).collect();
map.retain(|_, _| true);
assert_eq!(map.len(), size);
}

#[test]
fn retain_some() {
let mut map: HashMap<u32, u32> = (0..10).map(|x| (x, x)).collect();
let expected_map: HashMap<u32, u32> = (5..10).map(|x| (x, x)).collect();
map.retain(|_, v| *v >= 5);
assert_eq!(map.len(), 5);
assert_eq!(map, expected_map);
}

#[test]
fn retain_force_empty() {
let mut map = HashMap::<&'static str, u32>::new();
map.retain_force(|_, _| false);
assert_eq!(map.len(), 0);
}

#[test]
fn retain_force_some() {
let mut map: HashMap<u32, u32> = (0..10).map(|x| (x, x)).collect();
let expected_map: HashMap<u32, u32> = (5..10).map(|x| (x, x)).collect();
map.retain_force(|_, v| *v >= 5);
assert_eq!(map.len(), 5);
assert_eq!(map, expected_map);
}