diff --git a/src/accumulator/pollard.rs b/src/accumulator/pollard.rs index 40018f7..a399538 100644 --- a/src/accumulator/pollard.rs +++ b/src/accumulator/pollard.rs @@ -24,7 +24,7 @@ //! let p = Pollard::new() //! .modify(&hashes, &[]) //! .expect("Simple addition don't fail") -//! .modify(&[], &[0, 3]) +//! .modify(&[], &[hashes[0], hashes[3]]) //! .expect("Nor should simple deletion with known to be in tree elements"); //! // We should get this state after //! assert_eq!(p.get_roots().len(), 1); @@ -70,7 +70,8 @@ //! mutability, there is a simple specialized function to do the job, and the API gives you //! a [Rc] over a node, not the [RefCell], so avoid using the [RefCell] directly. use super::node_hash::NodeHash; -use super::util::{detwin, is_left_niece, is_root_position, tree_rows}; +use super::util::{calc_next_pos, is_left_niece, is_root_position, tree_rows}; +use std::collections::HashMap; use std::{cell::Cell, fmt::Debug}; use std::{cell::RefCell, rc::Rc}; type Node = Rc; @@ -113,7 +114,7 @@ impl Debug for PolNode { let mut next_nodes = vec![]; for node in row_nodes { if let Some(node) = node { - write!(f, "{:?} ", &node.get_data()[0..2])?; + write!(f, "{:02x}{:02x} ", node.get_data()[0], node.get_data()[1])?; let (l, r) = node.get_children(); next_nodes.push(l); @@ -222,10 +223,7 @@ impl PolNode { if let Some(parent) = self.get_parent() { return parent.recompute_parent_hash(); } - // It's impossible to reach this, because if we are a branch node (not a leaf), we - // will get into the first case, otherwise the second if will be taken. - // Leave this unreachable to catch weird bugs - unreachable!(); + return; } /// Returns this node's aunt as [Node] fn get_aunt(&self) -> Option { @@ -275,7 +273,7 @@ impl PolNode { self.r_niece.replace(r_niece); } /// Chops down any subtree this node has - fn chop(&self) { + fn _chop(&self) { self.l_niece.replace(None); self.r_niece.replace(None); } @@ -290,6 +288,9 @@ pub struct Pollard { leaves: u64, /// The actual roots roots: Vec, + /// Holds a map of all nodes in the tree. This is used to lookup specific nodes on deletion + /// and while proving + map: HashMap, } impl Pollard { @@ -335,7 +336,7 @@ impl Pollard { /// .expect("Pollard should not fail"); /// assert_eq!(p.get_roots()[0].get_data().to_string(), String::from("b151a956139bb821d4effa34ea95c17560e0135d1e4661fc23cedc3af49dac42")); /// ``` - pub fn modify(self, utxos: &[NodeHash], stxos: &[u64]) -> Result { + pub fn modify(self, utxos: &[NodeHash], stxos: &[NodeHash]) -> Result { let acc = self.delete(stxos)?.add(utxos); Ok(acc) @@ -357,27 +358,49 @@ impl Pollard { // If deleting a root, we just place a default node in it's place if is_root_position(pos, self.leaves, tree_rows(self.leaves)) { + self.map.remove_entry(&self.roots[tree as usize].get_data()); self.roots[tree as usize] = PolNode::default().into_rc(); return Ok(()); } // Grab the node we'll move up // from_node is whomever is moving up, to node is the position it will be // after moved - let (from_node, _, to_node) = self.grab_node(pos)?; + let (from_node, del, to_node) = self.grab_node(pos)?; // If the position I'm moving to has an aunt, I'm not becoming a root. // ancestor is the node right beneath the to_node, this is useful because // either it or it's sibling points to the `to_node`. We need to update this. + // let new_pos = parent(pos, tree_rows(self.leaves)); + self.map.remove_entry(&del.get_data()); + self.update_positions(&from_node, pos)?; if let Some(sibling) = to_node.get_sibling() { // If my new ancestor has a sibling, it means my aunt/parent is not root // and my aunt is pointing to me, *not* my parent. - sibling.chop(); + sibling.set_nieces(del.get_l_niece(), del.get_r_niece()); + if let Some(n) = sibling.get_l_niece() { + n.set_aunt(Some(sibling.clone())); + } + if let Some(n) = sibling.get_r_niece() { + n.set_aunt(Some(sibling.clone())); + } to_node.set_self_hash(from_node.get_data()); - to_node.set_nieces(to_node.get_l_niece(), to_node.get_r_niece()); + if let Some(n) = to_node.get_l_niece() { + n.set_aunt(Some(to_node.clone())); + } + if let Some(n) = to_node.get_r_niece() { + n.set_aunt(Some(to_node.clone())); + } to_node.recompute_parent_hash(); } else { // This means we are a root's sibling. We are becoming a root now to_node.set_self_hash(from_node.get_data()); - to_node.set_nieces(from_node.get_l_niece(), from_node.get_l_niece()); + to_node.set_nieces(del.get_l_niece(), del.get_r_niece()); + if let Some(n) = to_node.get_l_niece() { + n.set_aunt(Some(to_node.clone())); + } + if let Some(n) = to_node.get_r_niece() { + n.set_aunt(Some(to_node.clone())); + } + to_node.recompute_parent_hash(); } Ok(()) } @@ -417,20 +440,56 @@ impl Pollard { fn add(mut self, utxos: &[NodeHash]) -> Self { for utxo in utxos { self.roots = Pollard::add_single(self.roots, *utxo, self.leaves); + self.map.insert(*utxo, self.leaves); self.leaves += 1; } self } /// Deletes nodes from the accumulator - fn delete(mut self, stxos: &[u64]) -> Result { - let stxos = detwin(stxos.to_vec(), tree_rows(self.leaves)); + fn delete(mut self, stxos: &[NodeHash]) -> Result { for stxo in stxos.iter() { - self.delete_single(*stxo)?; + let pos = self + .map + .get(stxo) + .ok_or(format!("UTXO {} is not in the forest", stxo))?; + if let Err(e) = self.delete_single(*pos) { + return Err(e); + } } Ok(self) } + fn update_positions(&mut self, node: &Node, del_pos: u64) -> Result<(), String> { + let mut to_remap = vec![]; + let forest_rows = tree_rows(self.leaves); + to_remap.push(node.to_owned()); + while !to_remap.is_empty() { + let next = to_remap.pop().expect("We checked this is not empty"); + let (l_child, r_child) = next.get_children(); + if let (Some(l_child), Some(r_child)) = (l_child, r_child) { + to_remap.push(l_child); + to_remap.push(r_child); + continue; + } else { + let pos = self.map.get(&next.get_data()).ok_or(format!( + "Node with hash {} is not in the tree", + next.get_data() + ))?; + let new_pos = calc_next_pos(*pos, del_pos, forest_rows)?; + + self.map.entry(next.get_data()).and_modify(|e| *e = new_pos); + } + } + Ok(()) + } + fn lookup_targets(&self, targets: &[NodeHash]) -> Vec { + let mut res = Vec::new(); + for target in targets { + res.push(self.map.get(target)); + } + res.into_iter().flatten().copied().collect::>() + } /// Adds a single node. Addition is a loop over all new nodes, calling add_single for each /// of them fn add_single(mut roots: Vec, node: NodeHash, mut num_leaves: u64) -> Vec { @@ -614,7 +673,7 @@ mod test { let p = Pollard::new() .modify(&hashes, &[]) .expect("Pollard should not fail"); - let p = p.modify(&[], &[0]).expect("msg"); + let p = p.modify(&[], &[hashes[0]]).expect("msg"); let (_, node, _) = p.grab_node(8).unwrap(); assert_eq!( @@ -793,7 +852,7 @@ mod test { let p = Pollard::new() .modify(&hashes, &[]) .expect("Pollard should not fail") - .modify(&[], &[1]) + .modify(&[], &[hashes[1]]) .expect("Still should not fail"); assert_eq!(p.roots.len(), 1); @@ -837,7 +896,10 @@ mod test { let dels = case .target_values .clone() - .expect("Del test must have targets"); + .unwrap() + .iter() + .map(|pos| hashes[*pos as usize].clone()) + .collect::>(); let p = Pollard::new() .modify(&hashes, &[]) .expect("Test pollards are valid") @@ -887,8 +949,7 @@ mod test { // This is a bug found by fuzzing this lib, where calling detect_offset in an non-existing // node would cause a crash. This test makes sure it won't be introduced again let elements = get_hash_vec_of(&[0, 2]); - let targets = [0_u64]; - let pollard = Pollard::new().modify(&elements, &targets); - assert_eq!(pollard.err().unwrap(), "Node not in forest".to_owned()) + let pollard = Pollard::new().modify(&elements, &[elements[0]]); + assert_eq!(pollard.err().unwrap(), "UTXO 6e340b9cffb37a989ca544e6bb780a2c78901d3fb33738768511a30617afa01d is not in the forest".to_owned()) } } diff --git a/src/accumulator/proof.rs b/src/accumulator/proof.rs index a72c299..e474c18 100644 --- a/src/accumulator/proof.rs +++ b/src/accumulator/proof.rs @@ -568,7 +568,7 @@ impl Proof { Ok((Proof { hashes, targets }, target_hashes)) } - fn calc_next_positions( + pub(crate) fn calc_next_positions( block_targets: &Vec, old_positions: &Vec<(u64, NodeHash)>, num_leaves: u64, diff --git a/src/accumulator/util.rs b/src/accumulator/util.rs index 5d6956f..bed2d2c 100644 --- a/src/accumulator/util.rs +++ b/src/accumulator/util.rs @@ -423,5 +423,8 @@ mod tests { let res = super::calc_next_pos(1, 9, 3); assert_eq!(Ok(9), res); + + let res = super::calc_next_pos(5, 13, tree_rows(8)); + assert_eq!(Ok(9), res); } }