Skip to content

Commit

Permalink
sn-provider: add a wrapper around starknet provider (#1366)
Browse files Browse the repository at this point in the history
  • Loading branch information
tcoratger authored Sep 9, 2024
1 parent 70857a5 commit 5d941cd
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 73 deletions.
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ pub mod providers {
pub mod debug_provider;
pub mod eth_provider;
pub mod pool_provider;
pub mod sn_provider;
}
pub mod client;
pub mod config;
Expand Down
43 changes: 7 additions & 36 deletions src/pool/mempool.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,5 @@
use super::validate::KakarotTransactionValidator;
use crate::{
client::EthClient,
into_via_wrapper,
models::felt::Felt252Wrapper,
providers::eth_provider::{
error::ExecutionError,
starknet::{ERC20Reader, STARKNET_NATIVE_TOKEN},
utils::{class_hash_not_declared, contract_not_found},
},
};
use crate::{client::EthClient, providers::sn_provider::StarknetProvider};
use reth_primitives::{BlockId, U256};
use reth_transaction_pool::{
blobstore::NoopBlobStore, CoinbaseTipOrdering, EthPooledTransaction, Pool, TransactionPool,
Expand All @@ -17,7 +8,6 @@ use serde_json::Value;
use starknet::core::types::Felt;
use std::{collections::HashMap, fs::File, io::Read, sync::Arc, time::Duration};
use tokio::{runtime::Handle, sync::Mutex};
use tracing::Instrument;

/// A type alias for the Kakarot Transaction Validator.
/// Uses the Reth implementation [`TransactionValidationTaskExecutor`].
Expand Down Expand Up @@ -106,7 +96,7 @@ impl<SP: starknet::providers::Provider + Send + Sync + Clone + 'static> AccountM
for account_address in account_addresses {
// Fetch the balance and handle errors functionally
let balance = self
.get_balance(&account_address)
.get_balance(account_address)
.await
.inspect_err(|err| {
tracing::error!(
Expand All @@ -132,32 +122,13 @@ impl<SP: starknet::providers::Provider + Send + Sync + Clone + 'static> AccountM
}

/// Retrieves the balance of the specified account address.
async fn get_balance(&self, account_address: &Felt) -> eyre::Result<U256> {
async fn get_balance(&self, account_address: Felt) -> eyre::Result<U256> {
// Convert the optional Ethereum block ID to a Starknet block ID.
let starknet_block_id = self.eth_client.eth_provider().to_starknet_block_id(Some(BlockId::default())).await?;

// Create a new `ERC20Reader` instance for the Starknet native token
let eth_contract = ERC20Reader::new(*STARKNET_NATIVE_TOKEN, self.eth_client.eth_provider().starknet_provider());

// Call the `balanceOf` method on the contract for the given account_address and block ID, awaiting the result
let span = tracing::span!(tracing::Level::INFO, "sn::balance");
let res = eth_contract.balanceOf(account_address).block_id(starknet_block_id).call().instrument(span).await;

if contract_not_found(&res) || class_hash_not_declared(&res) {
return Err(eyre::eyre!("Contract not found or class hash not declared"));
}

// Otherwise, extract the balance from the result, converting any errors to ExecutionError
let balance = res.map_err(ExecutionError::from)?.balance;

// Convert the low and high parts of the balance to U256
let low: U256 = into_via_wrapper!(balance.low);
let high: U256 = into_via_wrapper!(balance.high);

// Combine the low and high parts to form the final balance and return it
let balance = low + (high << 128);

Ok(balance)
// Create a new Starknet provider wrapper.
let starknet_provider = StarknetProvider::new(Arc::new(self.eth_client.eth_provider().starknet_provider()));
// Get the balance of the address at the given block ID.
starknet_provider.balance_at(account_address, starknet_block_id).await.map_err(Into::into)
}

/// Processes a transaction for the given account if the balance is sufficient.
Expand Down
51 changes: 14 additions & 37 deletions src/providers/eth_provider/state.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
use std::sync::Arc;

use super::{
database::state::{EthCacheDatabase, EthDatabase},
error::{EthApiError, ExecutionError, TransactionError},
starknet::{
kakarot_core::{account_contract::AccountContractReader, starknet_address},
ERC20Reader, STARKNET_NATIVE_TOKEN,
},
utils::{class_hash_not_declared, contract_not_found, entrypoint_not_found, split_u256},
starknet::kakarot_core::{account_contract::AccountContractReader, starknet_address},
utils::{contract_not_found, entrypoint_not_found, split_u256},
};
use crate::{
into_via_wrapper,
models::felt::Felt252Wrapper,
providers::eth_provider::{
provider::{EthApiResult, EthDataProvider},
BlockProvider, ChainProvider,
providers::{
eth_provider::{
provider::{EthApiResult, EthDataProvider},
BlockProvider, ChainProvider,
},
sn_provider::StarknetProvider,
},
};
use async_trait::async_trait;
Expand Down Expand Up @@ -70,35 +72,10 @@ where
async fn balance(&self, address: Address, block_id: Option<BlockId>) -> EthApiResult<U256> {
// Convert the optional Ethereum block ID to a Starknet block ID.
let starknet_block_id = self.to_starknet_block_id(block_id).await?;

// Create a new `ERC20Reader` instance for the Starknet native token
let eth_contract = ERC20Reader::new(*STARKNET_NATIVE_TOKEN, self.starknet_provider());

// Call the `balanceOf` method on the contract for the given address and block ID, awaiting the result
let span = tracing::span!(tracing::Level::INFO, "sn::balance");
let res = eth_contract
.balanceOf(&starknet_address(address))
.block_id(starknet_block_id)
.call()
.instrument(span)
.await;

// Check if the contract was not found or the class hash not declared,
// returning a default balance of 0 if true.
// The native token contract should be deployed on Kakarot, so this should not happen
// We want to avoid errors in this case and return a default balance of 0
if contract_not_found(&res) || class_hash_not_declared(&res) {
return Ok(Default::default());
}
// Otherwise, extract the balance from the result, converting any errors to ExecutionError
let balance = res.map_err(ExecutionError::from)?.balance;

// Convert the low and high parts of the balance to U256
let low: U256 = into_via_wrapper!(balance.low);
let high: U256 = into_via_wrapper!(balance.high);

// Combine the low and high parts to form the final balance and return it
Ok(low + (high << 128))
// Create a new Starknet provider wrapper.
let starknet_provider = StarknetProvider::new(Arc::new(self.starknet_provider()));
// Get the balance of the address at the given block ID.
starknet_provider.balance_at(starknet_address(address), starknet_block_id).await.map_err(Into::into)
}

async fn storage_at(
Expand Down
3 changes: 3 additions & 0 deletions src/providers/sn_provider/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
pub mod starknet_provider;

pub use starknet_provider::StarknetProvider;
63 changes: 63 additions & 0 deletions src/providers/sn_provider/starknet_provider.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
use crate::{
into_via_wrapper,
models::felt::Felt252Wrapper,
providers::eth_provider::{
error::ExecutionError,
starknet::{ERC20Reader, STARKNET_NATIVE_TOKEN},
utils::{class_hash_not_declared, contract_not_found},
},
};
use reth_primitives::U256;
use starknet::core::types::{BlockId, Felt};
use std::sync::Arc;
use tracing::Instrument;

/// A provider wrapper around the Starknet provider to expose utility methods.
#[derive(Debug, Clone)]
pub struct StarknetProvider<SP: starknet::providers::Provider + Send + Sync> {
/// The underlying Starknet provider wrapped in an [`Arc`] for shared ownership across threads.
provider: Arc<SP>,
}

impl<SP> StarknetProvider<SP>
where
SP: starknet::providers::Provider + Send + Sync,
{
/// Creates a new [`StarknetProvider`] instance from an [`Arc`]-wrapped Starknet provider.
pub const fn new(provider: Arc<SP>) -> Self {
Self { provider }
}

/// Retrieves the balance of a Starknet address for a specified block.
///
/// This method interacts with the Starknet native token contract to query the balance of the given
/// address at a specific block.
///
/// If the contract is not deployed or the class hash is not declared, a balance of 0 is returned
/// instead of an error.
pub async fn balance_at(&self, address: Felt, block_id: BlockId) -> Result<U256, ExecutionError> {
// Create a new `ERC20Reader` instance for the Starknet native token
let eth_contract = ERC20Reader::new(*STARKNET_NATIVE_TOKEN, &self.provider);

// Call the `balanceOf` method on the contract for the given address and block ID, awaiting the result
let span = tracing::span!(tracing::Level::INFO, "sn::balance");
let res = eth_contract.balanceOf(&address).block_id(block_id).call().instrument(span).await;

// Check if the contract was not found or the class hash not declared,
// returning a default balance of 0 if true.
// The native token contract should be deployed on Kakarot, so this should not happen
// We want to avoid errors in this case and return a default balance of 0
if contract_not_found(&res) || class_hash_not_declared(&res) {
return Ok(Default::default());
}
// Otherwise, extract the balance from the result, converting any errors to ExecutionError
let balance = res.map_err(ExecutionError::from)?.balance;

// Convert the low and high parts of the balance to U256
let low: U256 = into_via_wrapper!(balance.low);
let high: U256 = into_via_wrapper!(balance.high);

// Combine the low and high parts to form the final balance and return it
Ok(low + (high << 128))
}
}

0 comments on commit 5d941cd

Please sign in to comment.