diff --git a/crates/trie/trie/src/forward_cursor.rs b/crates/trie/trie/src/forward_cursor.rs index b1b6c041289..c99b0d049ee 100644 --- a/crates/trie/trie/src/forward_cursor.rs +++ b/crates/trie/trie/src/forward_cursor.rs @@ -23,8 +23,9 @@ impl<'a, K, V> ForwardInMemoryCursor<'a, K, V> { self.is_empty } + /// Returns the current entry pointed to be the cursor, or `None` if no entries are left. #[inline] - fn peek(&self) -> Option<&(K, V)> { + pub fn current(&self) -> Option<&(K, V)> { self.entries.clone().next() } @@ -59,7 +60,7 @@ where fn advance_while(&mut self, predicate: impl Fn(&K) -> bool) -> Option<(K, V)> { let mut entry; loop { - entry = self.peek(); + entry = self.current(); if entry.is_some_and(|(k, _)| predicate(k)) { self.next(); } else { @@ -77,20 +78,21 @@ mod tests { #[test] fn test_cursor() { let mut cursor = ForwardInMemoryCursor::new(&[(1, ()), (2, ()), (3, ()), (4, ()), (5, ())]); + assert_eq!(cursor.current(), Some(&(1, ()))); assert_eq!(cursor.seek(&0), Some((1, ()))); - assert_eq!(cursor.peek(), Some(&(1, ()))); + assert_eq!(cursor.current(), Some(&(1, ()))); assert_eq!(cursor.seek(&3), Some((3, ()))); - assert_eq!(cursor.peek(), Some(&(3, ()))); + assert_eq!(cursor.current(), Some(&(3, ()))); assert_eq!(cursor.seek(&3), Some((3, ()))); - assert_eq!(cursor.peek(), Some(&(3, ()))); + assert_eq!(cursor.current(), Some(&(3, ()))); assert_eq!(cursor.seek(&4), Some((4, ()))); - assert_eq!(cursor.peek(), Some(&(4, ()))); + assert_eq!(cursor.current(), Some(&(4, ()))); assert_eq!(cursor.seek(&6), None); - assert_eq!(cursor.peek(), None); + assert_eq!(cursor.current(), None); } } diff --git a/crates/trie/trie/src/trie_cursor/in_memory.rs b/crates/trie/trie/src/trie_cursor/in_memory.rs index e76bf7b2be3..d9658150f3a 100644 --- a/crates/trie/trie/src/trie_cursor/in_memory.rs +++ b/crates/trie/trie/src/trie_cursor/in_memory.rs @@ -69,10 +69,15 @@ where pub struct InMemoryTrieCursor<'a, C> { /// The underlying cursor. If None then it is assumed there is no DB data. cursor: Option, + /// Entry that `cursor` is currently pointing to. + cursor_entry: Option<(Nibbles, BranchNodeCompact)>, /// Forward-only in-memory cursor over storage trie nodes. in_memory_cursor: ForwardInMemoryCursor<'a, Nibbles, Option>, - /// Last key returned by the cursor. + /// The key most recently returned from the Cursor. last_key: Option, + #[cfg(debug_assertions)] + /// Whether an initial seek was called. + seeked: bool, } impl<'a, C: TrieCursor> InMemoryTrieCursor<'a, C> { @@ -83,47 +88,84 @@ impl<'a, C: TrieCursor> InMemoryTrieCursor<'a, C> { trie_updates: &'a [(Nibbles, Option)], ) -> Self { let in_memory_cursor = ForwardInMemoryCursor::new(trie_updates); - Self { cursor, in_memory_cursor, last_key: None } + Self { + cursor, + cursor_entry: None, + in_memory_cursor, + last_key: None, + #[cfg(debug_assertions)] + seeked: false, + } } - fn seek_inner( - &mut self, - key: Nibbles, - exact: bool, - ) -> Result, DatabaseError> { - let mut mem_entry = self.in_memory_cursor.seek(&key); - let mut db_entry = self.cursor.as_mut().map(|c| c.seek(key)).transpose()?.flatten(); - - // exact matching is easy, if overlay has a value then return that (updated or removed), or - // if db has a value then return that. - if exact { - return Ok(match (mem_entry, db_entry) { - (Some((mem_key, entry_inner)), _) if mem_key == key => { - entry_inner.map(|node| (key, node)) - } - (_, Some((db_key, node))) if db_key == key => Some((key, node)), - _ => None, - }) + /// Asserts that the next entry to be returned from the cursor is not previous to the last entry + /// returned. + fn set_last_key(&mut self, next_entry: &Option<(Nibbles, BranchNodeCompact)>) { + let next_key = next_entry.as_ref().map(|e| e.0); + debug_assert!( + self.last_key.is_none_or(|last| next_key.is_none_or(|next| next >= last)), + "Cannot return entry {:?} previous to the last returned entry at {:?}", + next_key, + self.last_key, + ); + self.last_key = next_key; + } + + /// Seeks the `cursor_entry` field of the struct using the cursor. + fn cursor_seek(&mut self, key: Nibbles) -> Result<(), DatabaseError> { + if let Some(entry) = self.cursor_entry.as_ref() && + entry.0 >= key + { + // If already seeked to the given key then don't do anything. Also if we're seeked past + // the given key then don't anything, because `TrieCursor` is specifically a + // forward-only cursor. + } else { + self.cursor_entry = self.cursor.as_mut().map(|c| c.seek(key)).transpose()?.flatten(); + } + + Ok(()) + } + + /// Seeks the `cursor_entry` field of the struct to the subsequent entry using the cursor. + fn cursor_next(&mut self) -> Result<(), DatabaseError> { + #[cfg(debug_assertions)] + { + debug_assert!(self.seeked); + } + + // If the previous entry is `None`, and we've done a seek previously, then the cursor is + // exhausted and we shouldn't call `next` again. + if self.cursor_entry.is_some() { + self.cursor_entry = self.cursor.as_mut().map(|c| c.next()).transpose()?.flatten(); } + Ok(()) + } + + /// Compares the current in-memory entry with the current entry of the cursor, and applies the + /// in-memory entry to the cursor entry as an overlay. + // + /// This may consume and move forward the current entries when the overlay indicates a removed + /// node. + fn choose_next_entry(&mut self) -> Result, DatabaseError> { loop { - match (mem_entry, &db_entry) { + match (self.in_memory_cursor.current().cloned(), &self.cursor_entry) { (Some((mem_key, None)), _) - if db_entry.as_ref().is_none_or(|(db_key, _)| &mem_key < db_key) => + if self.cursor_entry.as_ref().is_none_or(|(db_key, _)| &mem_key < db_key) => { // If overlay has a removed node but DB cursor is exhausted or ahead of the // in-memory cursor then move ahead in-memory, as there might be further // non-removed overlay nodes. - mem_entry = self.in_memory_cursor.first_after(&mem_key); + self.in_memory_cursor.first_after(&mem_key); } (Some((mem_key, None)), Some((db_key, _))) if &mem_key == db_key => { // If overlay has a removed node which is returned from DB then move both // cursors ahead to the next key. - mem_entry = self.in_memory_cursor.first_after(&mem_key); - db_entry = self.cursor.as_mut().map(|c| c.next()).transpose()?.flatten(); + self.in_memory_cursor.first_after(&mem_key); + self.cursor_next()?; } (Some((mem_key, Some(node))), _) - if db_entry.as_ref().is_none_or(|(db_key, _)| &mem_key <= db_key) => + if self.cursor_entry.as_ref().is_none_or(|(db_key, _)| &mem_key <= db_key) => { // If overlay returns a node prior to the DB's node, or the DB is exhausted, // then we return the overlay's node. @@ -133,18 +175,10 @@ impl<'a, C: TrieCursor> InMemoryTrieCursor<'a, C> { // - mem_key > db_key // - overlay is exhausted // Return the db_entry. If DB is also exhausted then this returns None. - _ => return Ok(db_entry), + _ => return Ok(self.cursor_entry.clone()), } } } - - fn next_inner( - &mut self, - last: Nibbles, - ) -> Result, DatabaseError> { - let Some(key) = last.increment() else { return Ok(None) }; - self.seek_inner(key, false) - } } impl TrieCursor for InMemoryTrieCursor<'_, C> { @@ -152,8 +186,23 @@ impl TrieCursor for InMemoryTrieCursor<'_, C> { &mut self, key: Nibbles, ) -> Result, DatabaseError> { - let entry = self.seek_inner(key, true)?; - self.last_key = entry.as_ref().map(|(nibbles, _)| *nibbles); + self.cursor_seek(key)?; + let mem_entry = self.in_memory_cursor.seek(&key); + + #[cfg(debug_assertions)] + { + self.seeked = true; + } + + let entry = match (mem_entry, &self.cursor_entry) { + (Some((mem_key, entry_inner)), _) if mem_key == key => { + entry_inner.map(|node| (key, node)) + } + (_, Some((db_key, node))) if db_key == &key => Some((key, node.clone())), + _ => None, + }; + + self.set_last_key(&entry); Ok(entry) } @@ -161,22 +210,47 @@ impl TrieCursor for InMemoryTrieCursor<'_, C> { &mut self, key: Nibbles, ) -> Result, DatabaseError> { - let entry = self.seek_inner(key, false)?; - self.last_key = entry.as_ref().map(|(nibbles, _)| *nibbles); + self.cursor_seek(key)?; + self.in_memory_cursor.seek(&key); + + #[cfg(debug_assertions)] + { + self.seeked = true; + } + + let entry = self.choose_next_entry()?; + self.set_last_key(&entry); Ok(entry) } fn next(&mut self) -> Result, DatabaseError> { - let next = match &self.last_key { - Some(last) => { - let entry = self.next_inner(*last)?; - self.last_key = entry.as_ref().map(|entry| entry.0); - entry - } - // no previous entry was found - None => None, + #[cfg(debug_assertions)] + { + debug_assert!(self.seeked, "Cursor must be seek'd before next is called"); + } + + // A `last_key` of `None` indicates that the cursor is exhausted. + let Some(last_key) = self.last_key else { + return Ok(None); }; - Ok(next) + + // If either cursor is currently pointing to the last entry which was returned then consume + // that entry so that `choose_next_entry` is looking at the subsequent one. + if let Some((key, _)) = self.in_memory_cursor.current() && + key == &last_key + { + self.in_memory_cursor.first_after(&last_key); + } + + if let Some((key, _)) = &self.cursor_entry && + key == &last_key + { + self.cursor_next()?; + } + + let entry = self.choose_next_entry()?; + self.set_last_key(&entry); + Ok(entry) } fn current(&mut self) -> Result, DatabaseError> { @@ -218,8 +292,10 @@ mod tests { results.push(entry); } - while let Ok(Some(entry)) = cursor.next() { - results.push(entry); + if !test_case.expected_results.is_empty() { + while let Ok(Some(entry)) = cursor.next() { + results.push(entry); + } } assert_eq!( @@ -501,4 +577,238 @@ mod tests { cursor.next().unwrap(); assert_eq!(cursor.current().unwrap(), Some(Nibbles::from_nibbles([0x3]))); } + + mod proptest_tests { + use super::*; + use itertools::Itertools; + use proptest::prelude::*; + + /// Merge `db_nodes` with `in_memory_nodes`, applying the in-memory overlay. + /// This properly handles deletions (None values in `in_memory_nodes`). + fn merge_with_overlay( + db_nodes: Vec<(Nibbles, BranchNodeCompact)>, + in_memory_nodes: Vec<(Nibbles, Option)>, + ) -> Vec<(Nibbles, BranchNodeCompact)> { + db_nodes + .into_iter() + .merge_join_by(in_memory_nodes, |db_entry, mem_entry| db_entry.0.cmp(&mem_entry.0)) + .filter_map(|entry| match entry { + // Only in db: keep it + itertools::EitherOrBoth::Left((key, node)) => Some((key, node)), + // Only in memory: keep if not a deletion + itertools::EitherOrBoth::Right((key, node_opt)) => { + node_opt.map(|node| (key, node)) + } + // In both: memory takes precedence (keep if not a deletion) + itertools::EitherOrBoth::Both(_, (key, node_opt)) => { + node_opt.map(|node| (key, node)) + } + }) + .collect() + } + + /// Generate a strategy for a `BranchNodeCompact` with simplified parameters. + /// The constraints are: + /// - `tree_mask` must be a subset of `state_mask` + /// - `hash_mask` must be a subset of `state_mask` + /// - `hash_mask.count_ones()` must equal `hashes.len()` + /// + /// To keep it simple, we use an empty hashes vec and `hash_mask` of 0. + fn branch_node_strategy() -> impl Strategy { + any::() + .prop_flat_map(|state_mask| { + let tree_mask_strategy = any::().prop_map(move |tree| tree & state_mask); + (Just(state_mask), tree_mask_strategy) + }) + .prop_map(|(state_mask, tree_mask)| { + BranchNodeCompact::new(state_mask, tree_mask, 0, vec![], None) + }) + } + + /// Generate a sorted vector of (Nibbles, `BranchNodeCompact`) entries + fn sorted_db_nodes_strategy() -> impl Strategy> { + prop::collection::vec( + (prop::collection::vec(any::(), 0..3), branch_node_strategy()), + 0..20, + ) + .prop_map(|entries| { + // Convert Vec to Nibbles and sort + let mut result: Vec<(Nibbles, BranchNodeCompact)> = entries + .into_iter() + .map(|(bytes, node)| (Nibbles::from_nibbles_unchecked(bytes), node)) + .collect(); + result.sort_by(|a, b| a.0.cmp(&b.0)); + result.dedup_by(|a, b| a.0 == b.0); + result + }) + } + + /// Generate a sorted vector of (Nibbles, Option) entries + fn sorted_in_memory_nodes_strategy( + ) -> impl Strategy)>> { + prop::collection::vec( + ( + prop::collection::vec(any::(), 0..3), + prop::option::of(branch_node_strategy()), + ), + 0..20, + ) + .prop_map(|entries| { + // Convert Vec to Nibbles and sort + let mut result: Vec<(Nibbles, Option)> = entries + .into_iter() + .map(|(bytes, node)| (Nibbles::from_nibbles_unchecked(bytes), node)) + .collect(); + result.sort_by(|a, b| a.0.cmp(&b.0)); + result.dedup_by(|a, b| a.0 == b.0); + result + }) + } + + proptest! { + #![proptest_config(ProptestConfig::with_cases(1000))] + + #[test] + fn proptest_in_memory_trie_cursor( + db_nodes in sorted_db_nodes_strategy(), + in_memory_nodes in sorted_in_memory_nodes_strategy(), + op_choices in prop::collection::vec(any::(), 10..500), + ) { + reth_tracing::init_test_tracing(); + use tracing::debug; + + debug!("Starting proptest!"); + + // Create the expected results by merging the two sorted vectors, + // properly handling deletions (None values in in_memory_nodes) + let expected_combined = merge_with_overlay(db_nodes.clone(), in_memory_nodes.clone()); + + // Collect all keys for operation generation + let all_keys: Vec = expected_combined.iter().map(|(k, _)| *k).collect(); + + // Create a control cursor using the combined result with a mock cursor + let control_db_map: BTreeMap = + expected_combined.into_iter().collect(); + let control_db_arc = Arc::new(control_db_map); + let control_visited_keys = Arc::new(Mutex::new(Vec::new())); + let mut control_cursor = MockTrieCursor::new(control_db_arc, control_visited_keys); + + // Create the InMemoryTrieCursor being tested + let db_nodes_map: BTreeMap = + db_nodes.into_iter().collect(); + let db_nodes_arc = Arc::new(db_nodes_map); + let visited_keys = Arc::new(Mutex::new(Vec::new())); + let mock_cursor = MockTrieCursor::new(db_nodes_arc, visited_keys); + let mut test_cursor = InMemoryTrieCursor::new(Some(mock_cursor), &in_memory_nodes); + + // Test: seek to the beginning first + let control_first = control_cursor.seek(Nibbles::default()).unwrap(); + let test_first = test_cursor.seek(Nibbles::default()).unwrap(); + debug!( + control=?control_first.as_ref().map(|(k, _)| k), + test=?test_first.as_ref().map(|(k, _)| k), + "Initial seek returned", + ); + assert_eq!(control_first, test_first, "Initial seek mismatch"); + + // If both cursors returned None, nothing to test + if control_first.is_none() && test_first.is_none() { + return Ok(()); + } + + // Track the last key returned from the cursor + let mut last_returned_key = control_first.as_ref().map(|(k, _)| *k); + + // Execute a sequence of random operations + for choice in op_choices { + let op_type = choice % 3; + + match op_type { + 0 => { + // Next operation + let control_result = control_cursor.next().unwrap(); + let test_result = test_cursor.next().unwrap(); + debug!( + control=?control_result.as_ref().map(|(k, _)| k), + test=?test_result.as_ref().map(|(k, _)| k), + "Next returned", + ); + assert_eq!(control_result, test_result, "Next operation mismatch"); + + last_returned_key = control_result.as_ref().map(|(k, _)| *k); + + // Stop if both cursors are exhausted + if control_result.is_none() && test_result.is_none() { + break; + } + } + 1 => { + // Seek operation - choose a key >= last_returned_key + if all_keys.is_empty() { + continue; + } + + let valid_keys: Vec<_> = all_keys + .iter() + .filter(|k| last_returned_key.is_none_or(|last| **k >= last)) + .collect(); + + if valid_keys.is_empty() { + continue; + } + + let key = *valid_keys[(choice as usize / 3) % valid_keys.len()]; + + let control_result = control_cursor.seek(key).unwrap(); + let test_result = test_cursor.seek(key).unwrap(); + debug!( + control=?control_result.as_ref().map(|(k, _)| k), + test=?test_result.as_ref().map(|(k, _)| k), + ?key, + "Seek returned", + ); + assert_eq!(control_result, test_result, "Seek operation mismatch for key {:?}", key); + + last_returned_key = control_result.as_ref().map(|(k, _)| *k); + + // Stop if both cursors are exhausted + if control_result.is_none() && test_result.is_none() { + break; + } + } + _ => { + // SeekExact operation - choose a key >= last_returned_key + if all_keys.is_empty() { + continue; + } + + let valid_keys: Vec<_> = all_keys + .iter() + .filter(|k| last_returned_key.is_none_or(|last| **k >= last)) + .collect(); + + if valid_keys.is_empty() { + continue; + } + + let key = *valid_keys[(choice as usize / 3) % valid_keys.len()]; + + let control_result = control_cursor.seek_exact(key).unwrap(); + let test_result = test_cursor.seek_exact(key).unwrap(); + debug!( + control=?control_result.as_ref().map(|(k, _)| k), + test=?test_result.as_ref().map(|(k, _)| k), + ?key, + "SeekExact returned", + ); + assert_eq!(control_result, test_result, "SeekExact operation mismatch for key {:?}", key); + + // seek_exact updates the last_key internally but only if it found something + last_returned_key = control_result.as_ref().map(|(k, _)| *k); + } + } + } + } + } + } }