1+ use std:: collections:: HashMap ;
12use std:: hash:: { DefaultHasher , Hash , Hasher } ;
2- use std:: sync:: Arc ;
3+ use std:: sync:: { Arc , Mutex } ;
34
45use anyhow:: * ;
56use async_trait:: async_trait;
67use base64:: Engine ;
78use base64:: engine:: general_purpose:: STANDARD_NO_PAD as BASE64 ;
89use deadpool_postgres:: { Config , ManagerConfig , Pool , PoolConfig , RecyclingMethod , Runtime } ;
910use futures_util:: future:: poll_fn;
10- use moka:: future:: Cache ;
1111use tokio_postgres:: { AsyncMessage , NoTls } ;
1212use tracing:: Instrument ;
1313
@@ -18,6 +18,15 @@ use crate::pubsub::DriverOutput;
1818struct Subscription {
1919 // Channel to send messages to this subscription
2020 tx : tokio:: sync:: broadcast:: Sender < Vec < u8 > > ,
21+ // Cancellation token shared by all subscribers of this subject
22+ token : tokio_util:: sync:: CancellationToken ,
23+ }
24+
25+ impl Subscription {
26+ fn new ( tx : tokio:: sync:: broadcast:: Sender < Vec < u8 > > ) -> Self {
27+ let token = tokio_util:: sync:: CancellationToken :: new ( ) ;
28+ Self { tx, token }
29+ }
2130}
2231
2332/// > In the default configuration it must be shorter than 8000 bytes
@@ -40,7 +49,7 @@ pub const POSTGRES_MAX_MESSAGE_SIZE: usize =
4049pub struct PostgresDriver {
4150 pool : Arc < Pool > ,
4251 client : Arc < tokio_postgres:: Client > ,
43- subscriptions : Cache < String , Subscription > ,
52+ subscriptions : Arc < Mutex < HashMap < String , Subscription > > > ,
4453}
4554
4655impl PostgresDriver {
@@ -65,8 +74,8 @@ impl PostgresDriver {
6574 . context ( "failed to create postgres pool" ) ?;
6675 tracing:: debug!( "postgres pool created successfully" ) ;
6776
68- let subscriptions: Cache < String , Subscription > =
69- Cache :: builder ( ) . initial_capacity ( 5 ) . build ( ) ;
77+ let subscriptions: Arc < Mutex < HashMap < String , Subscription > > > =
78+ Arc :: new ( Mutex :: new ( HashMap :: new ( ) ) ) ;
7079 let subscriptions2 = subscriptions. clone ( ) ;
7180
7281 let ( client, mut conn) = tokio_postgres:: connect ( & conn_str, tokio_postgres:: NoTls ) . await ?;
@@ -75,7 +84,9 @@ impl PostgresDriver {
7584 loop {
7685 match poll_fn ( |cx| conn. poll_message ( cx) ) . await {
7786 Some ( std:: result:: Result :: Ok ( AsyncMessage :: Notification ( note) ) ) => {
78- if let Some ( sub) = subscriptions2. get ( note. channel ( ) ) . await {
87+ if let Some ( sub) =
88+ subscriptions2. lock ( ) . unwrap ( ) . get ( note. channel ( ) ) . cloned ( )
89+ {
7990 let bytes = match BASE64 . decode ( note. payload ( ) ) {
8091 std:: result:: Result :: Ok ( b) => b,
8192 std:: result:: Result :: Err ( err) => {
@@ -121,7 +132,7 @@ impl PostgresDriver {
121132#[ async_trait]
122133impl PubSubDriver for PostgresDriver {
123134 async fn subscribe ( & self , subject : & str ) -> Result < SubscriberDriverHandle > {
124- // TODO: To match NATS implementation, LIST must be pipelined (i.e. wait for the command
135+ // TODO: To match NATS implementation, LISTEN must be pipelined (i.e. wait for the command
125136 // to reach the server, but not wait for it to respond). However, this has to ensure that
126137 // NOTIFY & LISTEN are called on the same connection (not diff connections in a pool) or
127138 // else there will be race conditions where messages might be published before
@@ -135,33 +146,57 @@ impl PubSubDriver for PostgresDriver {
135146 let hashed = self . hash_subject ( subject) ;
136147
137148 // Check if we already have a subscription for this channel
138- let rx = if let Some ( existing_sub) = self . subscriptions . get ( & hashed) . await {
139- // Reuse the existing broadcast channel
140- existing_sub. tx . subscribe ( )
141- } else {
142- // Create a new broadcast channel for this subject
143- let ( tx, rx) = tokio:: sync:: broadcast:: channel ( 1024 ) ;
144- let subscription = Subscription { tx : tx. clone ( ) } ;
145-
146- // Register subscription
147- self . subscriptions
148- . insert ( hashed. clone ( ) , subscription)
149- . await ;
150-
151- // Execute LISTEN command on the async client (for receiving notifications)
152- // This only needs to be done once per channel
153- let span = tracing:: trace_span!( "pg_listen" ) ;
154- self . client
155- . execute ( & format ! ( "LISTEN \" {hashed}\" " ) , & [ ] )
156- . instrument ( span)
157- . await ?;
158-
159- rx
160- } ;
149+ let ( rx, drop_guard) =
150+ if let Some ( existing_sub) = self . subscriptions . lock ( ) . unwrap ( ) . get ( & hashed) . cloned ( ) {
151+ // Reuse the existing broadcast channel
152+ let rx = existing_sub. tx . subscribe ( ) ;
153+ let drop_guard = existing_sub. token . clone ( ) . drop_guard ( ) ;
154+ ( rx, drop_guard)
155+ } else {
156+ // Create a new broadcast channel for this subject
157+ let ( tx, rx) = tokio:: sync:: broadcast:: channel ( 1024 ) ;
158+ let subscription = Subscription :: new ( tx. clone ( ) ) ;
159+
160+ // Register subscription
161+ self . subscriptions
162+ . lock ( )
163+ . unwrap ( )
164+ . insert ( hashed. clone ( ) , subscription. clone ( ) ) ;
165+
166+ // Execute LISTEN command on the async client (for receiving notifications)
167+ // This only needs to be done once per channel
168+ let span = tracing:: trace_span!( "pg_listen" ) ;
169+ self . client
170+ . execute ( & format ! ( "LISTEN \" {hashed}\" " ) , & [ ] )
171+ . instrument ( span)
172+ . await ?;
173+
174+ // Spawn a single cleanup task for this subscription waiting on its token
175+ let driver = self . clone ( ) ;
176+ let hashed_clone = hashed. clone ( ) ;
177+ let tx_clone = tx. clone ( ) ;
178+ let token_clone = subscription. token . clone ( ) ;
179+ tokio:: spawn ( async move {
180+ token_clone. cancelled ( ) . await ;
181+ if tx_clone. receiver_count ( ) == 0 {
182+ let sql = format ! ( "UNLISTEN \" {}\" " , hashed_clone) ;
183+ if let Err ( err) = driver. client . execute ( sql. as_str ( ) , & [ ] ) . await {
184+ tracing:: warn!( ?err, %hashed_clone, "failed to UNLISTEN channel" ) ;
185+ } else {
186+ tracing:: trace!( %hashed_clone, "unlistened channel" ) ;
187+ }
188+ driver. subscriptions . lock ( ) . unwrap ( ) . remove ( & hashed_clone) ;
189+ }
190+ } ) ;
191+
192+ let drop_guard = subscription. token . clone ( ) . drop_guard ( ) ;
193+ ( rx, drop_guard)
194+ } ;
161195
162196 Ok ( Box :: new ( PostgresSubscriber {
163197 subject : subject. to_string ( ) ,
164- rx,
198+ rx : Some ( rx) ,
199+ _drop_guard : drop_guard,
165200 } ) )
166201 }
167202
@@ -191,13 +226,18 @@ impl PubSubDriver for PostgresDriver {
191226
192227pub struct PostgresSubscriber {
193228 subject : String ,
194- rx : tokio:: sync:: broadcast:: Receiver < Vec < u8 > > ,
229+ rx : Option < tokio:: sync:: broadcast:: Receiver < Vec < u8 > > > ,
230+ _drop_guard : tokio_util:: sync:: DropGuard ,
195231}
196232
197233#[ async_trait]
198234impl SubscriberDriver for PostgresSubscriber {
199235 async fn next ( & mut self ) -> Result < DriverOutput > {
200- match self . rx . recv ( ) . await {
236+ let rx = match self . rx . as_mut ( ) {
237+ Some ( rx) => rx,
238+ None => return Ok ( DriverOutput :: Unsubscribed ) ,
239+ } ;
240+ match rx. recv ( ) . await {
201241 std:: result:: Result :: Ok ( payload) => Ok ( DriverOutput :: Message {
202242 subject : self . subject . clone ( ) ,
203243 payload,
0 commit comments