diff --git a/src/context.rs b/src/context.rs index 93f43c21a4..c838a82e6e 100644 --- a/src/context.rs +++ b/src/context.rs @@ -13,7 +13,7 @@ use async_channel::{self as channel, Receiver, Sender}; use pgp::types::PublicKeyTrait; use pgp::SignedPublicKey; use ratelimit::Ratelimit; -use tokio::sync::{Mutex, Notify, OnceCell, RwLock}; +use tokio::sync::{Mutex, Notify, RwLock}; use crate::aheader::EncryptPreference; use crate::chat::{get_chat_cnt, ChatId, ProtectionStatus}; @@ -292,7 +292,7 @@ pub struct InnerContext { pub(crate) push_subscribed: AtomicBool, /// Iroh for realtime peer channels. - pub(crate) iroh: OnceCell, + pub(crate) iroh: RwLock>>, } /// The state of ongoing process. @@ -450,7 +450,7 @@ impl Context { debug_logging: std::sync::RwLock::new(None), push_subscriber, push_subscribed: AtomicBool::new(false), - iroh: OnceCell::new(), + iroh: RwLock::new(None), }; let ctx = Context { @@ -485,6 +485,7 @@ impl Context { /// Stops the IO scheduler. pub async fn stop_io(&self) { + *self.iroh.write().await = None; self.scheduler.stop(self).await; } @@ -496,7 +497,7 @@ impl Context { /// Indicate that the network likely has come back. pub async fn maybe_network(&self) { - if let Some(iroh) = self.iroh.get() { + if let Some(ref iroh) = *self.iroh.read().await { iroh.network_change().await; } self.scheduler.maybe_network().await; diff --git a/src/peer_channels.rs b/src/peer_channels.rs index 254c7dd3b6..556634f9e1 100644 --- a/src/peer_channels.rs +++ b/src/peer_channels.rs @@ -35,6 +35,7 @@ use iroh_net::{NodeAddr, NodeId}; use parking_lot::Mutex; use std::collections::{BTreeSet, HashMap}; use std::env; +use std::sync::Arc; use tokio::sync::{oneshot, RwLock}; use tokio::task::JoinHandle; use url::Url; @@ -285,15 +286,22 @@ impl Context { } /// Get or initialize the iroh peer channel. - pub async fn get_or_try_init_peer_channel(&self) -> Result<&Iroh> { + pub async fn get_or_try_init_peer_channel(&self) -> Result> { if !self.get_config_bool(Config::WebxdcRealtimeEnabled).await? { bail!("Attempt to get Iroh when realtime is disabled"); } - let ctx = self.clone(); - self.iroh - .get_or_try_init(|| async { ctx.init_peer_channels().await }) - .await + { + let lock = self.iroh.read().await; + if let Some(ref iroh) = *lock { + return Ok(Arc::clone(iroh)); + } + } + + let mut lock = self.iroh.write().await; + let iroh = Arc::new(self.init_peer_channels().await?); + *lock = Some(iroh.clone()); + Ok(iroh) } } @@ -720,6 +728,12 @@ mod tests { } } } + + // Calling stop_io() closes iroh endpoint as well, + // even though I/O was not started in this test. + assert!(alice.iroh.read().await.is_some()); + alice.stop_io().await; + assert!(alice.iroh.read().await.is_none()); } #[tokio::test(flavor = "multi_thread", worker_threads = 2)] @@ -811,7 +825,9 @@ mod tests { .unwrap(); let bob_sequence_number = bob .iroh - .get() + .read() + .await + .clone() .unwrap() .sequence_numbers .lock() @@ -820,7 +836,9 @@ mod tests { leave_webxdc_realtime(bob, bob_webxdc.id).await.unwrap(); let bob_sequence_number_after = bob .iroh - .get() + .read() + .await + .clone() .unwrap() .sequence_numbers .lock() @@ -860,7 +878,16 @@ mod tests { // bob for example does not change the channels because he never sends an // advertisement assert_eq!( - alice.iroh.get().unwrap().iroh_channels.read().await.len(), + alice + .iroh + .read() + .await + .clone() + .unwrap() + .iroh_channels + .read() + .await + .len(), 1 ); leave_webxdc_realtime(alice, alice_webxdc.id).await.unwrap(); @@ -870,7 +897,9 @@ mod tests { .unwrap(); assert!(alice .iroh - .get() + .read() + .await + .as_ref() .unwrap() .iroh_channels .read() @@ -963,19 +992,19 @@ mod tests { .await .unwrap(); - assert!(alice.ctx.iroh.get().is_none()); + assert!(alice.ctx.iroh.read().await.is_none()); // creates iroh endpoint as side effect send_webxdc_realtime_data(alice, MsgId::new(1), vec![]) .await .unwrap(); - assert!(alice.ctx.iroh.get().is_none()); + assert!(alice.ctx.iroh.read().await.is_none()); // creates iroh endpoint as side effect leave_webxdc_realtime(alice, MsgId::new(1)).await.unwrap(); - assert!(alice.ctx.iroh.get().is_none()); + assert!(alice.ctx.iroh.read().await.is_none()); // This internal function should return error // if accidentally called with the setting disabled.