Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(trie): blinded provider integration #13066

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 129 additions & 19 deletions crates/engine/tree/src/tree/root.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,25 @@
//! State root task related functionality.

use alloy_primitives::map::{HashMap, HashSet};
use alloy_primitives::{
map::{HashMap, HashSet},
Bytes,
};
use reth_provider::{
providers::ConsistentDbView, BlockReader, DBProvider, DatabaseProviderFactory,
};
use reth_trie::{
proof::Proof, updates::TrieUpdates, HashedPostState, HashedStorage, MultiProof, Nibbles,
TrieInput,
hashed_cursor::HashedPostStateCursorFactory,
proof::{Proof, StorageProof},
trie_cursor::InMemoryTrieCursorFactory,
updates::TrieUpdates,
HashedPostState, HashedStorage, MultiProof, Nibbles, TrieInput,
};
use reth_trie_db::DatabaseProof;
use reth_trie_db::{DatabaseHashedCursorFactory, DatabaseProof, DatabaseTrieCursorFactory};
use reth_trie_parallel::root::ParallelStateRootError;
use reth_trie_sparse::{SparseStateTrie, SparseStateTrieResult, SparseTrieError};
use reth_trie_sparse::{
blinded::{pad_path_to_key, BlindedProvider, BlindedProviderFactory},
SparseStateTrie, SparseStateTrieResult, SparseTrieError,
};
use revm_primitives::{keccak256, EvmState, B256};
use std::{
collections::BTreeMap,
Expand Down Expand Up @@ -50,18 +59,112 @@ impl StateRootHandle {
}

/// Common configuration for state root tasks
#[derive(Debug)]
#[derive(Clone, Debug)]
pub(crate) struct StateRootConfig<Factory> {
/// View over the state in the database.
pub consistent_view: ConsistentDbView<Factory>,
/// Latest trie input.
pub input: Arc<TrieInput>,
}

impl<Factory> BlindedProviderFactory for StateRootConfig<Factory>
where
Factory: DatabaseProviderFactory<Provider: BlockReader> + Clone,
{
type AccountNodeProvider = BlindedAccountNodeProvider<Factory>;
type StorageNodeProvider = BlindedStorageNodeProvider<Factory>;

fn account_node_provider(&self) -> Self::AccountNodeProvider {
BlindedAccountNodeProvider { view: self.consistent_view.clone(), input: self.input.clone() }
}

fn storage_node_provider(&self, hashed_address: B256) -> Self::StorageNodeProvider {
BlindedStorageNodeProvider {
view: self.consistent_view.clone(),
input: self.input.clone(),
hashed_address,
}
}
}

pub(crate) struct BlindedAccountNodeProvider<Factory> {
view: ConsistentDbView<Factory>,
input: Arc<TrieInput>,
}

impl<Factory> BlindedProvider for BlindedAccountNodeProvider<Factory>
where
Factory: DatabaseProviderFactory<Provider: BlockReader>,
{
type Error = SparseTrieError;

fn blinded_node(&mut self, path: Nibbles) -> Result<Option<Bytes>, Self::Error> {
let provider = self.view.provider_ro().unwrap();
let targets = HashMap::from_iter([(pad_path_to_key(&path), Default::default())]);
let proof = Proof::new(
InMemoryTrieCursorFactory::new(
DatabaseTrieCursorFactory::new(provider.tx_ref()),
&self.input.nodes.clone().into_sorted(),
),
HashedPostStateCursorFactory::new(
DatabaseHashedCursorFactory::new(provider.tx_ref()),
&self.input.state.clone().into_sorted(),
),
)
.with_prefix_sets_mut(self.input.prefix_sets.clone())
.multiproof(targets)
.unwrap(); // TODO:

Ok(proof.account_subtree.into_inner().remove(&path))
}
}

pub(crate) struct BlindedStorageNodeProvider<Factory> {
view: ConsistentDbView<Factory>,
input: Arc<TrieInput>,
hashed_address: B256,
}

impl<Factory> BlindedProvider for BlindedStorageNodeProvider<Factory>
where
Factory: DatabaseProviderFactory<Provider: BlockReader>,
{
type Error = SparseTrieError;

fn blinded_node(&mut self, path: Nibbles) -> Result<Option<Bytes>, Self::Error> {
let provider = self.view.provider_ro().unwrap();
let targets = HashSet::from_iter([pad_path_to_key(&path)]);
let storage_prefix_set = self
.input
.prefix_sets
.storage_prefix_sets
.get(&self.hashed_address)
.cloned()
.unwrap_or_default();
let proof = StorageProof::new_hashed(
InMemoryTrieCursorFactory::new(
DatabaseTrieCursorFactory::new(provider.tx_ref()),
&self.input.nodes.clone().into_sorted(),
),
HashedPostStateCursorFactory::new(
DatabaseHashedCursorFactory::new(provider.tx_ref()),
&self.input.state.clone().into_sorted(),
),
self.hashed_address,
)
.with_prefix_set_mut(storage_prefix_set)
.storage_multiproof(targets)
.unwrap();

// The subtree only contains the proof for a single target.
Ok(proof.subtree.into_inner().remove(&path))
}
}

/// Messages used internally by the state root task
#[derive(Debug)]
#[allow(dead_code)]
pub(crate) enum StateRootMessage {
pub(crate) enum StateRootMessage<P: BlindedProviderFactory> {
/// New state update from transaction execution
StateUpdate(EvmState),
/// Proof calculation completed for a specific state update
Expand All @@ -76,7 +179,7 @@ pub(crate) enum StateRootMessage {
/// State root calculation completed
RootCalculated {
/// The updated sparse trie
trie: Box<SparseStateTrie>,
trie: Box<SparseStateTrie<P>>,
/// Time taken to calculate the root
elapsed: Duration,
},
Expand Down Expand Up @@ -157,20 +260,23 @@ impl ProofSequencer {
/// to the tree.
/// Then it updates relevant leaves according to the result of the transaction.
#[derive(Debug)]
pub(crate) struct StateRootTask<Factory> {
pub(crate) struct StateRootTask<Factory>
where
Factory: DatabaseProviderFactory<Provider: BlockReader> + Clone,
{
/// Task configuration.
config: StateRootConfig<Factory>,
/// Receiver for state root related messages.
rx: Receiver<StateRootMessage>,
rx: Receiver<StateRootMessage<StateRootConfig<Factory>>>,
/// Sender for state root related messages.
tx: Sender<StateRootMessage>,
tx: Sender<StateRootMessage<StateRootConfig<Factory>>>,
/// Proof targets that have been already fetched.
fetched_proof_targets: HashMap<B256, HashSet<B256>>,
/// Proof sequencing handler.
proof_sequencer: ProofSequencer,
/// The sparse trie used for the state root calculation. If [`None`], then update is in
/// progress.
sparse_trie: Option<Box<SparseStateTrie>>,
sparse_trie: Option<Box<SparseStateTrie<StateRootConfig<Factory>>>>,
}

#[allow(dead_code)]
Expand All @@ -181,16 +287,17 @@ where
/// Creates a new state root task with the unified message channel
pub(crate) fn new(
config: StateRootConfig<Factory>,
tx: Sender<StateRootMessage>,
rx: Receiver<StateRootMessage>,
tx: Sender<StateRootMessage<StateRootConfig<Factory>>>,
rx: Receiver<StateRootMessage<StateRootConfig<Factory>>>,
) -> Self {
let sparse_trie = SparseStateTrie::new(config.clone()).with_updates(true);
Self {
config,
rx,
tx,
fetched_proof_targets: Default::default(),
proof_sequencer: ProofSequencer::new(),
sparse_trie: Some(Box::new(SparseStateTrie::default().with_updates(true))),
sparse_trie: Some(Box::new(sparse_trie)),
}
}

Expand Down Expand Up @@ -218,7 +325,7 @@ where
update: EvmState,
fetched_proof_targets: &HashMap<B256, HashSet<B256>>,
proof_sequence_number: u64,
state_root_message_sender: Sender<StateRootMessage>,
state_root_message_sender: Sender<StateRootMessage<StateRootConfig<Factory>>>,
) -> HashMap<B256, HashSet<B256>> {
let mut hashed_state_update = HashedPostState::default();
for (address, account) in update {
Expand Down Expand Up @@ -493,12 +600,15 @@ fn get_proof_targets(

/// Updates the sparse trie with the given proofs and state, and returns the updated trie and the
/// time it took.
fn update_sparse_trie(
mut trie: Box<SparseStateTrie>,
fn update_sparse_trie<Factory>(
mut trie: Box<SparseStateTrie<StateRootConfig<Factory>>>,
multiproof: MultiProof,
targets: HashMap<B256, HashSet<B256>>,
state: HashedPostState,
) -> SparseStateTrieResult<(Box<SparseStateTrie>, Duration)> {
) -> SparseStateTrieResult<(Box<SparseStateTrie<StateRootConfig<Factory>>>, Duration)>
where
Factory: DatabaseProviderFactory<Provider: BlockReader> + Clone,
{
trace!(target: "engine::root::sparse", "Updating sparse trie");
let started_at = Instant::now();

Expand Down
8 changes: 8 additions & 0 deletions crates/trie/sparse/src/blinded.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,11 @@ impl BlindedProvider for DefaultBlindedProvider {
Ok(None)
}
}

/// Right pad the path with 0s and return as [`B256`].
#[inline]
pub fn pad_path_to_key(path: &Nibbles) -> B256 {
let mut padded = path.pack();
padded.resize(32, 0);
B256::from_slice(&padded)
}
30 changes: 27 additions & 3 deletions crates/trie/sparse/src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use crate::{
RevealedSparseTrie, SparseStateTrieError, SparseStateTrieResult, SparseTrie, SparseTrieError,
};
use alloy_primitives::{
hex,
map::{HashMap, HashSet},
Bytes, B256,
};
Expand All @@ -13,10 +14,9 @@ use reth_trie_common::{
updates::{StorageTrieUpdates, TrieUpdates},
MultiProof, Nibbles, TrieAccount, TrieNode, EMPTY_ROOT_HASH, TRIE_ACCOUNT_RLP_MAX_SIZE,
};
use std::iter::Peekable;
use std::{fmt, iter::Peekable};

/// Sparse state trie representing lazy-loaded Ethereum state trie.
#[derive(Debug)]
pub struct SparseStateTrie<F: BlindedProviderFactory = DefaultBlindedProviderFactory> {
/// Blinded node provider factory.
provider_factory: F,
Expand All @@ -32,13 +32,25 @@ pub struct SparseStateTrie<F: BlindedProviderFactory = DefaultBlindedProviderFac
account_rlp_buf: Vec<u8>,
}

impl<F: BlindedProviderFactory> fmt::Debug for SparseStateTrie<F> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("SparseStateTrie")
.field("state", &self.state)
.field("storages", &self.storages)
.field("revealed", &self.revealed)
.field("retain_updates", &self.retain_updates)
.field("account_rlp_buf", &hex::encode(&self.account_rlp_buf))
.finish_non_exhaustive()
}
}

impl Default for SparseStateTrie {
fn default() -> Self {
Self {
provider_factory: Default::default(),
state: Default::default(),
storages: Default::default(),
revealed: Default::default(),
provider_factory: Default::default(),
retain_updates: false,
account_rlp_buf: Vec::with_capacity(TRIE_ACCOUNT_RLP_MAX_SIZE),
}
Expand All @@ -53,6 +65,18 @@ impl SparseStateTrie {
}

impl<F: BlindedProviderFactory> SparseStateTrie<F> {
/// Create new [`SparseStateTrie`] with blinded node provider factory.
pub fn new(provider_factory: F) -> Self {
Self {
provider_factory,
state: Default::default(),
storages: Default::default(),
revealed: Default::default(),
retain_updates: false,
account_rlp_buf: Vec::with_capacity(TRIE_ACCOUNT_RLP_MAX_SIZE),
}
}

/// Set the retention of branch node updates and deletions.
pub const fn with_updates(mut self, retain_updates: bool) -> Self {
self.retain_updates = retain_updates;
Expand Down
11 changes: 10 additions & 1 deletion crates/trie/sparse/src/trie.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,23 @@ use std::{borrow::Cow, fmt};

/// Inner representation of the sparse trie.
/// Sparse trie is blind by default until nodes are revealed.
#[derive(PartialEq, Eq, Debug)]
#[derive(PartialEq, Eq)]
pub enum SparseTrie<P = DefaultBlindedProvider> {
/// None of the trie nodes are known.
Blind,
/// The trie nodes have been revealed.
Revealed(Box<RevealedSparseTrie<P>>),
}

impl<P> fmt::Debug for SparseTrie<P> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Blind => write!(f, "Blind"),
Self::Revealed(revealed) => write!(f, "Revealed({revealed:?})"),
}
}
}

impl<P> Default for SparseTrie<P> {
fn default() -> Self {
Self::Blind
Expand Down
Loading