Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions ext/crates/once/src/multiindexed/kdtrie.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,53 @@ impl<V> KdTrie<V> {
unsafe { node.try_set_value(coords[self.dimensions - 1], value) }
}

/// Retrieves a mutable reference to the value at the specified coordinates, if it exists.
///
/// This method can only be called if we have an exclusive reference to self. When you have
/// exclusive access, there's no possibility of concurrent access, so the atomic synchronization
/// used by `get` is unnecessary.
///
/// # Parameters
///
/// * `coords`: A slice of coordinates with length equal to `self.dimensions`
///
/// # Returns
///
/// * `Some(&mut V)` if a value exists at the specified coordinates
/// * `None` if no value exists at the specified coordinates
///
/// # Panics
///
/// Panics if the length of `coords` does not match the number of dimensions.
///
/// # Examples
///
/// ```
/// use once::multiindexed::KdTrie;
///
/// let mut trie = KdTrie::<i32>::new(2);
/// trie.insert(&[1, 2], 42);
///
/// // Modify the value through a mutable reference
/// if let Some(value) = trie.get_mut(&[1, 2]) {
/// *value = 100;
/// }
///
/// assert_eq!(trie.get(&[1, 2]), Some(&100));
/// assert_eq!(trie.get_mut(&[0, 0]), None);
/// ```
pub fn get_mut(&mut self, coords: &[i32]) -> Option<&mut V> {
assert!(coords.len() == self.dimensions);

let mut node = &mut self.root;

for &coord in coords.iter().take(self.dimensions - 1) {
node = unsafe { node.get_child_mut(coord)? };
}

unsafe { node.get_value_mut(coords[self.dimensions - 1]) }
}

pub fn dimensions(&self) -> usize {
self.dimensions
}
Expand Down
114 changes: 102 additions & 12 deletions ext/crates/once/src/multiindexed/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,42 @@ impl<const K: usize, V> MultiIndexed<K, V> {
self.0.get(&coords)
}

/// Retrieves a mutable reference to the value at the specified coordinates, if it exists.
///
/// This method can only be called if we have an exclusive reference to self. Having an
/// exclusive reference prevents concurrent access, so the atomic synchronization used by `get`
/// is unnecessary. This makes it safe to return a mutable reference to the stored value.
///
/// # Parameters
///
/// * `coords`: An array of K integer coordinates
///
/// # Returns
///
/// * `Some(&mut V)` if a value exists at the specified coordinates
/// * `None` if no value exists at the specified coordinates
///
/// # Examples
///
/// ```
/// use once::MultiIndexed;
///
/// let mut array = MultiIndexed::<3, Vec<i32>>::new();
/// array.insert([1, 2, 3], vec![1, 2, 3]);
/// array.insert([4, 5, 6], vec![4, 5, 6]);
///
/// // Modify the vectors in place
/// if let Some(vec) = array.get_mut([1, 2, 3]) {
/// vec.push(4);
/// vec.push(5);
/// }
///
/// assert_eq!(array.get([1, 2, 3]), Some(&vec![1, 2, 3, 4, 5]));
/// ```
pub fn get_mut(&mut self, coords: [i32; K]) -> Option<&mut V> {
self.0.get_mut(&coords)
}

/// Inserts a value at the specified coordinates.
///
/// This operation is thread-safe and can be called from multiple threads. However, this method
Expand Down Expand Up @@ -325,6 +361,34 @@ mod tests {
assert_eq!(arr.get([1, 3, 4]), Some(&45));
}

#[test]
fn test_get_mut_basic() {
let mut arr = MultiIndexed::<3, i32>::new();

arr.insert([1, 2, 3], 42);
arr.insert([4, 5, 6], 100);
arr.insert([-1, -2, -3], 200);

// Modify values using get_mut
if let Some(value) = arr.get_mut([1, 2, 3]) {
*value = 1000;
}
if let Some(value) = arr.get_mut([4, 5, 6]) {
*value += 50;
}
if let Some(value) = arr.get_mut([-1, -2, -3]) {
*value *= 2;
}

// Verify the modifications
assert_eq!(arr.get([1, 2, 3]), Some(&1000));
assert_eq!(arr.get([4, 5, 6]), Some(&150));
assert_eq!(arr.get([-1, -2, -3]), Some(&400));

// Verify that get_mut returns None for non-existent coordinates
assert_eq!(arr.get_mut([0, 0, 0]), None);
}

// This is a bit too heavy for miri
#[cfg_attr(not(miri), test)]
fn test_large() {
Expand Down Expand Up @@ -481,23 +545,26 @@ mod tests {
enum Operation<const K: usize> {
Insert([i32; K], i32),
Get([i32; K]),
GetMut([i32; K]),
Modify([i32; K], i32), // Add this value to the existing value (requires get_mut)
}

fn insert_strategy<const K: usize>(max: u32) -> impl Strategy<Value = Operation<K>> {
// Generate a strategy for a single operation (insert, get, get_mut, or modify)
fn operation_strategy<const K: usize>(max: u32) -> impl Strategy<Value = Operation<K>> {
coords_strategy::<K>(max).prop_flat_map(move |coords| {
any::<i32>().prop_map(move |value| Operation::Insert(coords, value))
prop_oneof![
any::<i32>()
.prop_map(move |value| Operation::Insert(coords, value))
.boxed(),
Just(Operation::Get(coords)).boxed(),
Just(Operation::GetMut(coords)).boxed(),
any::<i32>()
.prop_map(move |delta| Operation::Modify(coords, delta))
.boxed(),
]
})
}

fn get_strategy<const K: usize>(max: u32) -> impl Strategy<Value = Operation<K>> {
coords_strategy::<K>(max).prop_map(Operation::Get)
}

// Generate a strategy for a single operation (insert or get)
fn operation_strategy<const K: usize>(max: u32) -> impl Strategy<Value = Operation<K>> {
prop_oneof![insert_strategy(max), get_strategy(max)]
}

// Generate a strategy for vectors of i32 coordinates
fn coords_vec_strategy<const K: usize>(
max_len: usize,
Expand All @@ -515,7 +582,7 @@ mod tests {
}

fn proptest_multiindexed_ops_kd<const K: usize>(ops: Vec<Operation<K>>) {
let arr = MultiIndexed::<K, i32>::new();
let mut arr = MultiIndexed::<K, i32>::new();
let mut reference = HashMap::new();

for op in ops {
Expand All @@ -538,8 +605,31 @@ mod tests {
let expected = reference.get(&coords);
assert_eq!(actual, expected);
}
Operation::GetMut(coords) => {
// Check that get_mut returns the same as our reference HashMap
let actual = arr.get_mut(coords).map(|v| &*v);
let expected = reference.get(&coords);
assert_eq!(actual, expected);
}
Operation::Modify(coords, delta) => {
// Try to modify the value using get_mut
if let Some(value) = arr.get_mut(coords) {
*value = value.wrapping_add(delta);
}
// Also modify in reference
if let Some(value) = reference.get_mut(&coords) {
*value = value.wrapping_add(delta);
}
// Verify they match
assert_eq!(arr.get(coords), reference.get(&coords));
}
}
}

// Final verification: all values should match
for (coords, expected_value) in &reference {
assert_eq!(arr.get(*coords), Some(expected_value));
}
}

fn proptest_multiindexed_iter_kd<const K: usize>(coords: Vec<[i32; K]>) {
Expand Down
36 changes: 36 additions & 0 deletions ext/crates/once/src/multiindexed/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,24 @@ impl<V> Node<V> {
unsafe { self.inner.get(idx) }
}

/// Retrieves a mutable reference to the child node at the specified index, if it exists.
///
/// # Parameters
///
/// * `idx`: The index of the child node to retrieve
///
/// # Returns
///
/// * `Some(&mut Self)` if a child node exists at the specified index
/// * `None` if no child node exists at the specified index
///
/// # Safety
///
/// Can only be called on an inner node.
pub(super) unsafe fn get_child_mut(&mut self, idx: i32) -> Option<&mut Self> {
unsafe { self.inner.get_mut(idx) }
}

/// Retrieves a reference to the value at the specified index, if it exists.
///
/// # Parameters
Expand All @@ -126,6 +144,24 @@ impl<V> Node<V> {
unsafe { self.leaf.get(idx) }
}

/// Retrieves a mutable reference to the value at the specified index, if it exists.
///
/// # Parameters
///
/// * `idx`: The index of the value to retrieve
///
/// # Returns
///
/// * `Some(&mut V)` if a value exists at the specified index
/// * `None` if no value exists at the specified index
///
/// # Safety
///
/// Can only be called on a leaf node.
pub(super) unsafe fn get_value_mut(&mut self, idx: i32) -> Option<&mut V> {
unsafe { self.leaf.get_mut(idx) }
}

/// Sets the value at the specified index.
///
/// # Parameters
Expand Down