Skip to content

Commit

Permalink
fix: close iroh endpoint when I/O is stopped
Browse files Browse the repository at this point in the history
  • Loading branch information
link2xt committed Nov 26, 2024
1 parent 4026c82 commit 6087bb2
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 16 deletions.
9 changes: 5 additions & 4 deletions src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use async_channel::{self as channel, Receiver, Sender};
use pgp::types::PublicKeyTrait;
use pgp::SignedPublicKey;
use ratelimit::Ratelimit;
use tokio::sync::{Mutex, Notify, OnceCell, RwLock};
use tokio::sync::{Mutex, Notify, RwLock};

use crate::aheader::EncryptPreference;
use crate::chat::{get_chat_cnt, ChatId, ProtectionStatus};
Expand Down Expand Up @@ -292,7 +292,7 @@ pub struct InnerContext {
pub(crate) push_subscribed: AtomicBool,

/// Iroh for realtime peer channels.
pub(crate) iroh: OnceCell<Iroh>,
pub(crate) iroh: RwLock<Option<Arc<Iroh>>>,
}

/// The state of ongoing process.
Expand Down Expand Up @@ -450,7 +450,7 @@ impl Context {
debug_logging: std::sync::RwLock::new(None),
push_subscriber,
push_subscribed: AtomicBool::new(false),
iroh: OnceCell::new(),
iroh: RwLock::new(None),
};

let ctx = Context {
Expand Down Expand Up @@ -485,6 +485,7 @@ impl Context {

/// Stops the IO scheduler.
pub async fn stop_io(&self) {
*self.iroh.write().await = None;
self.scheduler.stop(self).await;
}

Expand All @@ -496,7 +497,7 @@ impl Context {

/// Indicate that the network likely has come back.
pub async fn maybe_network(&self) {
if let Some(iroh) = self.iroh.get() {
if let Some(ref iroh) = *self.iroh.read().await {
iroh.network_change().await;
}
self.scheduler.maybe_network().await;
Expand Down
56 changes: 44 additions & 12 deletions src/peer_channels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ use iroh_net::{NodeAddr, NodeId};
use parking_lot::Mutex;
use std::collections::{BTreeSet, HashMap};
use std::env;
use std::sync::Arc;
use tokio::sync::{oneshot, RwLock};
use tokio::task::JoinHandle;
use url::Url;
Expand Down Expand Up @@ -285,15 +286,25 @@ impl Context {
}

/// Get or initialize the iroh peer channel.
pub async fn get_or_try_init_peer_channel(&self) -> Result<&Iroh> {
pub async fn get_or_try_init_peer_channel(&self) -> Result<Arc<Iroh>> {
if !self.get_config_bool(Config::WebxdcRealtimeEnabled).await? {
bail!("Attempt to get Iroh when realtime is disabled");
}

let ctx = self.clone();
self.iroh
.get_or_try_init(|| async { ctx.init_peer_channels().await })
.await
{
let lock = self.iroh.read().await;
if let Some(ref iroh) = *lock {
return Ok(Arc::clone(iroh));
}
}

let mut lock = self.iroh.write().await;
if let Some(ref iroh) = *lock {
return Ok(Arc::clone(iroh));
}
let iroh = Arc::new(self.init_peer_channels().await?);
*lock = Some(iroh.clone());
Ok(iroh)
}
}

Expand Down Expand Up @@ -720,6 +731,12 @@ mod tests {
}
}
}

// Calling stop_io() closes iroh endpoint as well,
// even though I/O was not started in this test.
assert!(alice.iroh.read().await.is_some());
alice.stop_io().await;
assert!(alice.iroh.read().await.is_none());
}

#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
Expand Down Expand Up @@ -811,7 +828,9 @@ mod tests {
.unwrap();
let bob_sequence_number = bob
.iroh
.get()
.read()
.await
.clone()
.unwrap()
.sequence_numbers
.lock()
Expand All @@ -820,7 +839,9 @@ mod tests {
leave_webxdc_realtime(bob, bob_webxdc.id).await.unwrap();
let bob_sequence_number_after = bob
.iroh
.get()
.read()
.await
.clone()
.unwrap()
.sequence_numbers
.lock()
Expand Down Expand Up @@ -860,7 +881,16 @@ mod tests {
// bob for example does not change the channels because he never sends an
// advertisement
assert_eq!(
alice.iroh.get().unwrap().iroh_channels.read().await.len(),
alice
.iroh
.read()
.await
.clone()
.unwrap()
.iroh_channels
.read()
.await
.len(),
1
);
leave_webxdc_realtime(alice, alice_webxdc.id).await.unwrap();
Expand All @@ -870,7 +900,9 @@ mod tests {
.unwrap();
assert!(alice
.iroh
.get()
.read()
.await
.as_ref()
.unwrap()
.iroh_channels
.read()
Expand Down Expand Up @@ -963,19 +995,19 @@ mod tests {
.await
.unwrap();

assert!(alice.ctx.iroh.get().is_none());
assert!(alice.ctx.iroh.read().await.is_none());

// creates iroh endpoint as side effect
send_webxdc_realtime_data(alice, MsgId::new(1), vec![])
.await
.unwrap();

assert!(alice.ctx.iroh.get().is_none());
assert!(alice.ctx.iroh.read().await.is_none());

// creates iroh endpoint as side effect
leave_webxdc_realtime(alice, MsgId::new(1)).await.unwrap();

assert!(alice.ctx.iroh.get().is_none());
assert!(alice.ctx.iroh.read().await.is_none());

// This internal function should return error
// if accidentally called with the setting disabled.
Expand Down

0 comments on commit 6087bb2

Please sign in to comment.