@@ -9,7 +9,7 @@ use anyhow::{anyhow, Context as _, Result};
9
9
use collections:: HashMap ;
10
10
use futures:: {
11
11
channel:: {
12
- mpsc:: { self , UnboundedReceiver , UnboundedSender } ,
12
+ mpsc:: { self , Sender , UnboundedReceiver , UnboundedSender } ,
13
13
oneshot,
14
14
} ,
15
15
future:: BoxFuture ,
@@ -28,7 +28,6 @@ use rpc::{
28
28
use smol:: {
29
29
fs,
30
30
process:: { self , Child , Stdio } ,
31
- Timer ,
32
31
} ;
33
32
use std:: {
34
33
any:: TypeId ,
@@ -441,6 +440,7 @@ impl SshRemoteClient {
441
440
cx. spawn ( |mut cx| async move {
442
441
let ( outgoing_tx, outgoing_rx) = mpsc:: unbounded :: < Envelope > ( ) ;
443
442
let ( incoming_tx, incoming_rx) = mpsc:: unbounded :: < Envelope > ( ) ;
443
+ let ( connection_activity_tx, connection_activity_rx) = mpsc:: channel :: < ( ) > ( 1 ) ;
444
444
445
445
let client = cx. update ( |cx| ChannelClient :: new ( incoming_rx, outgoing_tx, cx) ) ?;
446
446
let this = cx. new_model ( |_| Self {
@@ -467,6 +467,7 @@ impl SshRemoteClient {
467
467
ssh_proxy_process,
468
468
proxy_incoming_tx,
469
469
proxy_outgoing_rx,
470
+ connection_activity_tx,
470
471
& mut cx,
471
472
) ;
472
473
@@ -476,7 +477,7 @@ impl SshRemoteClient {
476
477
return Err ( error) ;
477
478
}
478
479
479
- let heartbeat_task = Self :: heartbeat ( this. downgrade ( ) , & mut cx) ;
480
+ let heartbeat_task = Self :: heartbeat ( this. downgrade ( ) , connection_activity_rx , & mut cx) ;
480
481
481
482
this. update ( & mut cx, |this, _| {
482
483
* this. state . lock ( ) = Some ( State :: Connected {
@@ -518,7 +519,7 @@ impl SshRemoteClient {
518
519
// We wait 50ms instead of waiting for a response, because
519
520
// waiting for a response would require us to wait on the main thread
520
521
// which we want to avoid in an `on_app_quit` callback.
521
- Timer :: after ( Duration :: from_millis ( 50 ) ) . await ;
522
+ smol :: Timer :: after ( Duration :: from_millis ( 50 ) ) . await ;
522
523
}
523
524
524
525
// Drop `multiplex_task` because it owns our ssh_proxy_process, which is a
@@ -632,6 +633,7 @@ impl SshRemoteClient {
632
633
let ( incoming_tx, outgoing_rx) = forwarder. into_channels ( ) . await ;
633
634
let ( forwarder, proxy_incoming_tx, proxy_outgoing_rx) =
634
635
ChannelForwarder :: new ( incoming_tx, outgoing_rx, & mut cx) ;
636
+ let ( connection_activity_tx, connection_activity_rx) = mpsc:: channel :: < ( ) > ( 1 ) ;
635
637
636
638
let ( ssh_connection, ssh_process) = match Self :: establish_connection (
637
639
identifier,
@@ -653,6 +655,7 @@ impl SshRemoteClient {
653
655
ssh_process,
654
656
proxy_incoming_tx,
655
657
proxy_outgoing_rx,
658
+ connection_activity_tx,
656
659
& mut cx,
657
660
) ;
658
661
@@ -665,7 +668,7 @@ impl SshRemoteClient {
665
668
delegate,
666
669
forwarder,
667
670
multiplex_task,
668
- heartbeat_task : Self :: heartbeat ( this. clone ( ) , & mut cx) ,
671
+ heartbeat_task : Self :: heartbeat ( this. clone ( ) , connection_activity_rx , & mut cx) ,
669
672
}
670
673
} ) ;
671
674
@@ -717,41 +720,60 @@ impl SshRemoteClient {
717
720
Ok ( ( ) )
718
721
}
719
722
720
- fn heartbeat ( this : WeakModel < Self > , cx : & mut AsyncAppContext ) -> Task < Result < ( ) > > {
723
+ fn heartbeat (
724
+ this : WeakModel < Self > ,
725
+ mut connection_activity_rx : mpsc:: Receiver < ( ) > ,
726
+ cx : & mut AsyncAppContext ,
727
+ ) -> Task < Result < ( ) > > {
721
728
let Ok ( client) = this. update ( cx, |this, _| this. client . clone ( ) ) else {
722
729
return Task :: ready ( Err ( anyhow ! ( "SshRemoteClient lost" ) ) ) ;
723
730
} ;
731
+
724
732
cx. spawn ( |mut cx| {
725
733
let this = this. clone ( ) ;
726
734
async move {
727
735
let mut missed_heartbeats = 0 ;
728
736
729
- let mut timer = Timer :: interval ( HEARTBEAT_INTERVAL ) ;
737
+ let keepalive_timer = cx. background_executor ( ) . timer ( HEARTBEAT_INTERVAL ) . fuse ( ) ;
738
+ futures:: pin_mut!( keepalive_timer) ;
739
+
730
740
loop {
731
- timer. next ( ) . await ;
732
-
733
- log:: debug!( "Sending heartbeat to server..." ) ;
734
-
735
- let result = client. ping ( HEARTBEAT_TIMEOUT ) . await ;
736
- if result. is_err ( ) {
737
- missed_heartbeats += 1 ;
738
- log:: warn!(
739
- "No heartbeat from server after {:?}. Missed heartbeat {} out of {}." ,
740
- HEARTBEAT_TIMEOUT ,
741
- missed_heartbeats,
742
- MAX_MISSED_HEARTBEATS
743
- ) ;
744
- } else if missed_heartbeats != 0 {
745
- missed_heartbeats = 0 ;
746
- } else {
747
- continue ;
748
- }
741
+ select_biased ! {
742
+ _ = connection_activity_rx. next( ) . fuse( ) => {
743
+ keepalive_timer. set( cx. background_executor( ) . timer( HEARTBEAT_INTERVAL ) . fuse( ) ) ;
744
+ }
745
+ _ = keepalive_timer => {
746
+ log:: debug!( "Sending heartbeat to server..." ) ;
749
747
750
- let result = this. update ( & mut cx, |this, mut cx| {
751
- this. handle_heartbeat_result ( missed_heartbeats, & mut cx)
752
- } ) ?;
753
- if result. is_break ( ) {
754
- return Ok ( ( ) ) ;
748
+ let result = select_biased! {
749
+ _ = connection_activity_rx. next( ) . fuse( ) => {
750
+ Ok ( ( ) )
751
+ }
752
+ ping_result = client. ping( HEARTBEAT_TIMEOUT ) . fuse( ) => {
753
+ ping_result
754
+ }
755
+ } ;
756
+ if result. is_err( ) {
757
+ missed_heartbeats += 1 ;
758
+ log:: warn!(
759
+ "No heartbeat from server after {:?}. Missed heartbeat {} out of {}." ,
760
+ HEARTBEAT_TIMEOUT ,
761
+ missed_heartbeats,
762
+ MAX_MISSED_HEARTBEATS
763
+ ) ;
764
+ } else if missed_heartbeats != 0 {
765
+ missed_heartbeats = 0 ;
766
+ } else {
767
+ continue ;
768
+ }
769
+
770
+ let result = this. update( & mut cx, |this, mut cx| {
771
+ this. handle_heartbeat_result( missed_heartbeats, & mut cx)
772
+ } ) ?;
773
+ if result. is_break( ) {
774
+ return Ok ( ( ) ) ;
775
+ }
776
+ }
755
777
}
756
778
}
757
779
}
@@ -792,6 +814,7 @@ impl SshRemoteClient {
792
814
mut ssh_proxy_process : Child ,
793
815
incoming_tx : UnboundedSender < Envelope > ,
794
816
mut outgoing_rx : UnboundedReceiver < Envelope > ,
817
+ mut connection_activity_tx : Sender < ( ) > ,
795
818
cx : & AsyncAppContext ,
796
819
) -> Task < Result < ( ) > > {
797
820
let mut child_stderr = ssh_proxy_process. stderr . take ( ) . unwrap ( ) ;
@@ -833,6 +856,7 @@ impl SshRemoteClient {
833
856
let message_len = message_len_from_buffer( & stdout_buffer) ;
834
857
match read_message_with_len( & mut child_stdout, & mut stdout_buffer, message_len) . await {
835
858
Ok ( envelope) => {
859
+ connection_activity_tx. try_send( ( ) ) . ok( ) ;
836
860
incoming_tx. unbounded_send( envelope) . ok( ) ;
837
861
}
838
862
Err ( error) => {
@@ -863,6 +887,8 @@ impl SshRemoteClient {
863
887
}
864
888
stderr_buffer. drain( 0 ..start_ix) ;
865
889
stderr_offset -= start_ix;
890
+
891
+ connection_activity_tx. try_send( ( ) ) . ok( ) ;
866
892
}
867
893
Err ( error) => {
868
894
Err ( anyhow!( "error reading stderr: {error:?}" ) ) ?;
@@ -1392,16 +1418,19 @@ impl ChannelClient {
1392
1418
cx. clone ( ) ,
1393
1419
) {
1394
1420
log:: debug!( "ssh message received. name:{type_name}" ) ;
1395
- match future. await {
1396
- Ok ( _) => {
1397
- log:: debug!( "ssh message handled. name:{type_name}" ) ;
1398
- }
1399
- Err ( error) => {
1400
- log:: error!(
1401
- "error handling message. type:{type_name}, error:{error}" ,
1402
- ) ;
1421
+ cx. foreground_executor ( ) . spawn ( async move {
1422
+ match future. await {
1423
+ Ok ( _) => {
1424
+ log:: debug!( "ssh message handled. name:{type_name}" ) ;
1425
+ }
1426
+ Err ( error) => {
1427
+ log:: error!(
1428
+ "error handling message. type:{type_name}, error:{error}" ,
1429
+ ) ;
1430
+ }
1403
1431
}
1404
- }
1432
+ } ) . detach ( ) ;
1433
+
1405
1434
} else {
1406
1435
log:: error!( "unhandled ssh message name:{type_name}" ) ;
1407
1436
}
0 commit comments