From 0cd1626a4dca3769ddbd4879182db6af34f9848c Mon Sep 17 00:00:00 2001 From: Kris Nuttycombe Date: Tue, 19 Jan 2021 08:23:58 -0700 Subject: [PATCH] WIP: Create an in-memory wallet backend. --- Cargo.toml | 2 +- rust-toolchain.toml | 2 +- zcash_client_backend/src/data_api.rs | 17 +- .../src/data_api/mem_wallet.rs | 444 ++++++++++++++++++ 4 files changed, 456 insertions(+), 9 deletions(-) create mode 100644 zcash_client_backend/src/data_api/mem_wallet.rs diff --git a/Cargo.toml b/Cargo.toml index c129eb377d..216d819fea 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,7 +15,7 @@ members = [ [workspace.package] edition = "2021" -rust-version = "1.65" +rust-version = "1.69" repository = "https://github.com/zcash/librustzcash" license = "MIT OR Apache-2.0" categories = ["cryptography::cryptocurrencies"] diff --git a/rust-toolchain.toml b/rust-toolchain.toml index 5ecda6e495..190bd174ae 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,3 +1,3 @@ [toolchain] -channel = "1.65.0" +channel = "1.69.0" components = [ "clippy", "rustfmt" ] diff --git a/zcash_client_backend/src/data_api.rs b/zcash_client_backend/src/data_api.rs index 16e0f757f2..9d92a904ac 100644 --- a/zcash_client_backend/src/data_api.rs +++ b/zcash_client_backend/src/data_api.rs @@ -44,6 +44,9 @@ pub mod error; pub mod scanning; pub mod wallet; +#[cfg(any(test, feature = "test-dependencies"))] +pub mod mem_wallet; + /// The height of subtree roots in the Sapling note commitment tree. /// /// This conforms to the structure of subtree data returned by @@ -1247,13 +1250,6 @@ pub mod testing { Ok(None) } - fn get_target_and_anchor_heights( - &self, - _min_confirmations: NonZeroU32, - ) -> Result, Self::Error> { - Ok(None) - } - fn block_metadata( &self, _height: BlockHeight, @@ -1273,6 +1269,13 @@ pub mod testing { Ok(vec![]) } + fn get_target_and_anchor_heights( + &self, + _min_confirmations: NonZeroU32, + ) -> Result, Self::Error> { + Ok(None) + } + fn get_min_unspent_height(&self) -> Result, Self::Error> { Ok(None) } diff --git a/zcash_client_backend/src/data_api/mem_wallet.rs b/zcash_client_backend/src/data_api/mem_wallet.rs new file mode 100644 index 0000000000..7678ca9740 --- /dev/null +++ b/zcash_client_backend/src/data_api/mem_wallet.rs @@ -0,0 +1,444 @@ +use incrementalmerkletree::Address; +use secrecy::{ExposeSecret, SecretVec}; +use shardtree::{error::ShardTreeError, store::memory::MemoryShardStore, ShardTree}; +use zip32::DiversifierIndex; +use std::{ + cmp::Ordering, + collections::{BTreeMap, HashMap}, + convert::Infallible, + num::NonZeroU32, +}; + +use zcash_primitives::{ + block::BlockHash, + consensus::{BlockHeight, Network}, + memo::Memo, + transaction::{components::Amount, Transaction, TxId}, + zip32::{AccountId, Scope}, +}; + +use crate::{ + address::UnifiedAddress, + keys::{UnifiedAddressRequest, UnifiedFullViewingKey, UnifiedSpendingKey}, + wallet::{Note, NoteId, ReceivedNote, WalletTransparentOutput, WalletTx}, + ShieldedProtocol, +}; + +use super::{ + chain::CommitmentTreeRoot, error::Error, scanning::ScanRange, AccountBirthday, BlockMetadata, + DecryptedTransaction, InputSource, NullifierQuery, ScannedBlock, SentTransaction, + WalletCommitmentTrees, WalletRead, WalletSummary, WalletWrite, SAPLING_SHARD_HEIGHT, +}; + +#[cfg(feature = "transparent-inputs")] +use {crate::wallet::TransparentAddressMetadata, zcash_primitives::legacy::TransparentAddress}; + +#[cfg(feature = "orchard")] +use super::ORCHARD_SHARD_HEIGHT; + +struct MemoryWalletBlock { + height: BlockHeight, + hash: BlockHash, + block_time: u32, + // Just the transactions that involve an account in this wallet + transactions: HashMap, +} + +impl PartialEq for MemoryWalletBlock { + fn eq(&self, other: &Self) -> bool { + (self.height, self.block_time) == (other.height, other.block_time) + } +} + +impl Eq for MemoryWalletBlock {} + +impl PartialOrd for MemoryWalletBlock { + fn partial_cmp(&self, other: &Self) -> Option { + Some((self.height, self.block_time).cmp(&(other.height, other.block_time))) + } +} + +impl Ord for MemoryWalletBlock { + fn cmp(&self, other: &Self) -> Ordering { + (self.height, self.block_time).cmp(&(other.height, other.block_time)) + } +} + +pub struct MemoryWalletAccount { + account_id: AccountId, + ufvk: UnifiedFullViewingKey, + birthday: AccountBirthday, + addresses: BTreeMap, +} + +pub struct MemoryWalletDb { + network: Network, + blocks: BTreeMap, + tx_idx: HashMap, + accounts: BTreeMap, + sapling_spends: HashMap, + #[cfg(feature = "orchard")] + orchard_spends: HashMap, + sapling_tree: ShardTree< + MemoryShardStore, + { SAPLING_SHARD_HEIGHT * 2 }, + SAPLING_SHARD_HEIGHT, + >, + #[cfg(feature = "orchard")] + orchard_tree: ShardTree< + MemoryShardStore, + { ORCHARD_SHARD_HEIGHT * 2 }, + ORCHARD_SHARD_HEIGHT, + >, +} + +pub enum MemoryWalletError { +} + +impl WalletRead for MemoryWalletDb { + type Error = MemoryWalletError; + + fn chain_height(&self) -> Result, Self::Error> { + Ok(None) + } + + fn block_metadata(&self, _height: BlockHeight) -> Result, Self::Error> { + Ok(None) + } + + fn block_fully_scanned(&self) -> Result, Self::Error> { + Ok(None) + } + + fn block_max_scanned(&self) -> Result, Self::Error> { + Ok(None) + } + + fn suggest_scan_ranges(&self) -> Result, Self::Error> { + Ok(vec![]) + } + + fn get_target_and_anchor_heights( + &self, + _min_confirmations: NonZeroU32, + ) -> Result, Self::Error> { + Ok(None) + } + + fn get_min_unspent_height(&self) -> Result, Self::Error> { + Ok(None) + } + + fn get_block_hash(&self, block_height: BlockHeight) -> Result, Self::Error> { + Ok(self.blocks.iter().find_map(|b| { + if b.height == block_height { + Some(b.hash) + } else { + None + } + })) + } + + fn get_max_height_hash(&self) -> Result, Self::Error> { + Ok(None) + } + + fn get_tx_height(&self, _txid: TxId) -> Result, Self::Error> { + Ok(None) + } + + fn get_wallet_birthday(&self) -> Result, Self::Error> { + Ok(None) + } + + fn get_account_birthday(&self, _account: AccountId) -> Result { + Err(()) + } + + fn get_current_address( + &self, + account: AccountId, + ) -> Result, Self::Error> { + Ok(self + .accounts + .get(&account) + .map(|ufvk| { + ufvk.default_address(UnifiedAddressRequest::unsafe_new(true, true, true)) + .map(|(_, a)| a) + })) + } + + fn get_unified_full_viewing_keys( + &self, + ) -> Result, Self::Error> { + Ok(HashMap::new()) + } + + fn get_account_for_ufvk( + &self, + ufvk: &UnifiedFullViewingKey, + ) -> Result, Self::Error> { + Ok(self + .accounts + .iter() + .filter_map(|(id, ufvk0)| if ufvk0 == ufvk { Some(id) } else { None }) + .next()) + } + + fn get_wallet_summary( + &self, + _min_confirmations: u32, + ) -> Result, Self::Error> { + Ok(None) + } + // fn get_balance_at( + // &self, + // account: AccountId, + // height: BlockHeight, + // ) -> Result { + // let mut received_amounts: HashMap = HashMap::new(); + // Ok(self.blocks.iter().filter(|b| b.height <= height).fold( + // Amount::zero(), + // |acc, block| { + // block.transactions.values().fold(acc, |acc, wallet_tx| { + // // add to our balance when we receive an output + // let total_received = wallet_tx + // .shielded_outputs + // .iter() + // .filter(|s| s.account == account) + // .fold(acc, |acc, o| { + // let nf = o.note.nf( + // &self.accounts.get(&account).unwrap().fvk.vk, + // o.witness.position() as u64, + // ); + // let amount = Amount::from_u64(o.note.value).unwrap(); + // + // // cache received amounts + // received_amounts.insert(nf, amount); + // acc + amount + // }); + // + // // subtract the previously cached received amount when we observe + // // a spend of its nullifier + // wallet_tx + // .shielded_spends + // .iter() + // .filter(|s| { + // self.spentness + // .get(&s.nf) + // .filter(|(_, spent)| *spent) + // .is_some() + // }) + // .fold(total_received, |acc, s| { + // received_amounts.get(&s.nf).map_or(acc, |amt| acc - *amt) + // }) + // }) + // }, + // )) + // } + + fn get_memo(&self, id_note: NoteId) -> Result { + self.blocks + .iter() + .find_map(|b| { + b.transactions.iter().find_map(|(txid, tx)| { + if *txid == id_note.0 { + tx.shielded_outputs.iter().find_map(|wso| { + if wso.index == id_note.1 { + wso.memo.clone().and_then(|m| m.to_utf8()) + } else { + None + } + }) + } else { + None + } + }) + }) + .transpose() + .map_err(MemoryWalletError::MemoDecryptionError) + } + + fn get_transaction(&self, _id_tx: TxId) -> Result { + Err(Error::ScanRequired) // wrong error but we'll fix it later. + } + + fn get_sapling_nullifiers( + &self, + _query: NullifierQuery, + ) -> Result, Self::Error> { + Ok(Vec::new()) + } + + #[cfg(feature = "orchard")] + fn get_orchard_nullifiers( + &self, + _query: NullifierQuery, + ) -> Result, Self::Error> { + Ok(Vec::new()) + } + + #[cfg(feature = "transparent-inputs")] + fn get_transparent_receivers( + &self, + _account: AccountId, + ) -> Result>, Self::Error> { + Ok(HashMap::new()) + } + + #[cfg(feature = "transparent-inputs")] + fn get_transparent_balances( + &self, + _account: AccountId, + _max_height: BlockHeight, + ) -> Result, Self::Error> { + Ok(HashMap::new()) + } + + fn get_account_ids(&self) -> Result, Self::Error> { + Ok(Vec::new()) + } +} + +impl WalletWrite for MemoryWalletDb { + type UtxoRef = u32; + + fn create_account( + &mut self, + seed: &SecretVec, + birthday: AccountBirthday, + ) -> Result<(AccountId, UnifiedSpendingKey), Self::Error> { + let account_id = self + .accounts + .last_key_value() + .map_or(AccountId::ZERO, |(id, _)| { + AccountId::from(u32::from(id) + 1) + }); + let usk = UnifiedSpendingKey::from_seed(&self.network, seed.expose_secret(), account_id)?; + let ufvk = usk.to_unified_full_viewing_key(); + self.accounts.insert(account_id, MemoryWalletAccount { + account_id, + ufvk, + birthday, + addresses: BTreeMap::new() + }); + + Ok((account_id, usk)) + } + + fn get_next_available_address( + &mut self, + account: AccountId, + request: UnifiedAddressRequest, + ) -> Result, Self::Error> { + self.accounts.get(account).map(|acct| + acct.addresses.last_key_value() + ) + } + + #[allow(clippy::type_complexity)] + fn put_blocks( + &mut self, + _blocks: Vec>, + ) -> Result<(), Self::Error> { + Ok(()) + } + + fn update_chain_tip(&mut self, _tip_height: BlockHeight) -> Result<(), Self::Error> { + Ok(()) + } + + fn store_decrypted_tx( + &mut self, + _received_tx: DecryptedTransaction, + ) -> Result<(), Self::Error> { + Ok(()) + } + + fn store_sent_tx(&mut self, _sent_tx: &SentTransaction) -> Result<(), Self::Error> { + Ok(()) + } + + fn truncate_to_height(&mut self, _block_height: BlockHeight) -> Result<(), Self::Error> { + Ok(()) + } + + /// Adds a transparent UTXO received by the wallet to the data store. + fn put_received_transparent_utxo( + &mut self, + _output: &WalletTransparentOutput, + ) -> Result { + Ok(0) + } +} + +impl WalletCommitmentTrees for MemoryWalletDb { + type Error = Infallible; + type SaplingShardStore<'a> = MemoryShardStore; + + fn with_sapling_tree_mut(&mut self, mut callback: F) -> Result + where + for<'a> F: FnMut( + &'a mut ShardTree< + Self::SaplingShardStore<'a>, + { sapling::NOTE_COMMITMENT_TREE_DEPTH }, + SAPLING_SHARD_HEIGHT, + >, + ) -> Result, + E: From>, + { + callback(&mut self.sapling_tree) + } + + fn put_sapling_subtree_roots( + &mut self, + start_index: u64, + roots: &[CommitmentTreeRoot], + ) -> Result<(), ShardTreeError> { + self.with_sapling_tree_mut(|t| { + for (root, i) in roots.iter().zip(0u64..) { + let root_addr = Address::from_parts(SAPLING_SHARD_HEIGHT.into(), start_index + i); + t.insert(root_addr, *root.root_hash())?; + } + Ok::<_, ShardTreeError>(()) + })?; + + Ok(()) + } + + #[cfg(feature = "orchard")] + type OrchardShardStore<'a> = MemoryShardStore; + + #[cfg(feature = "orchard")] + fn with_orchard_tree_mut(&mut self, mut callback: F) -> Result + where + for<'a> F: FnMut( + &'a mut ShardTree< + Self::OrchardShardStore<'a>, + { ORCHARD_SHARD_HEIGHT * 2 }, + ORCHARD_SHARD_HEIGHT, + >, + ) -> Result, + E: From>, + { + callback(&mut self.orchard_tree) + } + + /// Adds a sequence of note commitment tree subtree roots to the data store. + #[cfg(feature = "orchard")] + fn put_orchard_subtree_roots( + &mut self, + start_index: u64, + roots: &[CommitmentTreeRoot], + ) -> Result<(), ShardTreeError> { + self.with_orchard_tree_mut(|t| { + for (root, i) in roots.iter().zip(0u64..) { + let root_addr = Address::from_parts(ORCHARD_SHARD_HEIGHT.into(), start_index + i); + t.insert(root_addr, *root.root_hash())?; + } + Ok::<_, ShardTreeError>(()) + })?; + + Ok(()) + } +}