Skip to content

Commit

Permalink
Make storage manager create and manage the Arc
Browse files Browse the repository at this point in the history
  • Loading branch information
slawlor committed Jan 4, 2023
1 parent c0b4db2 commit 0084ec8
Show file tree
Hide file tree
Showing 15 changed files with 183 additions and 185 deletions.
43 changes: 21 additions & 22 deletions akd/src/append_only_zks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -940,14 +940,13 @@ mod tests {
EMPTY_VALUE,
};
use rand::{rngs::OsRng, seq::SliceRandom, RngCore};
use std::sync::Arc;
use std::time::Duration;

#[tokio::test]
async fn test_batch_insert_basic() -> Result<(), AkdError> {
let mut rng = OsRng;
let num_nodes = 10;
let database = Arc::new(AsyncInMemoryDatabase::new());
let database = AsyncInMemoryDatabase::new();
let db = StorageManager::new_no_cache(database);
let mut azks1 = Azks::new::<_>(&db).await?;
azks1.increment_epoch();
Expand All @@ -972,7 +971,7 @@ mod tests {
root_node.write_to_storage(&db).await?;
}

let database2 = Arc::new(AsyncInMemoryDatabase::new());
let database2 = AsyncInMemoryDatabase::new();
let db2 = StorageManager::new_no_cache(database2);
let mut azks2 = Azks::new::<_>(&db2).await?;

Expand All @@ -991,7 +990,7 @@ mod tests {

#[tokio::test]
async fn test_batch_insert_root_hash() -> Result<(), AkdError> {
let database = Arc::new(AsyncInMemoryDatabase::new());
let database = AsyncInMemoryDatabase::new();
let db = StorageManager::new_no_cache(database);

// manually construct a 3-layer tree and compute the root hash
Expand Down Expand Up @@ -1056,7 +1055,7 @@ mod tests {
async fn test_insert_permuted() -> Result<(), AkdError> {
let num_nodes = 10;
let mut rng = OsRng;
let database = Arc::new(AsyncInMemoryDatabase::new());
let database = AsyncInMemoryDatabase::new();
let db = StorageManager::new_no_cache(database);
let mut azks1 = Azks::new::<_>(&db).await?;
azks1.increment_epoch();
Expand All @@ -1083,7 +1082,7 @@ mod tests {
// Try randomly permuting
node_set.shuffle(&mut rng);

let database2 = Arc::new(AsyncInMemoryDatabase::new());
let database2 = AsyncInMemoryDatabase::new();
let db2 = StorageManager::new_no_cache(database2);
let mut azks2 = Azks::new(&db2).await?;

Expand All @@ -1102,7 +1101,7 @@ mod tests {

#[tokio::test]
async fn test_insert_num_nodes() -> Result<(), AkdError> {
let database = Arc::new(AsyncInMemoryDatabase::new());
let database = AsyncInMemoryDatabase::new();
let db = StorageManager::new_no_cache(database.clone());
let mut azks = Azks::new::<_>(&db).await?;

Expand Down Expand Up @@ -1184,7 +1183,7 @@ mod tests {

#[tokio::test]
async fn test_preload_nodes_accuracy() {
let database = Arc::new(AsyncInMemoryDatabase::new());
let database = AsyncInMemoryDatabase::new();
let storage_manager =
StorageManager::new(database, Some(Duration::from_secs(180u64)), None, None);
let mut azks = Azks::new::<_>(&storage_manager)
Expand Down Expand Up @@ -1264,7 +1263,7 @@ mod tests {
#[tokio::test]
async fn test_node_set_partition() -> Result<(), AkdError> {
let num_nodes = 5;
let database = Arc::new(AsyncInMemoryDatabase::new());
let database = AsyncInMemoryDatabase::new();
let db = StorageManager::new_no_cache(database);
let mut azks1 = Azks::new::<_>(&db).await?;
azks1.increment_epoch();
Expand Down Expand Up @@ -1311,7 +1310,7 @@ mod tests {
#[tokio::test]
async fn test_node_set_get_longest_common_prefix() -> Result<(), AkdError> {
let num_nodes = 10;
let database = Arc::new(AsyncInMemoryDatabase::new());
let database = AsyncInMemoryDatabase::new();
let db = StorageManager::new_no_cache(database);
let mut azks1 = Azks::new::<_>(&db).await?;
azks1.increment_epoch();
Expand Down Expand Up @@ -1351,7 +1350,7 @@ mod tests {

// Try randomly permuting
node_set.shuffle(&mut rng);
let database = Arc::new(AsyncInMemoryDatabase::new());
let database = AsyncInMemoryDatabase::new();
let db = StorageManager::new_no_cache(database);
let mut azks = Azks::new::<_>(&db).await?;
azks.batch_insert_nodes::<_>(&db, node_set.clone(), InsertMode::Directory)
Expand Down Expand Up @@ -1401,7 +1400,7 @@ mod tests {

// Try randomly permuting
node_set.shuffle(&mut rng);
let database = Arc::new(AsyncInMemoryDatabase::new());
let database = AsyncInMemoryDatabase::new();
let db = StorageManager::new_no_cache(database);
let mut azks = Azks::new::<_>(&db).await?;
azks.batch_insert_nodes::<_>(&db, node_set.clone(), InsertMode::Directory)
Expand Down Expand Up @@ -1430,7 +1429,7 @@ mod tests {
node_set.push(node);
}

let database = Arc::new(AsyncInMemoryDatabase::new());
let database = AsyncInMemoryDatabase::new();
let db = StorageManager::new_no_cache(database);
let mut azks = Azks::new::<_>(&db).await?;
azks.batch_insert_nodes::<_>(&db, node_set.clone(), InsertMode::Directory)
Expand All @@ -1452,7 +1451,7 @@ mod tests {

// Try randomly permuting
node_set.shuffle(&mut rng);
let database = Arc::new(AsyncInMemoryDatabase::new());
let database = AsyncInMemoryDatabase::new();
let db = StorageManager::new_no_cache(database);
let mut azks = Azks::new::<_>(&db).await?;
azks.batch_insert_nodes::<_>(&db, node_set.clone(), InsertMode::Directory)
Expand All @@ -1475,7 +1474,7 @@ mod tests {

#[tokio::test]
async fn test_membership_proof_intermediate() -> Result<(), AkdError> {
let database = Arc::new(AsyncInMemoryDatabase::new());
let database = AsyncInMemoryDatabase::new();
let db = StorageManager::new_no_cache(database);

let node_set: Vec<Node> = vec![
Expand Down Expand Up @@ -1529,7 +1528,7 @@ mod tests {
let node = Node { label, hash };
node_set.push(node);
}
let database = Arc::new(AsyncInMemoryDatabase::new());
let database = AsyncInMemoryDatabase::new();
let db = StorageManager::new_no_cache(database);
let mut azks = Azks::new::<_>(&db).await?;
let search_label = node_set[0].label;
Expand All @@ -1549,7 +1548,7 @@ mod tests {
let num_nodes = 3;

let node_set = gen_nodes(num_nodes);
let database = Arc::new(AsyncInMemoryDatabase::new());
let database = AsyncInMemoryDatabase::new();
let db = StorageManager::new_no_cache(database);
let mut azks = Azks::new::<_>(&db).await?;
let search_label = node_set[num_nodes - 1].label;
Expand All @@ -1571,7 +1570,7 @@ mod tests {
let num_nodes = 10;

let node_set = gen_nodes(num_nodes);
let database = Arc::new(AsyncInMemoryDatabase::new());
let database = AsyncInMemoryDatabase::new();
let db = StorageManager::new_no_cache(database);
let mut azks = Azks::new::<_>(&db).await?;
let search_label = node_set[num_nodes - 1].label;
Expand All @@ -1590,7 +1589,7 @@ mod tests {

#[tokio::test]
async fn test_append_only_proof_very_tiny() -> Result<(), AkdError> {
let database = Arc::new(AsyncInMemoryDatabase::new());
let database = AsyncInMemoryDatabase::new();
let db = StorageManager::new_no_cache(database);
let mut azks = Azks::new::<_>(&db).await?;

Expand Down Expand Up @@ -1619,7 +1618,7 @@ mod tests {

#[tokio::test]
async fn test_append_only_proof_tiny() -> Result<(), AkdError> {
let database = Arc::new(AsyncInMemoryDatabase::new());
let database = AsyncInMemoryDatabase::new();
let db = StorageManager::new_no_cache(database);
let mut azks = Azks::new::<_>(&db).await?;

Expand Down Expand Up @@ -1664,7 +1663,7 @@ mod tests {

let node_set_1 = gen_nodes(num_nodes);

let database = Arc::new(AsyncInMemoryDatabase::new());
let database = AsyncInMemoryDatabase::new();
let db = StorageManager::new_no_cache(database);
let mut azks = Azks::new::<_>(&db).await?;
azks.batch_insert_nodes::<_>(&db, node_set_1.clone(), InsertMode::Directory)
Expand Down Expand Up @@ -1693,7 +1692,7 @@ mod tests {

#[tokio::test]
async fn future_epoch_throws_error() -> Result<(), AkdError> {
let database = Arc::new(AsyncInMemoryDatabase::new());
let database = AsyncInMemoryDatabase::new();

let db = StorageManager::new_no_cache(database);
let azks = Azks::new::<_>(&db).await?;
Expand Down
2 changes: 1 addition & 1 deletion akd/src/auditor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ pub async fn verify_consecutive_append_only(
let inserted = proof.inserted.clone();

let db = AsyncInMemoryDatabase::new();
let manager = StorageManager::new_no_cache(std::sync::Arc::new(db));
let manager = StorageManager::new_no_cache(db);

let mut azks = Azks::new::<_>(&manager).await?;
azks.batch_insert_nodes::<_>(&manager, unchanged_nodes, InsertMode::Auditor)
Expand Down
30 changes: 15 additions & 15 deletions akd/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
//! use akd::directory::Directory;
//!
//! let db = AsyncInMemoryDatabase::new();
//! let storage_manager = StorageManager::new_no_cache(std::sync::Arc::new(db));
//! let storage_manager = StorageManager::new_no_cache(db);
//! let vrf = HardCodedAkdVRF{};
//!
//! # tokio_test::block_on(async {
Expand All @@ -75,7 +75,7 @@
//! # use akd::directory::Directory;
//! #
//! # let db = AsyncInMemoryDatabase::new();
//! # let storage_manager = StorageManager::new_no_cache(std::sync::Arc::new(db));
//! # let storage_manager = StorageManager::new_no_cache(db);
//! # let vrf = HardCodedAkdVRF{};
//! use akd::EpochHash;
//! use akd::{AkdLabel, AkdValue};
Expand All @@ -86,7 +86,7 @@
//! (AkdLabel::from_utf8_str("second entry"), AkdValue::from_utf8_str("second value")),
//! ];
//! # let db = AsyncInMemoryDatabase::new();
//! # let storage_manager = StorageManager::new_no_cache(std::sync::Arc::new(db));
//! # let storage_manager = StorageManager::new_no_cache(db);
//!
//! # tokio_test::block_on(async {
//! # let vrf = HardCodedAkdVRF{};
Expand All @@ -109,7 +109,7 @@
//! # use akd::directory::Directory;
//! #
//! # let db = AsyncInMemoryDatabase::new();
//! # let storage_manager = StorageManager::new_no_cache(std::sync::Arc::new(db));
//! # let storage_manager = StorageManager::new_no_cache(db);
//! # let vrf = HardCodedAkdVRF{};
//! # use akd::EpochHash;
//! # use akd::{AkdLabel, AkdValue};
Expand All @@ -120,7 +120,7 @@
//! # (AkdLabel::from_utf8_str("second entry"), AkdValue::from_utf8_str("second value")),
//! # ];
//! # let db = AsyncInMemoryDatabase::new();
//! # let storage_manager = StorageManager::new_no_cache(std::sync::Arc::new(db));
//! # let storage_manager = StorageManager::new_no_cache(db);
//! #
//! # tokio_test::block_on(async {
//! # let vrf = HardCodedAkdVRF{};
Expand All @@ -142,7 +142,7 @@
//! # use akd::directory::Directory;
//! #
//! # let db = AsyncInMemoryDatabase::new();
//! # let storage_manager = StorageManager::new_no_cache(std::sync::Arc::new(db));
//! # let storage_manager = StorageManager::new_no_cache(db);
//! # let vrf = HardCodedAkdVRF{};
//! # use akd::EpochHash;
//! # use akd::{AkdLabel, AkdValue};
Expand All @@ -153,7 +153,7 @@
//! # (AkdLabel::from_utf8_str("second entry"), AkdValue::from_utf8_str("second value")),
//! # ];
//! # let db = AsyncInMemoryDatabase::new();
//! # let storage_manager = StorageManager::new_no_cache(std::sync::Arc::new(db));
//! # let storage_manager = StorageManager::new_no_cache(db);
//! #
//! # tokio_test::block_on(async {
//! # let vrf = HardCodedAkdVRF{};
Expand Down Expand Up @@ -196,7 +196,7 @@
//! # use akd::directory::Directory;
//! #
//! # let db = AsyncInMemoryDatabase::new();
//! # let storage_manager = StorageManager::new_no_cache(std::sync::Arc::new(db));
//! # let storage_manager = StorageManager::new_no_cache(db);
//! # let vrf = HardCodedAkdVRF{};
//! # use akd::EpochHash;
//! # use akd::{AkdLabel, AkdValue};
Expand All @@ -207,7 +207,7 @@
//! # (AkdLabel::from_utf8_str("second entry"), AkdValue::from_utf8_str("second value")),
//! # ];
//! # let db = AsyncInMemoryDatabase::new();
//! # let storage_manager = StorageManager::new_no_cache(std::sync::Arc::new(db));
//! # let storage_manager = StorageManager::new_no_cache(db);
//! #
//! # tokio_test::block_on(async {
//! # let vrf = HardCodedAkdVRF{};
Expand Down Expand Up @@ -236,7 +236,7 @@
//! # use akd::directory::Directory;
//! #
//! # let db = AsyncInMemoryDatabase::new();
//! # let storage_manager = StorageManager::new_no_cache(std::sync::Arc::new(db));
//! # let storage_manager = StorageManager::new_no_cache(db);
//! # let vrf = HardCodedAkdVRF{};
//! # use akd::EpochHash;
//! # use akd::HistoryParams;
Expand All @@ -248,7 +248,7 @@
//! # (AkdLabel::from_utf8_str("second entry"), AkdValue::from_utf8_str("second value")),
//! # ];
//! # let db = AsyncInMemoryDatabase::new();
//! # let storage_manager = StorageManager::new_no_cache(std::sync::Arc::new(db));
//! # let storage_manager = StorageManager::new_no_cache(db);
//! #
//! # tokio_test::block_on(async {
//! # let vrf = HardCodedAkdVRF{};
Expand Down Expand Up @@ -300,7 +300,7 @@
//! # use akd::directory::Directory;
//! #
//! # let db = AsyncInMemoryDatabase::new();
//! # let storage_manager = StorageManager::new_no_cache(std::sync::Arc::new(db));
//! # let storage_manager = StorageManager::new_no_cache(db);
//! # let vrf = HardCodedAkdVRF{};
//! # use akd::EpochHash;
//! # use akd::{AkdLabel, AkdValue};
Expand All @@ -311,7 +311,7 @@
//! # (AkdLabel::from_utf8_str("second entry"), AkdValue::from_utf8_str("second value")),
//! # ];
//! # let db = AsyncInMemoryDatabase::new();
//! # let storage_manager = StorageManager::new_no_cache(std::sync::Arc::new(db));
//! # let storage_manager = StorageManager::new_no_cache(db);
//! #
//! # tokio_test::block_on(async {
//! # let vrf = HardCodedAkdVRF{};
Expand Down Expand Up @@ -339,7 +339,7 @@
//! # use akd::directory::Directory;
//! #
//! # let db = AsyncInMemoryDatabase::new();
//! # let storage_manager = StorageManager::new_no_cache(std::sync::Arc::new(db));
//! # let storage_manager = StorageManager::new_no_cache(db);
//! # let vrf = HardCodedAkdVRF{};
//! # use akd::EpochHash;
//! # use akd::{AkdLabel, AkdValue};
Expand All @@ -350,7 +350,7 @@
//! # (AkdLabel::from_utf8_str("second entry"), AkdValue::from_utf8_str("second value")),
//! # ];
//! # let db = AsyncInMemoryDatabase::new();
//! # let storage_manager = StorageManager::new_no_cache(std::sync::Arc::new(db));
//! # let storage_manager = StorageManager::new_no_cache(db);
//! #
//! # tokio_test::block_on(async {
//! # let vrf = HardCodedAkdVRF{};
Expand Down
16 changes: 11 additions & 5 deletions akd/src/storage/manager/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ pub struct StorageManager<Db: Database> {
cache: Option<TimedCache>,
transaction: Transaction,
/// The underlying database managed by this storage manager
pub db: Arc<Db>,
db: Arc<Db>,

metrics: [Arc<AtomicU64>; NUM_METRICS],
}
Expand All @@ -76,18 +76,18 @@ unsafe impl<Db: Database> Send for StorageManager<Db> {}

impl<Db: Database> StorageManager<Db> {
/// Create a new storage manager with NO CACHE
pub fn new_no_cache(db: Arc<Db>) -> Self {
pub fn new_no_cache(db: Db) -> Self {
Self {
cache: None,
transaction: Transaction::new(),
db,
db: Arc::new(db),
metrics: [0; NUM_METRICS].map(|_| Arc::new(AtomicU64::new(0))),
}
}

/// Create a new storage manager with a cache utilizing the options provided (or defaults)
pub fn new(
db: Arc<Db>,
db: Db,
cache_item_lifetime: Option<Duration>,
cache_limit_bytes: Option<usize>,
cache_clean_frequency: Option<Duration>,
Expand All @@ -99,11 +99,17 @@ impl<Db: Database> StorageManager<Db> {
cache_clean_frequency,
)),
transaction: Transaction::new(),
db,
db: Arc::new(db),
metrics: [0; NUM_METRICS].map(|_| Arc::new(AtomicU64::new(0))),
}
}

/// Retrieve a reference to the database implementation
#[cfg(any(test, feature = "public-tests"))]
pub fn get_db(&self) -> Arc<Db> {
self.db.clone()
}

/// Returns whether the storage manager has a cache
pub fn has_cache(&self) -> bool {
self.cache.is_some()
Expand Down
Loading

0 comments on commit 0084ec8

Please sign in to comment.