diff --git a/src/context.rs b/src/context.rs index 7cc1e929ea..7d1c8c774e 100644 --- a/src/context.rs +++ b/src/context.rs @@ -13,7 +13,7 @@ use async_channel::{self as channel, Receiver, Sender}; use pgp::types::PublicKeyTrait; use pgp::SignedPublicKey; use ratelimit::Ratelimit; -use tokio::sync::{Mutex, Notify, OnceCell, RwLock}; +use tokio::sync::{Mutex, Notify, RwLock}; use crate::aheader::EncryptPreference; use crate::chat::{get_chat_cnt, ChatId, ProtectionStatus}; @@ -292,7 +292,7 @@ pub struct InnerContext { pub(crate) push_subscribed: AtomicBool, /// Iroh for realtime peer channels. - pub(crate) iroh: OnceCell, + pub(crate) iroh: Arc>>, } /// The state of ongoing process. @@ -450,7 +450,7 @@ impl Context { debug_logging: std::sync::RwLock::new(None), push_subscriber, push_subscribed: AtomicBool::new(false), - iroh: OnceCell::new(), + iroh: Arc::new(RwLock::new(None)), }; let ctx = Context { @@ -486,6 +486,19 @@ impl Context { /// Stops the IO scheduler. pub async fn stop_io(&self) { self.scheduler.stop(self).await; + if let Some(iroh) = self.iroh.write().await.take() { + // Close all QUIC connections. + + // Spawn into a separate task, + // because Iroh calls `wait_idle()` internally + // and it may take time, especially if the network + // has become unavailable. + tokio::spawn(async move { + // We do not log the error because we do not want the task + // to hold the reference to Context. + let _ = tokio::time::timeout(Duration::from_secs(60), iroh.close()).await; + }); + } } /// Restarts the IO scheduler if it was running before @@ -496,7 +509,7 @@ impl Context { /// Indicate that the network likely has come back. pub async fn maybe_network(&self) { - if let Some(iroh) = self.iroh.get() { + if let Some(ref iroh) = *self.iroh.read().await { iroh.network_change().await; } self.scheduler.maybe_network().await; diff --git a/src/peer_channels.rs b/src/peer_channels.rs index 254c7dd3b6..73e0c68446 100644 --- a/src/peer_channels.rs +++ b/src/peer_channels.rs @@ -78,6 +78,14 @@ impl Iroh { self.endpoint.network_change().await } + /// Closes the QUIC endpoint. + pub(crate) async fn close(self) -> Result<()> { + self.endpoint + .close(0u32.into(), b"") + .await + .context("Closing iroh endpoint failed") + } + /// Join a topic and create the subscriber loop for it. /// /// If there is no gossip, create it. @@ -285,15 +293,36 @@ impl Context { } /// Get or initialize the iroh peer channel. - pub async fn get_or_try_init_peer_channel(&self) -> Result<&Iroh> { + pub async fn get_or_try_init_peer_channel( + &self, + ) -> Result> { if !self.get_config_bool(Config::WebxdcRealtimeEnabled).await? { bail!("Attempt to get Iroh when realtime is disabled"); } - let ctx = self.clone(); - self.iroh - .get_or_try_init(|| async { ctx.init_peer_channels().await }) - .await + if let Ok(lock) = tokio::sync::RwLockReadGuard::<'_, std::option::Option>::try_map( + self.iroh.read().await, + |opt_iroh| opt_iroh.as_ref(), + ) { + return Ok(lock); + } + + let lock = self.iroh.write().await; + match tokio::sync::RwLockWriteGuard::<'_, std::option::Option>::try_downgrade_map( + lock, + |opt_iroh| opt_iroh.as_ref(), + ) { + Ok(lock) => Ok(lock), + Err(mut lock) => { + let iroh = self.init_peer_channels().await?; + *lock = Some(iroh); + tokio::sync::RwLockWriteGuard::<'_, std::option::Option>::try_downgrade_map( + lock, + |opt_iroh| opt_iroh.as_ref(), + ) + .map_err(|_| anyhow!("Downgrade should succeed as we just stored `Some` value")) + } + } } } @@ -626,7 +655,6 @@ mod tests { break; } } - let bob_iroh = bob.get_or_try_init_peer_channel().await.unwrap(); // Bob adds alice to gossip peers. let members = get_iroh_gossip_peers(bob, bob_webxdc.id) @@ -636,13 +664,23 @@ mod tests { .map(|addr| addr.node_id) .collect::>(); - let alice_iroh = alice.get_or_try_init_peer_channel().await.unwrap(); assert_eq!( members, - vec![alice_iroh.get_node_addr().await.unwrap().node_id] + vec![ + alice + .get_or_try_init_peer_channel() + .await + .unwrap() + .get_node_addr() + .await + .unwrap() + .node_id + ] ); - bob_iroh + bob.get_or_try_init_peer_channel() + .await + .unwrap() .join_and_subscribe_gossip(bob, bob_webxdc.id) .await .unwrap() @@ -651,7 +689,10 @@ mod tests { .unwrap(); // Alice sends ephemeral message - alice_iroh + alice + .get_or_try_init_peer_channel() + .await + .unwrap() .send_webxdc_realtime_data(alice, alice_webxdc.id, "alice -> bob".as_bytes().to_vec()) .await .unwrap(); @@ -670,7 +711,9 @@ mod tests { } } // Bob sends ephemeral message - bob_iroh + bob.get_or_try_init_peer_channel() + .await + .unwrap() .send_webxdc_realtime_data(bob, bob_webxdc.id, "bob -> alice".as_bytes().to_vec()) .await .unwrap(); @@ -699,10 +742,20 @@ mod tests { assert_eq!( members, - vec![bob_iroh.get_node_addr().await.unwrap().node_id] + vec![ + bob.get_or_try_init_peer_channel() + .await + .unwrap() + .get_node_addr() + .await + .unwrap() + .node_id + ] ); - bob_iroh + bob.get_or_try_init_peer_channel() + .await + .unwrap() .send_webxdc_realtime_data(bob, bob_webxdc.id, "bob -> alice 2".as_bytes().to_vec()) .await .unwrap(); @@ -720,6 +773,12 @@ mod tests { } } } + + // Calling stop_io() closes iroh endpoint as well, + // even though I/O was not started in this test. + assert!(alice.iroh.read().await.is_some()); + alice.stop_io().await; + assert!(alice.iroh.read().await.is_none()); } #[tokio::test(flavor = "multi_thread", worker_threads = 2)] @@ -761,7 +820,6 @@ mod tests { .unwrap(); bob.recv_msg_trash(&alice.pop_sent_msg().await).await; - let bob_iroh = bob.get_or_try_init_peer_channel().await.unwrap(); // Bob adds alice to gossip peers. let members = get_iroh_gossip_peers(bob, bob_webxdc.id) @@ -771,13 +829,23 @@ mod tests { .map(|addr| addr.node_id) .collect::>(); - let alice_iroh = alice.get_or_try_init_peer_channel().await.unwrap(); assert_eq!( members, - vec![alice_iroh.get_node_addr().await.unwrap().node_id] + vec![ + alice + .get_or_try_init_peer_channel() + .await + .unwrap() + .get_node_addr() + .await + .unwrap() + .node_id + ] ); - bob_iroh + bob.get_or_try_init_peer_channel() + .await + .unwrap() .join_and_subscribe_gossip(bob, bob_webxdc.id) .await .unwrap() @@ -786,7 +854,10 @@ mod tests { .unwrap(); // Alice sends ephemeral message - alice_iroh + alice + .get_or_try_init_peer_channel() + .await + .unwrap() .send_webxdc_realtime_data(alice, alice_webxdc.id, "alice -> bob".as_bytes().to_vec()) .await .unwrap(); @@ -811,7 +882,9 @@ mod tests { .unwrap(); let bob_sequence_number = bob .iroh - .get() + .read() + .await + .as_ref() .unwrap() .sequence_numbers .lock() @@ -820,7 +893,9 @@ mod tests { leave_webxdc_realtime(bob, bob_webxdc.id).await.unwrap(); let bob_sequence_number_after = bob .iroh - .get() + .read() + .await + .as_ref() .unwrap() .sequence_numbers .lock() @@ -829,7 +904,9 @@ mod tests { // Check that sequence number is persisted when leaving the channel. assert_eq!(bob_sequence_number, bob_sequence_number_after); - bob_iroh + bob.get_or_try_init_peer_channel() + .await + .unwrap() .join_and_subscribe_gossip(bob, bob_webxdc.id) .await .unwrap() @@ -837,7 +914,9 @@ mod tests { .await .unwrap(); - bob_iroh + bob.get_or_try_init_peer_channel() + .await + .unwrap() .send_webxdc_realtime_data(bob, bob_webxdc.id, "bob -> alice".as_bytes().to_vec()) .await .unwrap(); @@ -860,7 +939,16 @@ mod tests { // bob for example does not change the channels because he never sends an // advertisement assert_eq!( - alice.iroh.get().unwrap().iroh_channels.read().await.len(), + alice + .iroh + .read() + .await + .as_ref() + .unwrap() + .iroh_channels + .read() + .await + .len(), 1 ); leave_webxdc_realtime(alice, alice_webxdc.id).await.unwrap(); @@ -870,7 +958,9 @@ mod tests { .unwrap(); assert!(alice .iroh - .get() + .read() + .await + .as_ref() .unwrap() .iroh_channels .read() @@ -963,19 +1053,19 @@ mod tests { .await .unwrap(); - assert!(alice.ctx.iroh.get().is_none()); + assert!(alice.ctx.iroh.read().await.is_none()); // creates iroh endpoint as side effect send_webxdc_realtime_data(alice, MsgId::new(1), vec![]) .await .unwrap(); - assert!(alice.ctx.iroh.get().is_none()); + assert!(alice.ctx.iroh.read().await.is_none()); // creates iroh endpoint as side effect leave_webxdc_realtime(alice, MsgId::new(1)).await.unwrap(); - assert!(alice.ctx.iroh.get().is_none()); + assert!(alice.ctx.iroh.read().await.is_none()); // This internal function should return error // if accidentally called with the setting disabled.