Skip to content

Commit

Permalink
fix: close iroh endpoint when I/O is stopped
Browse files Browse the repository at this point in the history
  • Loading branch information
link2xt committed Nov 27, 2024
1 parent 4026c82 commit f918ecb
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 31 deletions.
14 changes: 10 additions & 4 deletions src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use async_channel::{self as channel, Receiver, Sender};
use pgp::types::PublicKeyTrait;
use pgp::SignedPublicKey;
use ratelimit::Ratelimit;
use tokio::sync::{Mutex, Notify, OnceCell, RwLock};
use tokio::sync::{Mutex, Notify, RwLock};

use crate::aheader::EncryptPreference;
use crate::chat::{get_chat_cnt, ChatId, ProtectionStatus};
Expand Down Expand Up @@ -292,7 +292,7 @@ pub struct InnerContext {
pub(crate) push_subscribed: AtomicBool,

/// Iroh for realtime peer channels.
pub(crate) iroh: OnceCell<Iroh>,
pub(crate) iroh: Arc<RwLock<Option<Iroh>>>,
}

/// The state of ongoing process.
Expand Down Expand Up @@ -450,7 +450,7 @@ impl Context {
debug_logging: std::sync::RwLock::new(None),
push_subscriber,
push_subscribed: AtomicBool::new(false),
iroh: OnceCell::new(),
iroh: Arc::new(RwLock::new(None)),
};

let ctx = Context {
Expand Down Expand Up @@ -485,6 +485,12 @@ impl Context {

/// Stops the IO scheduler.
pub async fn stop_io(&self) {
if let Some(iroh) = self.iroh.write().await.take() {
// Close all QUIC connections.
if let Err(err) = iroh.close().await {
warn!(self, "Failed to close Iroh: {err:#}.");
}
}
self.scheduler.stop(self).await;
}

Expand All @@ -496,7 +502,7 @@ impl Context {

/// Indicate that the network likely has come back.
pub async fn maybe_network(&self) {
if let Some(iroh) = self.iroh.get() {
if let Some(ref iroh) = *self.iroh.read().await {
iroh.network_change().await;
}
self.scheduler.maybe_network().await;
Expand Down
146 changes: 119 additions & 27 deletions src/peer_channels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,14 @@ impl Iroh {
self.endpoint.network_change().await
}

/// Closes the QUIC endpoint.
pub(crate) async fn close(self) -> Result<()> {
self.endpoint
.close(0u32.into(), b"")
.await
.context("Closing iroh endpoint failed")
}

/// Join a topic and create the subscriber loop for it.
///
/// If there is no gossip, create it.
Expand Down Expand Up @@ -285,15 +293,38 @@ impl Context {
}

/// Get or initialize the iroh peer channel.
pub async fn get_or_try_init_peer_channel(&self) -> Result<&Iroh> {
pub async fn get_or_try_init_peer_channel(
&self,
) -> Result<tokio::sync::RwLockReadGuard<'_, Iroh>> {
if !self.get_config_bool(Config::WebxdcRealtimeEnabled).await? {
bail!("Attempt to get Iroh when realtime is disabled");
}

let ctx = self.clone();
self.iroh
.get_or_try_init(|| async { ctx.init_peer_channels().await })
.await
{
if let Ok(lock) = tokio::sync::RwLockReadGuard::<'_, std::option::Option<Iroh>>::try_map(
self.iroh.read().await,
|opt_iroh| opt_iroh.as_ref(),
) {
return Ok(lock);
}
}

let lock = self.iroh.write().await;
match tokio::sync::RwLockWriteGuard::<'_, std::option::Option<Iroh>>::try_downgrade_map(
lock,
|opt_iroh| opt_iroh.as_ref(),
) {
Ok(lock) => Ok(lock),
Err(mut lock) => {
let iroh = self.init_peer_channels().await?;
*lock = Some(iroh);
tokio::sync::RwLockWriteGuard::<'_, std::option::Option<Iroh>>::try_downgrade_map(
lock,
|opt_iroh| opt_iroh.as_ref(),
)
.map_err(|_| anyhow!("Downgrade should succeed as we just stored `Some` value"))
}
}
}
}

Expand Down Expand Up @@ -626,7 +657,6 @@ mod tests {
break;
}
}
let bob_iroh = bob.get_or_try_init_peer_channel().await.unwrap();

// Bob adds alice to gossip peers.
let members = get_iroh_gossip_peers(bob, bob_webxdc.id)
Expand All @@ -636,13 +666,23 @@ mod tests {
.map(|addr| addr.node_id)
.collect::<Vec<_>>();

let alice_iroh = alice.get_or_try_init_peer_channel().await.unwrap();
assert_eq!(
members,
vec![alice_iroh.get_node_addr().await.unwrap().node_id]
vec![
alice
.get_or_try_init_peer_channel()
.await
.unwrap()
.get_node_addr()
.await
.unwrap()
.node_id
]
);

bob_iroh
bob.get_or_try_init_peer_channel()
.await
.unwrap()
.join_and_subscribe_gossip(bob, bob_webxdc.id)
.await
.unwrap()
Expand All @@ -651,7 +691,10 @@ mod tests {
.unwrap();

// Alice sends ephemeral message
alice_iroh
alice
.get_or_try_init_peer_channel()
.await
.unwrap()
.send_webxdc_realtime_data(alice, alice_webxdc.id, "alice -> bob".as_bytes().to_vec())
.await
.unwrap();
Expand All @@ -670,7 +713,9 @@ mod tests {
}
}
// Bob sends ephemeral message
bob_iroh
bob.get_or_try_init_peer_channel()
.await
.unwrap()
.send_webxdc_realtime_data(bob, bob_webxdc.id, "bob -> alice".as_bytes().to_vec())
.await
.unwrap();
Expand Down Expand Up @@ -699,10 +744,20 @@ mod tests {

assert_eq!(
members,
vec![bob_iroh.get_node_addr().await.unwrap().node_id]
vec![
bob.get_or_try_init_peer_channel()
.await
.unwrap()
.get_node_addr()
.await
.unwrap()
.node_id
]
);

bob_iroh
bob.get_or_try_init_peer_channel()
.await
.unwrap()
.send_webxdc_realtime_data(bob, bob_webxdc.id, "bob -> alice 2".as_bytes().to_vec())
.await
.unwrap();
Expand All @@ -720,6 +775,12 @@ mod tests {
}
}
}

// Calling stop_io() closes iroh endpoint as well,
// even though I/O was not started in this test.
assert!(alice.iroh.read().await.is_some());
alice.stop_io().await;
assert!(alice.iroh.read().await.is_none());
}

#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
Expand Down Expand Up @@ -761,7 +822,6 @@ mod tests {
.unwrap();

bob.recv_msg_trash(&alice.pop_sent_msg().await).await;
let bob_iroh = bob.get_or_try_init_peer_channel().await.unwrap();

// Bob adds alice to gossip peers.
let members = get_iroh_gossip_peers(bob, bob_webxdc.id)
Expand All @@ -771,13 +831,23 @@ mod tests {
.map(|addr| addr.node_id)
.collect::<Vec<_>>();

let alice_iroh = alice.get_or_try_init_peer_channel().await.unwrap();
assert_eq!(
members,
vec![alice_iroh.get_node_addr().await.unwrap().node_id]
vec![
alice
.get_or_try_init_peer_channel()
.await
.unwrap()
.get_node_addr()
.await
.unwrap()
.node_id
]
);

bob_iroh
bob.get_or_try_init_peer_channel()
.await
.unwrap()
.join_and_subscribe_gossip(bob, bob_webxdc.id)
.await
.unwrap()
Expand All @@ -786,7 +856,10 @@ mod tests {
.unwrap();

// Alice sends ephemeral message
alice_iroh
alice
.get_or_try_init_peer_channel()
.await
.unwrap()
.send_webxdc_realtime_data(alice, alice_webxdc.id, "alice -> bob".as_bytes().to_vec())
.await
.unwrap();
Expand All @@ -811,7 +884,9 @@ mod tests {
.unwrap();
let bob_sequence_number = bob
.iroh
.get()
.read()
.await
.as_ref()
.unwrap()
.sequence_numbers
.lock()
Expand All @@ -820,7 +895,9 @@ mod tests {
leave_webxdc_realtime(bob, bob_webxdc.id).await.unwrap();
let bob_sequence_number_after = bob
.iroh
.get()
.read()
.await
.as_ref()
.unwrap()
.sequence_numbers
.lock()
Expand All @@ -829,15 +906,19 @@ mod tests {
// Check that sequence number is persisted when leaving the channel.
assert_eq!(bob_sequence_number, bob_sequence_number_after);

bob_iroh
bob.get_or_try_init_peer_channel()
.await
.unwrap()
.join_and_subscribe_gossip(bob, bob_webxdc.id)
.await
.unwrap()
.unwrap()
.await
.unwrap();

bob_iroh
bob.get_or_try_init_peer_channel()
.await
.unwrap()
.send_webxdc_realtime_data(bob, bob_webxdc.id, "bob -> alice".as_bytes().to_vec())
.await
.unwrap();
Expand All @@ -860,7 +941,16 @@ mod tests {
// bob for example does not change the channels because he never sends an
// advertisement
assert_eq!(
alice.iroh.get().unwrap().iroh_channels.read().await.len(),
alice
.iroh
.read()
.await
.as_ref()
.unwrap()
.iroh_channels
.read()
.await
.len(),
1
);
leave_webxdc_realtime(alice, alice_webxdc.id).await.unwrap();
Expand All @@ -870,7 +960,9 @@ mod tests {
.unwrap();
assert!(alice
.iroh
.get()
.read()
.await
.as_ref()
.unwrap()
.iroh_channels
.read()
Expand Down Expand Up @@ -963,19 +1055,19 @@ mod tests {
.await
.unwrap();

assert!(alice.ctx.iroh.get().is_none());
assert!(alice.ctx.iroh.read().await.is_none());

// creates iroh endpoint as side effect
send_webxdc_realtime_data(alice, MsgId::new(1), vec![])
.await
.unwrap();

assert!(alice.ctx.iroh.get().is_none());
assert!(alice.ctx.iroh.read().await.is_none());

// creates iroh endpoint as side effect
leave_webxdc_realtime(alice, MsgId::new(1)).await.unwrap();

assert!(alice.ctx.iroh.get().is_none());
assert!(alice.ctx.iroh.read().await.is_none());

// This internal function should return error
// if accidentally called with the setting disabled.
Expand Down

0 comments on commit f918ecb

Please sign in to comment.