@@ -113,8 +113,7 @@ impl PubSubDriver for PostgresDriver {
113113 tracing:: debug!( %subject, ?lock_id, "calculated advisory lock id" ) ;
114114
115115 // Create a single connection for both subscription and lock holding
116- let ( client, mut connection) =
117- tokio_postgres:: connect ( & self . conn_str , tokio_postgres:: NoTls ) . await ?;
116+ let ( client, mut connection) = pg_connect ( & self . conn_str ) . await ?;
118117
119118 // Set up message forwarding
120119 let ( tx, rx) = tokio:: sync:: mpsc:: unbounded_channel :: < String > ( ) ;
@@ -155,7 +154,7 @@ impl PubSubDriver for PostgresDriver {
155154 let listen_subject = subject_owned. clone ( ) ;
156155
157156 // Spawn task to handle connection, lock acquisition, and LISTEN
158- tokio:: spawn ( async move {
157+ let poll_handle = tokio:: spawn ( async move {
159158 // First acquire the lock while polling the connection
160159 let lock_sql = format ! ( "SELECT pg_try_advisory_lock_shared({})" , lock_id) ;
161160 let lock_future = client_clone. query_one ( & lock_sql, & [ ] ) ;
@@ -265,6 +264,7 @@ impl PubSubDriver for PostgresDriver {
265264 lock_id,
266265 client,
267266 subject : subject. to_string ( ) ,
267+ poll_handle,
268268 } ) )
269269 }
270270
@@ -419,16 +419,15 @@ impl PubSubDriver for PostgresDriver {
419419 // Create a temporary reply subject and a dedicated listener connection
420420 let reply_subject = format ! ( "_INBOX.{}" , uuid:: Uuid :: new_v4( ) ) ;
421421
422- let ( client, mut connection) =
423- tokio_postgres:: connect ( & self . conn_str , tokio_postgres:: NoTls ) . await ?;
422+ let ( client, mut connection) = pg_connect ( & self . conn_str ) . await ?;
424423
425424 // Setup connection and LISTEN in a task
426425 let ( listen_done_tx, listen_done_rx) = tokio:: sync:: oneshot:: channel ( ) ;
427426 let reply_subject_clone = reply_subject. clone ( ) ;
428427
429428 // Spawn task to handle connection and LISTEN
430429 let ( response_tx, mut response_rx) = tokio:: sync:: mpsc:: unbounded_channel ( ) ;
431- tokio:: spawn ( async move {
430+ let poll_handle = tokio:: spawn ( async move {
432431 // Convert subject to base64 hash string because Postgres identifiers can only be 63 bytes
433432 let mut hasher = DefaultHasher :: new ( ) ;
434433 reply_subject_clone. hash ( & mut hasher) ;
@@ -513,14 +512,19 @@ impl PubSubDriver for PostgresDriver {
513512 } ;
514513
515514 // Apply timeout if specified
516- if let Some ( dur) = timeout {
515+ let res = if let Some ( dur) = timeout {
517516 match tokio:: time:: timeout ( dur, response_future) . await {
518517 std:: result:: Result :: Ok ( resp) => resp,
519518 std:: result:: Result :: Err ( _) => Err ( errors:: Ups :: RequestTimeout . build ( ) . into ( ) ) ,
520519 }
521520 } else {
522521 response_future. await
523- }
522+ } ;
523+
524+ // Stop poll loop
525+ poll_handle. abort ( ) ;
526+
527+ res
524528 }
525529
526530 async fn send_request_reply ( & self , reply : & str , payload : & [ u8 ] ) -> Result < ( ) > {
@@ -599,6 +603,7 @@ pub struct PostgresSubscriber {
599603 lock_id : i64 ,
600604 client : Arc < tokio_postgres:: Client > ,
601605 subject : String ,
606+ poll_handle : tokio:: task:: JoinHandle < ( ) > ,
602607}
603608
604609#[ async_trait]
@@ -716,5 +721,22 @@ impl Drop for PostgresSubscriber {
716721 . execute ( "SELECT pg_advisory_unlock_shared($1)" , & [ & lock_id] )
717722 . await ;
718723 } ) ;
724+
725+ // Stop polling task
726+ self . poll_handle . abort ( ) ;
719727 }
720728}
729+
730+ async fn pg_connect (
731+ conn_str : & str ,
732+ ) -> Result < (
733+ tokio_postgres:: Client ,
734+ tokio_postgres:: Connection <
735+ tokio_postgres:: Socket ,
736+ <tokio_postgres:: NoTls as tokio_postgres:: tls:: TlsConnect < tokio_postgres:: Socket > >:: Stream ,
737+ > ,
738+ ) > {
739+ let ( client, conn) = tokio_postgres:: connect ( conn_str, tokio_postgres:: NoTls ) . await ?;
740+
741+ Ok ( ( client, conn) )
742+ }
0 commit comments