@@ -112,21 +112,23 @@ impl PostgresDriver {
112112 client : Arc < Mutex < Option < Arc < tokio_postgres:: Client > > > > ,
113113 ready_tx : tokio:: sync:: watch:: Sender < bool > ,
114114 ) {
115- let mut backoff = Backoff :: new ( 8 , None , 1_000 , 1_000 ) ;
115+ let mut backoff = Backoff :: default ( ) ;
116116
117117 loop {
118118 match tokio_postgres:: connect ( & conn_str, tokio_postgres:: NoTls ) . await {
119119 Result :: Ok ( ( new_client, conn) ) => {
120120 tracing:: info!( "postgres listen connection established" ) ;
121121 // Reset backoff on successful connection
122- backoff = Backoff :: new ( 8 , None , 1_000 , 1_000 ) ;
122+ backoff = Backoff :: default ( ) ;
123123
124124 let new_client = Arc :: new ( new_client) ;
125125
126- // Update the client reference immediately
127- * client. lock ( ) . await = Some ( new_client. clone ( ) ) ;
128- // Notify that client is ready
129- let _ = ready_tx. send ( true ) ;
126+ // Spawn the polling task immediately
127+ // This must be done before any operations on the client
128+ let subscriptions_clone = subscriptions. clone ( ) ;
129+ let poll_handle = tokio:: spawn ( async move {
130+ Self :: poll_connection ( conn, subscriptions_clone) . await ;
131+ } ) ;
130132
131133 // Get channels to re-subscribe to
132134 let channels: Vec < String > =
@@ -135,38 +137,41 @@ impl PostgresDriver {
135137
136138 if needs_resubscribe {
137139 tracing:: debug!(
138- ?channels,
140+ channels= ?channels. len ( ) ,
139141 "will re-subscribe to channels after connection starts"
140142 ) ;
141143 }
142144
143- // Spawn a task to re -subscribe after a short delay
145+ // Re -subscribe to channels
144146 if needs_resubscribe {
145- let client_for_resub = new_client. clone ( ) ;
146- let channels_clone = channels. clone ( ) ;
147- tokio:: spawn ( async move {
148- tracing:: debug!(
149- ?channels_clone,
150- "re-subscribing to channels after reconnection"
151- ) ;
152- for channel in & channels_clone {
153- if let Result :: Err ( e) = client_for_resub
154- . execute ( & format ! ( "LISTEN \" {}\" " , channel) , & [ ] )
155- . await
156- {
157- tracing:: error!( ?e, %channel, "failed to re-subscribe to channel" ) ;
158- } else {
159- tracing:: debug!( %channel, "successfully re-subscribed to channel" ) ;
160- }
147+ tracing:: debug!(
148+ channels=?channels. len( ) ,
149+ "re-subscribing to channels after reconnection"
150+ ) ;
151+ for channel in & channels {
152+ tracing:: info!( ?channel, "re-subscribing to channel" ) ;
153+ if let Result :: Err ( e) = new_client
154+ . execute ( & format ! ( "LISTEN \" {}\" " , channel) , & [ ] )
155+ . await
156+ {
157+ tracing:: error!( ?e, %channel, "failed to re-subscribe to channel" ) ;
158+ } else {
159+ tracing:: debug!( %channel, "successfully re-subscribed to channel" ) ;
161160 }
162- } ) ;
161+ }
163162 }
164163
165- // Poll the connection until it closes
166- Self :: poll_connection ( conn, subscriptions. clone ( ) ) . await ;
164+ // Update the client reference and signal ready
165+ // Do this AFTER re-subscribing to ensure LISTEN is complete
166+ * client. lock ( ) . await = Some ( new_client. clone ( ) ) ;
167+ let _ = ready_tx. send ( true ) ;
168+
169+ // Wait for the polling task to complete (when the connection closes)
170+ let _ = poll_handle. await ;
167171
168172 // Clear the client reference on disconnect
169173 * client. lock ( ) . await = None ;
174+
170175 // Notify that client is disconnected
171176 let _ = ready_tx. send ( false ) ;
172177 }
@@ -208,12 +213,12 @@ impl PostgresDriver {
208213 // Ignore other async messages
209214 }
210215 Some ( std:: result:: Result :: Err ( err) ) => {
211- tracing:: error!( ?err, "postgres connection error, reconnecting " ) ;
212- break ; // Exit loop to reconnect
216+ tracing:: error!( ?err, "postgres connection error" ) ;
217+ break ;
213218 }
214219 None => {
215- tracing:: warn!( "postgres connection closed, reconnecting " ) ;
216- break ; // Exit loop to reconnect
220+ tracing:: warn!( "postgres connection closed" ) ;
221+ break ;
217222 }
218223 }
219224 }
@@ -224,19 +229,16 @@ impl PostgresDriver {
224229 let mut ready_rx = self . client_ready . clone ( ) ;
225230 tokio:: time:: timeout ( tokio:: time:: Duration :: from_secs ( 5 ) , async {
226231 loop {
227- // Subscribe to changed before attempting to access client
228- let changed_fut = ready_rx. changed ( ) ;
229-
230232 // Check if client is already available
231233 if let Some ( client) = self . client . lock ( ) . await . clone ( ) {
232234 return Ok ( client) ;
233235 }
234236
235- // Wait for change, will return client if exists on next iteration
236- changed_fut
237+ // Wait for the ready signal to change
238+ ready_rx
239+ . changed ( )
237240 . await
238241 . map_err ( |_| anyhow ! ( "connection lifecycle task ended" ) ) ?;
239- tracing:: debug!( "client does not exist immediately after receive ready" ) ;
240242 }
241243 } )
242244 . await
@@ -288,13 +290,25 @@ impl PubSubDriver for PostgresDriver {
288290
289291 // Execute LISTEN command on the async client (for receiving notifications)
290292 // This only needs to be done once per channel
291- // Wait for client to be connected with retry logic
292- let client = self . wait_for_client ( ) . await ?;
293- let span = tracing:: trace_span!( "pg_listen" ) ;
294- client
295- . execute ( & format ! ( "LISTEN \" {hashed}\" " ) , & [ ] )
296- . instrument ( span)
297- . await ?;
293+ // Try to LISTEN if client is available, but don't fail if disconnected
294+ // The reconnection logic will handle re-subscribing
295+ if let Some ( client) = self . client . lock ( ) . await . clone ( ) {
296+ let span = tracing:: trace_span!( "pg_listen" ) ;
297+ match client
298+ . execute ( & format ! ( "LISTEN \" {hashed}\" " ) , & [ ] )
299+ . instrument ( span)
300+ . await
301+ {
302+ Result :: Ok ( _) => {
303+ tracing:: debug!( %hashed, "successfully subscribed to channel" ) ;
304+ }
305+ Result :: Err ( e) => {
306+ tracing:: warn!( ?e, %hashed, "failed to LISTEN, will retry on reconnection" ) ;
307+ }
308+ }
309+ } else {
310+ tracing:: debug!( %hashed, "client not connected, will LISTEN on reconnection" ) ;
311+ }
298312
299313 // Spawn a single cleanup task for this subscription waiting on its token
300314 let driver = self . clone ( ) ;
@@ -333,14 +347,66 @@ impl PubSubDriver for PostgresDriver {
333347
334348 // Encode payload to base64 and send NOTIFY
335349 let encoded = BASE64 . encode ( payload) ;
336- let conn = self . pool . get ( ) . await ?;
337350 let hashed = self . hash_subject ( subject) ;
338- let span = tracing:: trace_span!( "pg_notify" ) ;
339- conn. execute ( & format ! ( "NOTIFY \" {hashed}\" , '{encoded}'" ) , & [ ] )
340- . instrument ( span)
341- . await ?;
342351
343- Ok ( ( ) )
352+ tracing:: debug!( "attempting to get connection for publish" ) ;
353+
354+ // Wait for listen connection to be ready first if this channel has subscribers
355+ // This ensures that if we're reconnecting, the LISTEN is re-registered before NOTIFY
356+ if self . subscriptions . lock ( ) . await . contains_key ( & hashed) {
357+ self . wait_for_client ( ) . await ?;
358+ }
359+
360+ // Retry getting a connection from the pool with backoff in case the connection is
361+ // currently disconnected
362+ let mut backoff = Backoff :: default ( ) ;
363+ let mut last_error = None ;
364+
365+ loop {
366+ match self . pool . get ( ) . await {
367+ Result :: Ok ( conn) => {
368+ // Test the connection with a simple query before using it
369+ match conn. execute ( "SELECT 1" , & [ ] ) . await {
370+ Result :: Ok ( _) => {
371+ // Connection is good, use it for NOTIFY
372+ let span = tracing:: trace_span!( "pg_notify" ) ;
373+ match conn
374+ . execute ( & format ! ( "NOTIFY \" {hashed}\" , '{encoded}'" ) , & [ ] )
375+ . instrument ( span)
376+ . await
377+ {
378+ Result :: Ok ( _) => return Ok ( ( ) ) ,
379+ Result :: Err ( e) => {
380+ tracing:: debug!(
381+ ?e,
382+ "NOTIFY failed, retrying with new connection"
383+ ) ;
384+ last_error = Some ( e. into ( ) ) ;
385+ }
386+ }
387+ }
388+ Result :: Err ( e) => {
389+ tracing:: debug!(
390+ ?e,
391+ "connection test failed, retrying with new connection"
392+ ) ;
393+ last_error = Some ( e. into ( ) ) ;
394+ }
395+ }
396+ }
397+ Result :: Err ( e) => {
398+ tracing:: debug!( ?e, "failed to get connection from pool, retrying" ) ;
399+ last_error = Some ( e. into ( ) ) ;
400+ }
401+ }
402+
403+ // Check if we should continue retrying
404+ if !backoff. tick ( ) . await {
405+ return Err (
406+ last_error. unwrap_or_else ( || anyhow ! ( "failed to publish after retries" ) )
407+ ) ;
408+ }
409+ }
344410 }
345411
346412 async fn flush ( & self ) -> Result < ( ) > {
0 commit comments