Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: close iroh endpoint when I/O is stopped #6264

Merged
merged 1 commit into from
Nov 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use async_channel::{self as channel, Receiver, Sender};
use pgp::types::PublicKeyTrait;
use pgp::SignedPublicKey;
use ratelimit::Ratelimit;
use tokio::sync::{Mutex, Notify, OnceCell, RwLock};
use tokio::sync::{Mutex, Notify, RwLock};

use crate::aheader::EncryptPreference;
use crate::chat::{get_chat_cnt, ChatId, ProtectionStatus};
Expand Down Expand Up @@ -292,7 +292,7 @@ pub struct InnerContext {
pub(crate) push_subscribed: AtomicBool,

/// Iroh for realtime peer channels.
pub(crate) iroh: OnceCell<Iroh>,
pub(crate) iroh: Arc<RwLock<Option<Iroh>>>,
}

/// The state of ongoing process.
Expand Down Expand Up @@ -450,7 +450,7 @@ impl Context {
debug_logging: std::sync::RwLock::new(None),
push_subscriber,
push_subscribed: AtomicBool::new(false),
iroh: OnceCell::new(),
iroh: Arc::new(RwLock::new(None)),
};

let ctx = Context {
Expand Down Expand Up @@ -486,6 +486,19 @@ impl Context {
/// Stops the IO scheduler.
pub async fn stop_io(&self) {
self.scheduler.stop(self).await;
if let Some(iroh) = self.iroh.write().await.take() {
// Close all QUIC connections.

// Spawn into a separate task,
// because Iroh calls `wait_idle()` internally
// and it may take time, especially if the network
// has become unavailable.
tokio::spawn(async move {
// We do not log the error because we do not want the task
// to hold the reference to Context.
let _ = tokio::time::timeout(Duration::from_secs(60), iroh.close()).await;
});
}
}

/// Restarts the IO scheduler if it was running before
Expand All @@ -496,7 +509,7 @@ impl Context {

/// Indicate that the network likely has come back.
pub async fn maybe_network(&self) {
if let Some(iroh) = self.iroh.get() {
if let Some(ref iroh) = *self.iroh.read().await {
iroh.network_change().await;
}
self.scheduler.maybe_network().await;
Expand Down
144 changes: 117 additions & 27 deletions src/peer_channels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,14 @@ impl Iroh {
self.endpoint.network_change().await
}

/// Closes the QUIC endpoint.
pub(crate) async fn close(self) -> Result<()> {
self.endpoint
.close(0u32.into(), b"")
.await
.context("Closing iroh endpoint failed")
}

/// Join a topic and create the subscriber loop for it.
///
/// If there is no gossip, create it.
Expand Down Expand Up @@ -285,15 +293,36 @@ impl Context {
}

/// Get or initialize the iroh peer channel.
pub async fn get_or_try_init_peer_channel(&self) -> Result<&Iroh> {
pub async fn get_or_try_init_peer_channel(
&self,
) -> Result<tokio::sync::RwLockReadGuard<'_, Iroh>> {
if !self.get_config_bool(Config::WebxdcRealtimeEnabled).await? {
bail!("Attempt to get Iroh when realtime is disabled");
}

let ctx = self.clone();
self.iroh
.get_or_try_init(|| async { ctx.init_peer_channels().await })
.await
if let Ok(lock) = tokio::sync::RwLockReadGuard::<'_, std::option::Option<Iroh>>::try_map(
self.iroh.read().await,
|opt_iroh| opt_iroh.as_ref(),
) {
return Ok(lock);
}

let lock = self.iroh.write().await;
match tokio::sync::RwLockWriteGuard::<'_, std::option::Option<Iroh>>::try_downgrade_map(
lock,
|opt_iroh| opt_iroh.as_ref(),
) {
Ok(lock) => Ok(lock),
Err(mut lock) => {
let iroh = self.init_peer_channels().await?;
*lock = Some(iroh);
tokio::sync::RwLockWriteGuard::<'_, std::option::Option<Iroh>>::try_downgrade_map(
lock,
|opt_iroh| opt_iroh.as_ref(),
)
.map_err(|_| anyhow!("Downgrade should succeed as we just stored `Some` value"))
}
}
}
}

Expand Down Expand Up @@ -626,7 +655,6 @@ mod tests {
break;
}
}
let bob_iroh = bob.get_or_try_init_peer_channel().await.unwrap();

// Bob adds alice to gossip peers.
let members = get_iroh_gossip_peers(bob, bob_webxdc.id)
Expand All @@ -636,13 +664,23 @@ mod tests {
.map(|addr| addr.node_id)
.collect::<Vec<_>>();

let alice_iroh = alice.get_or_try_init_peer_channel().await.unwrap();
assert_eq!(
members,
vec![alice_iroh.get_node_addr().await.unwrap().node_id]
vec![
alice
.get_or_try_init_peer_channel()
.await
.unwrap()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor, but i'd return Result from tests to get rid of unwrap() everywhere, it really eats much screen space :)

.get_node_addr()
.await
.unwrap()
.node_id
]
);

bob_iroh
bob.get_or_try_init_peer_channel()
.await
.unwrap()
.join_and_subscribe_gossip(bob, bob_webxdc.id)
.await
.unwrap()
Expand All @@ -651,7 +689,10 @@ mod tests {
.unwrap();

// Alice sends ephemeral message
alice_iroh
alice
.get_or_try_init_peer_channel()
.await
.unwrap()
.send_webxdc_realtime_data(alice, alice_webxdc.id, "alice -> bob".as_bytes().to_vec())
.await
.unwrap();
Expand All @@ -670,7 +711,9 @@ mod tests {
}
}
// Bob sends ephemeral message
bob_iroh
bob.get_or_try_init_peer_channel()
.await
.unwrap()
.send_webxdc_realtime_data(bob, bob_webxdc.id, "bob -> alice".as_bytes().to_vec())
.await
.unwrap();
Expand Down Expand Up @@ -699,10 +742,20 @@ mod tests {

assert_eq!(
members,
vec![bob_iroh.get_node_addr().await.unwrap().node_id]
vec![
bob.get_or_try_init_peer_channel()
.await
.unwrap()
.get_node_addr()
.await
.unwrap()
.node_id
]
);

bob_iroh
bob.get_or_try_init_peer_channel()
.await
.unwrap()
.send_webxdc_realtime_data(bob, bob_webxdc.id, "bob -> alice 2".as_bytes().to_vec())
.await
.unwrap();
Expand All @@ -720,6 +773,12 @@ mod tests {
}
}
}

// Calling stop_io() closes iroh endpoint as well,
// even though I/O was not started in this test.
assert!(alice.iroh.read().await.is_some());
alice.stop_io().await;
assert!(alice.iroh.read().await.is_none());
}

#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
Expand Down Expand Up @@ -761,7 +820,6 @@ mod tests {
.unwrap();

bob.recv_msg_trash(&alice.pop_sent_msg().await).await;
let bob_iroh = bob.get_or_try_init_peer_channel().await.unwrap();

// Bob adds alice to gossip peers.
let members = get_iroh_gossip_peers(bob, bob_webxdc.id)
Expand All @@ -771,13 +829,23 @@ mod tests {
.map(|addr| addr.node_id)
.collect::<Vec<_>>();

let alice_iroh = alice.get_or_try_init_peer_channel().await.unwrap();
assert_eq!(
members,
vec![alice_iroh.get_node_addr().await.unwrap().node_id]
vec![
alice
.get_or_try_init_peer_channel()
.await
.unwrap()
.get_node_addr()
.await
.unwrap()
.node_id
]
);

bob_iroh
bob.get_or_try_init_peer_channel()
.await
.unwrap()
.join_and_subscribe_gossip(bob, bob_webxdc.id)
.await
.unwrap()
Expand All @@ -786,7 +854,10 @@ mod tests {
.unwrap();

// Alice sends ephemeral message
alice_iroh
alice
.get_or_try_init_peer_channel()
.await
.unwrap()
.send_webxdc_realtime_data(alice, alice_webxdc.id, "alice -> bob".as_bytes().to_vec())
.await
.unwrap();
Expand All @@ -811,7 +882,9 @@ mod tests {
.unwrap();
let bob_sequence_number = bob
.iroh
.get()
.read()
.await
.as_ref()
.unwrap()
.sequence_numbers
.lock()
Expand All @@ -820,7 +893,9 @@ mod tests {
leave_webxdc_realtime(bob, bob_webxdc.id).await.unwrap();
let bob_sequence_number_after = bob
.iroh
.get()
.read()
.await
.as_ref()
.unwrap()
.sequence_numbers
.lock()
Expand All @@ -829,15 +904,19 @@ mod tests {
// Check that sequence number is persisted when leaving the channel.
assert_eq!(bob_sequence_number, bob_sequence_number_after);

bob_iroh
bob.get_or_try_init_peer_channel()
.await
.unwrap()
.join_and_subscribe_gossip(bob, bob_webxdc.id)
.await
.unwrap()
.unwrap()
.await
.unwrap();

bob_iroh
bob.get_or_try_init_peer_channel()
.await
.unwrap()
.send_webxdc_realtime_data(bob, bob_webxdc.id, "bob -> alice".as_bytes().to_vec())
.await
.unwrap();
Expand All @@ -860,7 +939,16 @@ mod tests {
// bob for example does not change the channels because he never sends an
// advertisement
assert_eq!(
alice.iroh.get().unwrap().iroh_channels.read().await.len(),
alice
.iroh
.read()
.await
.as_ref()
.unwrap()
.iroh_channels
.read()
.await
.len(),
1
);
leave_webxdc_realtime(alice, alice_webxdc.id).await.unwrap();
Expand All @@ -870,7 +958,9 @@ mod tests {
.unwrap();
assert!(alice
.iroh
.get()
.read()
.await
.as_ref()
.unwrap()
.iroh_channels
.read()
Expand Down Expand Up @@ -963,19 +1053,19 @@ mod tests {
.await
.unwrap();

assert!(alice.ctx.iroh.get().is_none());
assert!(alice.ctx.iroh.read().await.is_none());

// creates iroh endpoint as side effect
send_webxdc_realtime_data(alice, MsgId::new(1), vec![])
.await
.unwrap();

assert!(alice.ctx.iroh.get().is_none());
assert!(alice.ctx.iroh.read().await.is_none());

// creates iroh endpoint as side effect
leave_webxdc_realtime(alice, MsgId::new(1)).await.unwrap();

assert!(alice.ctx.iroh.get().is_none());
assert!(alice.ctx.iroh.read().await.is_none());

// This internal function should return error
// if accidentally called with the setting disabled.
Expand Down
Loading