Skip to content

Commit f2f7b75

Browse files
authored
feat: add the new WithOnSubscriptionHook for the Receive method (#846)
* feat: add the new WithOnSubscriptionHook for the Receive method Signed-off-by: Rueian <[email protected]> * feat: apply betteraligment Signed-off-by: Rueian <[email protected]> * feat: add the new WithOnSubscriptionHook for the Receive method Signed-off-by: Rueian <[email protected]> * feat: add the new WithOnSubscriptionHook for the Receive method Signed-off-by: Rueian <[email protected]> --------- Signed-off-by: Rueian <[email protected]>
1 parent 0244b0f commit f2f7b75

File tree

5 files changed

+160
-27
lines changed

5 files changed

+160
-27
lines changed

README.md

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,12 @@ To receive messages from channels, `client.Receive()` should be used. It support
277277

278278
```golang
279279
err = client.Receive(context.Background(), client.B().Subscribe().Channel("ch1", "ch2").Build(), func(msg rueidis.PubSubMessage) {
280-
// Handle the message. Note that if you want to call another `client.Do()` here, you need to do it in another goroutine or the `client` will be blocked.
280+
// Handle the message. If you need to perform heavy processing or issue
281+
// additional commands, do that in a separate goroutine to avoid
282+
// blocking the pipeline, e.g.:
283+
// go func() {
284+
// // long work or client.Do(...)
285+
// }()
281286
})
282287
```
283288

@@ -294,6 +299,28 @@ While the `client.Receive()` call is blocking, the `Client` is still able to acc
294299
and they are sharing the same TCP connection. If your message handler may take some time to complete, it is recommended
295300
to use the `client.Receive()` inside a `client.Dedicated()` for not blocking other concurrent requests.
296301

302+
#### Subscription confirmations
303+
304+
Use `rueidis.WithOnSubscriptionHook` when you need to observe subscribe / unsubscribe confirmations that the server sends during the lifetime of a `client.Receive()`.
305+
306+
The hook can be triggered multiple times because the `client.Receive()` may automatically reconnect and resubscribe.
307+
308+
```go
309+
ctx := rueidis.WithOnSubscriptionHook(context.Background(), func(s rueidis.PubSubSubscription) {
310+
// This hook runs in the pipeline goroutine. If you need to perform
311+
// heavy work or invoke additional commands, do it in another
312+
// goroutine to avoid blocking the pipeline, for example:
313+
// go func() {
314+
// // long work or client.Do(...)
315+
// }()
316+
fmt.Printf("%s %s (count %d)\n", s.Kind, s.Channel, s.Count)
317+
})
318+
319+
err := client.Receive(ctx, client.B().Subscribe().Channel("news").Build(), func(m rueidis.PubSubMessage) {
320+
// ...
321+
})
322+
```
323+
297324
### Alternative PubSub Hooks
298325

299326
The `client.Receive()` requires users to provide a subscription command in advance.
@@ -305,7 +332,12 @@ defer cancel()
305332

306333
wait := c.SetPubSubHooks(rueidis.PubSubHooks{
307334
OnMessage: func(m rueidis.PubSubMessage) {
308-
// Handle the message. Note that if you want to call another `c.Do()` here, you need to do it in another goroutine or the `c` will be blocked.
335+
// Handle the message. If you need to perform heavy processing or issue
336+
// additional commands, do that in a separate goroutine to avoid
337+
// blocking the pipeline, e.g.:
338+
// go func() {
339+
// // long work or client.Do(...)
340+
// }()
309341
}
310342
})
311343
c.Do(ctx, c.B().Subscribe().Channel("ch").Build())

pipe.go

Lines changed: 51 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -720,26 +720,45 @@ func (p *pipe) handlePush(values []RedisMessage) (reply bool, unsubscribe bool)
720720
p.pshks.Load().(*pshks).hooks.OnMessage(m)
721721
}
722722
case "unsubscribe":
723-
p.nsubs.Unsubscribe(values[1].string())
724723
if len(values) >= 3 {
725-
p.pshks.Load().(*pshks).hooks.OnSubscription(PubSubSubscription{Kind: values[0].string(), Channel: values[1].string(), Count: values[2].intlen})
724+
s := PubSubSubscription{Kind: values[0].string(), Channel: values[1].string(), Count: values[2].intlen}
725+
p.nsubs.Unsubscribe(s)
726+
p.pshks.Load().(*pshks).hooks.OnSubscription(s)
726727
}
727728
return true, true
728729
case "punsubscribe":
729-
p.psubs.Unsubscribe(values[1].string())
730730
if len(values) >= 3 {
731-
p.pshks.Load().(*pshks).hooks.OnSubscription(PubSubSubscription{Kind: values[0].string(), Channel: values[1].string(), Count: values[2].intlen})
731+
s := PubSubSubscription{Kind: values[0].string(), Channel: values[1].string(), Count: values[2].intlen}
732+
p.psubs.Unsubscribe(s)
733+
p.pshks.Load().(*pshks).hooks.OnSubscription(s)
732734
}
733735
return true, true
734736
case "sunsubscribe":
735-
p.ssubs.Unsubscribe(values[1].string())
736737
if len(values) >= 3 {
737-
p.pshks.Load().(*pshks).hooks.OnSubscription(PubSubSubscription{Kind: values[0].string(), Channel: values[1].string(), Count: values[2].intlen})
738+
s := PubSubSubscription{Kind: values[0].string(), Channel: values[1].string(), Count: values[2].intlen}
739+
p.ssubs.Unsubscribe(s)
740+
p.pshks.Load().(*pshks).hooks.OnSubscription(s)
738741
}
739742
return true, true
740-
case "subscribe", "psubscribe", "ssubscribe":
743+
case "subscribe":
741744
if len(values) >= 3 {
742-
p.pshks.Load().(*pshks).hooks.OnSubscription(PubSubSubscription{Kind: values[0].string(), Channel: values[1].string(), Count: values[2].intlen})
745+
s := PubSubSubscription{Kind: values[0].string(), Channel: values[1].string(), Count: values[2].intlen}
746+
p.nsubs.Confirm(s)
747+
p.pshks.Load().(*pshks).hooks.OnSubscription(s)
748+
}
749+
return true, false
750+
case "psubscribe":
751+
if len(values) >= 3 {
752+
s := PubSubSubscription{Kind: values[0].string(), Channel: values[1].string(), Count: values[2].intlen}
753+
p.psubs.Confirm(s)
754+
p.pshks.Load().(*pshks).hooks.OnSubscription(s)
755+
}
756+
return true, false
757+
case "ssubscribe":
758+
if len(values) >= 3 {
759+
s := PubSubSubscription{Kind: values[0].string(), Channel: values[1].string(), Count: values[2].intlen}
760+
p.ssubs.Confirm(s)
761+
p.pshks.Load().(*pshks).hooks.OnSubscription(s)
743762
}
744763
return true, false
745764
}
@@ -762,6 +781,25 @@ func (p *pipe) _r2pipe(ctx context.Context) (r2p *pipe) {
762781
return r2p
763782
}
764783

784+
type recvCtxKey int
785+
786+
const hookKey recvCtxKey = 0
787+
788+
// WithOnSubscriptionHook attaches a subscription confirmation hook to the provided
789+
// context and returns a new context for the Receive method.
790+
//
791+
// The hook is invoked each time the server sends a subscribe or
792+
// unsubscribe confirmation, allowing callers to observe the state of a Pub/Sub
793+
// subscription during the lifetime of a Receive invocation.
794+
//
795+
// The hook may be called multiple times because the client can resubscribe after a
796+
// reconnection. Therefore, the hook implementation must be safe to run more than once.
797+
// Also, there should not be any blocking operations or another `client.Do()` in the hook
798+
// since it runs in the same goroutine as the pipeline. Otherwise, the pipeline will be blocked.
799+
func WithOnSubscriptionHook(ctx context.Context, hook func(PubSubSubscription)) context.Context {
800+
return context.WithValue(ctx, hookKey, hook)
801+
}
802+
765803
func (p *pipe) Receive(ctx context.Context, subscribe Completed, fn func(message PubSubMessage)) error {
766804
if p.nsubs == nil || p.psubs == nil || p.ssubs == nil {
767805
return p.Error()
@@ -787,7 +825,11 @@ func (p *pipe) Receive(ctx context.Context, subscribe Completed, fn func(message
787825
panic(wrongreceive)
788826
}
789827

790-
if ch, cancel := sb.Subscribe(args); ch != nil {
828+
var hook func(PubSubSubscription)
829+
if v := ctx.Value(hookKey); v != nil {
830+
hook = v.(func(PubSubSubscription))
831+
}
832+
if ch, cancel := sb.Subscribe(args, hook); ch != nil {
791833
defer cancel()
792834
if err := p.Do(ctx, subscribe).Error(); err != nil {
793835
return err

pipe_test.go

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2962,6 +2962,8 @@ func TestPubSub(t *testing.T) {
29622962
).Reply(strmsg('+', "PONG"))
29632963
}()
29642964

2965+
confirms := make(chan PubSubSubscription, 2)
2966+
ctx = WithOnSubscriptionHook(ctx, func(s PubSubSubscription) { confirms <- s })
29652967
if err := p.Receive(ctx, activate, func(msg PubSubMessage) {
29662968
if msg.Channel == "1" && msg.Message == "2" {
29672969
if err := p.Do(ctx, deactivate).Error(); err != nil {
@@ -2971,7 +2973,12 @@ func TestPubSub(t *testing.T) {
29712973
}); err != nil {
29722974
t.Fatalf("unexpected err %v", err)
29732975
}
2974-
2976+
if s := <-confirms; s.Kind != "subscribe" || s.Channel != "1" {
2977+
t.Fatalf("unexpected subscription %v", s)
2978+
}
2979+
if s := <-confirms; s.Kind != "unsubscribe" || s.Channel != "1" {
2980+
t.Fatalf("unexpected subscription %v", s)
2981+
}
29752982
cancel()
29762983
})
29772984

@@ -3003,6 +3010,8 @@ func TestPubSub(t *testing.T) {
30033010
).Reply(strmsg('+', "PONG"))
30043011
}()
30053012

3013+
confirms := make(chan PubSubSubscription, 2)
3014+
ctx = WithOnSubscriptionHook(ctx, func(s PubSubSubscription) { confirms <- s })
30063015
if err := p.Receive(ctx, activate, func(msg PubSubMessage) {
30073016
if msg.Channel == "1" && msg.Message == "2" {
30083017
if err := p.Do(ctx, deactivate).Error(); err != nil {
@@ -3012,7 +3021,12 @@ func TestPubSub(t *testing.T) {
30123021
}); err != nil {
30133022
t.Fatalf("unexpected err %v", err)
30143023
}
3015-
3024+
if s := <-confirms; s.Kind != "ssubscribe" || s.Channel != "1" {
3025+
t.Fatalf("unexpected subscription %v", s)
3026+
}
3027+
if s := <-confirms; s.Kind != "sunsubscribe" || s.Channel != "1" {
3028+
t.Fatalf("unexpected subscription %v", s)
3029+
}
30163030
cancel()
30173031
})
30183032

@@ -3045,6 +3059,8 @@ func TestPubSub(t *testing.T) {
30453059
).Reply(strmsg('+', "PONG"))
30463060
}()
30473061

3062+
confirms := make(chan PubSubSubscription, 2)
3063+
ctx = WithOnSubscriptionHook(ctx, func(s PubSubSubscription) { confirms <- s })
30483064
if err := p.Receive(ctx, activate, func(msg PubSubMessage) {
30493065
if msg.Pattern == "1" && msg.Channel == "2" && msg.Message == "3" {
30503066
if err := p.Do(ctx, deactivate).Error(); err != nil {
@@ -3054,7 +3070,12 @@ func TestPubSub(t *testing.T) {
30543070
}); err != nil {
30553071
t.Fatalf("unexpected err %v", err)
30563072
}
3057-
3073+
if s := <-confirms; s.Kind != "psubscribe" || s.Channel != "1" {
3074+
t.Fatalf("unexpected subscription %v", s)
3075+
}
3076+
if s := <-confirms; s.Kind != "punsubscribe" || s.Channel != "1" {
3077+
t.Fatalf("unexpected subscription %v", s)
3078+
}
30583079
cancel()
30593080
})
30603081

pubsub.go

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@ type PubSubMessage struct {
1515
Message string
1616
}
1717

18-
// PubSubSubscription represent a pubsub "subscribe", "unsubscribe", "psubscribe" or "punsubscribe" event.
18+
// PubSubSubscription represent a pubsub "subscribe", "unsubscribe", "ssubscribe", "sunsubscribe", "psubscribe" or "punsubscribe" event.
1919
type PubSubSubscription struct {
20-
// Kind is "subscribe", "unsubscribe", "psubscribe" or "punsubscribe"
20+
// Kind is "subscribe", "unsubscribe", "ssubscribe", "sunsubscribe", "psubscribe" or "punsubscribe"
2121
Kind string
2222
// Channel is the event subject.
2323
Channel string
@@ -54,6 +54,7 @@ type chs struct {
5454

5555
type sub struct {
5656
ch chan PubSubMessage
57+
fn func(PubSubSubscription)
5758
cs []string
5859
}
5960

@@ -67,12 +68,12 @@ func (s *subs) Publish(channel string, msg PubSubMessage) {
6768
}
6869
}
6970

70-
func (s *subs) Subscribe(channels []string) (ch chan PubSubMessage, cancel func()) {
71+
func (s *subs) Subscribe(channels []string, fn func(PubSubSubscription)) (ch chan PubSubMessage, cancel func()) {
7172
id := atomic.AddUint64(&s.cnt, 1)
7273
s.mu.Lock()
7374
if s.chs != nil {
7475
ch = make(chan PubSubMessage, 16)
75-
sb := &sub{cs: channels, ch: ch}
76+
sb := &sub{cs: channels, ch: ch, fn: fn}
7677
s.sub[id] = sb
7778
for _, channel := range channels {
7879
c := s.chs[channel].sub
@@ -110,13 +111,28 @@ func (s *subs) remove(id uint64) {
110111
}
111112
}
112113

113-
func (s *subs) Unsubscribe(channel string) {
114+
func (s *subs) Confirm(sub PubSubSubscription) {
115+
if atomic.LoadUint64(&s.cnt) != 0 {
116+
s.mu.RLock()
117+
for _, sb := range s.chs[sub.Channel].sub {
118+
if sb.fn != nil {
119+
sb.fn(sub)
120+
}
121+
}
122+
s.mu.RUnlock()
123+
}
124+
}
125+
126+
func (s *subs) Unsubscribe(sub PubSubSubscription) {
114127
if atomic.LoadUint64(&s.cnt) != 0 {
115128
s.mu.Lock()
116-
for id := range s.chs[channel].sub {
129+
for id, sb := range s.chs[sub.Channel].sub {
130+
if sb.fn != nil {
131+
sb.fn(sub)
132+
}
117133
s.remove(id)
118134
}
119-
delete(s.chs, channel)
135+
delete(s.chs, sub.Channel)
120136
s.mu.Unlock()
121137
}
122138
}

pubsub_test.go

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,24 @@ func TestSubs_Publish(t *testing.T) {
1414

1515
t.Run("with multiple subs", func(t *testing.T) {
1616
s := newSubs()
17-
ch1, cancel1 := s.Subscribe([]string{"a"})
18-
ch2, cancel2 := s.Subscribe([]string{"a"})
19-
ch3, cancel3 := s.Subscribe([]string{"b"})
17+
counts := map[string]int{
18+
"a": 0,
19+
"b": 0,
20+
}
21+
subFn := func(s PubSubSubscription) {
22+
counts[s.Channel]++
23+
}
24+
25+
ch1, cancel1 := s.Subscribe([]string{"a"}, subFn)
26+
ch2, cancel2 := s.Subscribe([]string{"a"}, subFn)
27+
ch3, cancel3 := s.Subscribe([]string{"b"}, subFn)
28+
s.Confirm(PubSubSubscription{Channel: "a"})
29+
s.Confirm(PubSubSubscription{Channel: "b"})
30+
31+
if counts["a"] != 2 || counts["b"] != 1 {
32+
t.Fatalf("unexpected counts %v", counts)
33+
}
34+
2035
m1 := PubSubMessage{Pattern: "1", Channel: "2", Message: "3"}
2136
m2 := PubSubMessage{Pattern: "11", Channel: "22", Message: "33"}
2237
go func() {
@@ -45,7 +60,7 @@ func TestSubs_Publish(t *testing.T) {
4560

4661
t.Run("drain ch", func(t *testing.T) {
4762
s := newSubs()
48-
ch, cancel := s.Subscribe([]string{"a"})
63+
ch, cancel := s.Subscribe([]string{"a"}, nil)
4964
s.Publish("a", PubSubMessage{})
5065
if len(ch) != 1 {
5166
t.Fatalf("unexpected ch len %v", len(ch))
@@ -60,15 +75,22 @@ func TestSubs_Publish(t *testing.T) {
6075
func TestSubs_Unsubscribe(t *testing.T) {
6176
defer ShouldNotLeaked(SetupLeakDetection())
6277
s := newSubs()
63-
ch, _ := s.Subscribe([]string{"1", "2"})
78+
counts := map[string]int{"1": 0, "2": 0}
79+
subFn := func(s PubSubSubscription) {
80+
counts[s.Channel]++
81+
}
82+
ch, _ := s.Subscribe([]string{"1", "2"}, subFn)
6483
go func() {
6584
s.Publish("1", PubSubMessage{})
6685
}()
6786
_, ok := <-ch
6887
if !ok {
6988
t.Fatalf("unexpected ch closed")
7089
}
71-
s.Unsubscribe("1")
90+
s.Unsubscribe(PubSubSubscription{Channel: "1"})
91+
if counts["1"] != 1 {
92+
t.Fatalf("unexpected counts %v", counts)
93+
}
7294
_, ok = <-ch
7395
if ok {
7496
t.Fatalf("unexpected ch unclosed")

0 commit comments

Comments
 (0)