Skip to content

Commit 915979b

Browse files
committed
change WaitTimeout to WaitWithContext
1 parent b6917b1 commit 915979b

File tree

2 files changed

+30
-20
lines changed

2 files changed

+30
-20
lines changed

watcher.go

+6-12
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
package rpc
55

66
import (
7+
"context"
78
"errors"
89
"sync"
910
"sync/atomic"
10-
"time"
1111
)
1212

1313
// ErrWatcherShutdown is returned when the watcher is shut down.
@@ -35,8 +35,8 @@ func freeEvent(e *event) {
3535
type Watcher interface {
3636
// Wait will return value when the key is triggered.
3737
Wait() ([]byte, error)
38-
// WaitTimeout acts like Wait but takes a timeout.
39-
WaitTimeout(time.Duration) ([]byte, error)
38+
// WaitWithContext acts like Wait but takes a context.
39+
WaitWithContext(context.Context) ([]byte, error)
4040
// Stop stops the watch.
4141
Stop() error
4242
}
@@ -89,22 +89,16 @@ func (w *watcher) Wait() (value []byte, err error) {
8989
return
9090
}
9191

92-
func (w *watcher) WaitTimeout(timeout time.Duration) (value []byte, err error) {
93-
if timeout <= 0 {
94-
return w.Wait()
95-
}
96-
timer := time.NewTimer(timeout)
92+
func (w *watcher) WaitWithContext(ctx context.Context) (value []byte, err error) {
9793
select {
9894
case e := <-w.C:
99-
timer.Stop()
10095
w.triggerNext()
10196
value = e.Value
10297
err = e.Error
10398
freeEvent(e)
104-
case <-timer.C:
105-
err = ErrTimeout
99+
case <-ctx.Done():
100+
err = ctx.Err()
106101
case <-w.done:
107-
timer.Stop()
108102
err = ErrWatcherShutdown
109103
}
110104
return

watcher_test.go

+24-8
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
package rpc
55

66
import (
7+
"context"
78
"testing"
89
"time"
910
)
@@ -16,7 +17,7 @@ func TestWatcherTrigger(t *testing.T) {
1617
watcher.trigger(e)
1718
}
1819
for i := byte(0); i < 255; i++ {
19-
v, err := watcher.WaitTimeout(0)
20+
v, err := watcher.Wait()
2021
if err != nil {
2122
t.Error(err)
2223
} else if len(v) == 0 {
@@ -42,29 +43,44 @@ func TestWatcherTriggerTimeout(t *testing.T) {
4243
watcher.trigger(e)
4344
}
4445
for i := byte(0); i < 255; i++ {
45-
v, err := watcher.WaitTimeout(time.Minute)
46+
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
47+
v, err := watcher.WaitWithContext(ctx)
4648
if err != nil {
4749
t.Error(err)
4850
} else if len(v) == 0 {
4951
t.Error("len == 0")
5052
} else if v[0] != i {
5153
t.Error("out of order")
5254
}
55+
cancel()
5356
}
5457
go func() {
55-
close(watcher.done)
58+
watcher.stop()
5659
}()
57-
_, err := watcher.WaitTimeout(time.Minute)
60+
_, err := watcher.Wait()
5861
if err != ErrWatcherShutdown {
5962
t.Error(err)
6063
}
6164
}
6265

6366
func TestWatcherTriggerTimeoutErr(t *testing.T) {
6467
watcher := &watcher{C: make(chan *event, 10), done: make(chan struct{}, 1)}
65-
_, err := watcher.WaitTimeout(time.Millisecond * 1)
66-
if err != ErrTimeout {
67-
t.Error(err)
68+
{
69+
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*1)
70+
_, err := watcher.WaitWithContext(ctx)
71+
if err == nil {
72+
t.Error()
73+
}
74+
cancel()
6875
}
69-
close(watcher.done)
76+
watcher.stop()
77+
{
78+
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*1)
79+
_, err := watcher.WaitWithContext(ctx)
80+
if err != ErrWatcherShutdown {
81+
t.Error(err)
82+
}
83+
cancel()
84+
}
85+
7086
}

0 commit comments

Comments
 (0)