Skip to content

Commit 0eb6bfd

Browse files
bennetboThorstenas-cii
authored
ssh remoting: Treat other message as heartbeat (zed-industries#19219)
This improves the heartbeat detection logic. We now treat any other incoming message from the ssh remote server as a heartbeat message, meaning that we can detect re-connects earlier. It also changes the connection handling to await futures detached. Co-Authored-by: Thorsten <[email protected]> Release Notes: - N/A --------- Co-authored-by: Thorsten <[email protected]> Co-authored-by: Antonio <[email protected]>
1 parent 4fa75a7 commit 0eb6bfd

File tree

1 file changed

+68
-39
lines changed

1 file changed

+68
-39
lines changed

crates/remote/src/ssh_session.rs

+68-39
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use anyhow::{anyhow, Context as _, Result};
99
use collections::HashMap;
1010
use futures::{
1111
channel::{
12-
mpsc::{self, UnboundedReceiver, UnboundedSender},
12+
mpsc::{self, Sender, UnboundedReceiver, UnboundedSender},
1313
oneshot,
1414
},
1515
future::BoxFuture,
@@ -28,7 +28,6 @@ use rpc::{
2828
use smol::{
2929
fs,
3030
process::{self, Child, Stdio},
31-
Timer,
3231
};
3332
use std::{
3433
any::TypeId,
@@ -441,6 +440,7 @@ impl SshRemoteClient {
441440
cx.spawn(|mut cx| async move {
442441
let (outgoing_tx, outgoing_rx) = mpsc::unbounded::<Envelope>();
443442
let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
443+
let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1);
444444

445445
let client = cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx))?;
446446
let this = cx.new_model(|_| Self {
@@ -467,6 +467,7 @@ impl SshRemoteClient {
467467
ssh_proxy_process,
468468
proxy_incoming_tx,
469469
proxy_outgoing_rx,
470+
connection_activity_tx,
470471
&mut cx,
471472
);
472473

@@ -476,7 +477,7 @@ impl SshRemoteClient {
476477
return Err(error);
477478
}
478479

479-
let heartbeat_task = Self::heartbeat(this.downgrade(), &mut cx);
480+
let heartbeat_task = Self::heartbeat(this.downgrade(), connection_activity_rx, &mut cx);
480481

481482
this.update(&mut cx, |this, _| {
482483
*this.state.lock() = Some(State::Connected {
@@ -518,7 +519,7 @@ impl SshRemoteClient {
518519
// We wait 50ms instead of waiting for a response, because
519520
// waiting for a response would require us to wait on the main thread
520521
// which we want to avoid in an `on_app_quit` callback.
521-
Timer::after(Duration::from_millis(50)).await;
522+
smol::Timer::after(Duration::from_millis(50)).await;
522523
}
523524

524525
// Drop `multiplex_task` because it owns our ssh_proxy_process, which is a
@@ -632,6 +633,7 @@ impl SshRemoteClient {
632633
let (incoming_tx, outgoing_rx) = forwarder.into_channels().await;
633634
let (forwarder, proxy_incoming_tx, proxy_outgoing_rx) =
634635
ChannelForwarder::new(incoming_tx, outgoing_rx, &mut cx);
636+
let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1);
635637

636638
let (ssh_connection, ssh_process) = match Self::establish_connection(
637639
identifier,
@@ -653,6 +655,7 @@ impl SshRemoteClient {
653655
ssh_process,
654656
proxy_incoming_tx,
655657
proxy_outgoing_rx,
658+
connection_activity_tx,
656659
&mut cx,
657660
);
658661

@@ -665,7 +668,7 @@ impl SshRemoteClient {
665668
delegate,
666669
forwarder,
667670
multiplex_task,
668-
heartbeat_task: Self::heartbeat(this.clone(), &mut cx),
671+
heartbeat_task: Self::heartbeat(this.clone(), connection_activity_rx, &mut cx),
669672
}
670673
});
671674

@@ -717,41 +720,60 @@ impl SshRemoteClient {
717720
Ok(())
718721
}
719722

720-
fn heartbeat(this: WeakModel<Self>, cx: &mut AsyncAppContext) -> Task<Result<()>> {
723+
fn heartbeat(
724+
this: WeakModel<Self>,
725+
mut connection_activity_rx: mpsc::Receiver<()>,
726+
cx: &mut AsyncAppContext,
727+
) -> Task<Result<()>> {
721728
let Ok(client) = this.update(cx, |this, _| this.client.clone()) else {
722729
return Task::ready(Err(anyhow!("SshRemoteClient lost")));
723730
};
731+
724732
cx.spawn(|mut cx| {
725733
let this = this.clone();
726734
async move {
727735
let mut missed_heartbeats = 0;
728736

729-
let mut timer = Timer::interval(HEARTBEAT_INTERVAL);
737+
let keepalive_timer = cx.background_executor().timer(HEARTBEAT_INTERVAL).fuse();
738+
futures::pin_mut!(keepalive_timer);
739+
730740
loop {
731-
timer.next().await;
732-
733-
log::debug!("Sending heartbeat to server...");
734-
735-
let result = client.ping(HEARTBEAT_TIMEOUT).await;
736-
if result.is_err() {
737-
missed_heartbeats += 1;
738-
log::warn!(
739-
"No heartbeat from server after {:?}. Missed heartbeat {} out of {}.",
740-
HEARTBEAT_TIMEOUT,
741-
missed_heartbeats,
742-
MAX_MISSED_HEARTBEATS
743-
);
744-
} else if missed_heartbeats != 0 {
745-
missed_heartbeats = 0;
746-
} else {
747-
continue;
748-
}
741+
select_biased! {
742+
_ = connection_activity_rx.next().fuse() => {
743+
keepalive_timer.set(cx.background_executor().timer(HEARTBEAT_INTERVAL).fuse());
744+
}
745+
_ = keepalive_timer => {
746+
log::debug!("Sending heartbeat to server...");
749747

750-
let result = this.update(&mut cx, |this, mut cx| {
751-
this.handle_heartbeat_result(missed_heartbeats, &mut cx)
752-
})?;
753-
if result.is_break() {
754-
return Ok(());
748+
let result = select_biased! {
749+
_ = connection_activity_rx.next().fuse() => {
750+
Ok(())
751+
}
752+
ping_result = client.ping(HEARTBEAT_TIMEOUT).fuse() => {
753+
ping_result
754+
}
755+
};
756+
if result.is_err() {
757+
missed_heartbeats += 1;
758+
log::warn!(
759+
"No heartbeat from server after {:?}. Missed heartbeat {} out of {}.",
760+
HEARTBEAT_TIMEOUT,
761+
missed_heartbeats,
762+
MAX_MISSED_HEARTBEATS
763+
);
764+
} else if missed_heartbeats != 0 {
765+
missed_heartbeats = 0;
766+
} else {
767+
continue;
768+
}
769+
770+
let result = this.update(&mut cx, |this, mut cx| {
771+
this.handle_heartbeat_result(missed_heartbeats, &mut cx)
772+
})?;
773+
if result.is_break() {
774+
return Ok(());
775+
}
776+
}
755777
}
756778
}
757779
}
@@ -792,6 +814,7 @@ impl SshRemoteClient {
792814
mut ssh_proxy_process: Child,
793815
incoming_tx: UnboundedSender<Envelope>,
794816
mut outgoing_rx: UnboundedReceiver<Envelope>,
817+
mut connection_activity_tx: Sender<()>,
795818
cx: &AsyncAppContext,
796819
) -> Task<Result<()>> {
797820
let mut child_stderr = ssh_proxy_process.stderr.take().unwrap();
@@ -833,6 +856,7 @@ impl SshRemoteClient {
833856
let message_len = message_len_from_buffer(&stdout_buffer);
834857
match read_message_with_len(&mut child_stdout, &mut stdout_buffer, message_len).await {
835858
Ok(envelope) => {
859+
connection_activity_tx.try_send(()).ok();
836860
incoming_tx.unbounded_send(envelope).ok();
837861
}
838862
Err(error) => {
@@ -863,6 +887,8 @@ impl SshRemoteClient {
863887
}
864888
stderr_buffer.drain(0..start_ix);
865889
stderr_offset -= start_ix;
890+
891+
connection_activity_tx.try_send(()).ok();
866892
}
867893
Err(error) => {
868894
Err(anyhow!("error reading stderr: {error:?}"))?;
@@ -1392,16 +1418,19 @@ impl ChannelClient {
13921418
cx.clone(),
13931419
) {
13941420
log::debug!("ssh message received. name:{type_name}");
1395-
match future.await {
1396-
Ok(_) => {
1397-
log::debug!("ssh message handled. name:{type_name}");
1398-
}
1399-
Err(error) => {
1400-
log::error!(
1401-
"error handling message. type:{type_name}, error:{error}",
1402-
);
1421+
cx.foreground_executor().spawn(async move {
1422+
match future.await {
1423+
Ok(_) => {
1424+
log::debug!("ssh message handled. name:{type_name}");
1425+
}
1426+
Err(error) => {
1427+
log::error!(
1428+
"error handling message. type:{type_name}, error:{error}",
1429+
);
1430+
}
14031431
}
1404-
}
1432+
}).detach();
1433+
14051434
} else {
14061435
log::error!("unhandled ssh message name:{type_name}");
14071436
}

0 commit comments

Comments
 (0)