Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions v2/delivery/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package delivery
import (
"net/http"
"net/url"
"sync/atomic"
"time"

"github.com/gorilla/websocket"
Expand Down Expand Up @@ -60,19 +61,19 @@ var wsServe = func(cfg *WsConfig, handler WsHandler, errHandler ErrHandler) (don
// Wait for the stopC channel to be closed. We do that in a
// separate goroutine because ReadMessage is a blocking
// operation.
silent := false
var silent int32
go func() {
select {
case <-stopC:
silent = true
atomic.StoreInt32(&silent, 1)
case <-doneC:
}
c.Close()
}()
for {
_, message, err := c.ReadMessage()
if err != nil {
if !silent {
if atomic.LoadInt32(&silent) == 0 {
errHandler(err)
}
return
Expand All @@ -86,7 +87,8 @@ var wsServe = func(cfg *WsConfig, handler WsHandler, errHandler ErrHandler) (don
func keepAlive(c *websocket.Conn, timeout time.Duration) {
ticker := time.NewTicker(timeout)

lastResponse := time.Now()
var lastResponse int64
atomic.StoreInt64(&lastResponse, time.Now().Unix())

c.SetPingHandler(func(pingData string) error {
// Respond with Pong using the server's PING payload
Expand All @@ -99,7 +101,7 @@ func keepAlive(c *websocket.Conn, timeout time.Duration) {
return err
}

lastResponse = time.Now()
atomic.StoreInt64(&lastResponse, time.Now().Unix())

return nil
})
Expand All @@ -108,7 +110,7 @@ func keepAlive(c *websocket.Conn, timeout time.Duration) {
defer ticker.Stop()
for {
<-ticker.C
if time.Since(lastResponse) > timeout {
if time.Since(time.Unix(atomic.LoadInt64(&lastResponse), 0)) > timeout {
c.Close()
return
}
Expand Down
14 changes: 8 additions & 6 deletions v2/futures/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package futures
import (
"net/http"
"net/url"
"sync/atomic"
"time"

"github.com/gorilla/websocket"
Expand Down Expand Up @@ -60,19 +61,19 @@ var wsServe = func(cfg *WsConfig, handler WsHandler, errHandler ErrHandler) (don
// Wait for the stopC channel to be closed. We do that in a
// separate goroutine because ReadMessage is a blocking
// operation.
silent := false
var silent int32
go func() {
select {
case <-stopC:
silent = true
atomic.StoreInt32(&silent, 1)
case <-doneC:
}
c.Close()
}()
for {
_, message, err := c.ReadMessage()
if err != nil {
if !silent {
if atomic.LoadInt32(&silent) == 0 {
errHandler(err)
}
return
Expand All @@ -86,7 +87,8 @@ var wsServe = func(cfg *WsConfig, handler WsHandler, errHandler ErrHandler) (don
func keepAlive(c *websocket.Conn, timeout time.Duration) {
ticker := time.NewTicker(timeout)

lastResponse := time.Now()
var lastResponse int64
atomic.StoreInt64(&lastResponse, time.Now().Unix())

c.SetPingHandler(func(pingData string) error {
// Respond with Pong using the server's PING payload
Expand All @@ -99,7 +101,7 @@ func keepAlive(c *websocket.Conn, timeout time.Duration) {
return err
}

lastResponse = time.Now()
atomic.StoreInt64(&lastResponse, time.Now().Unix())

return nil
})
Expand All @@ -108,7 +110,7 @@ func keepAlive(c *websocket.Conn, timeout time.Duration) {
defer ticker.Stop()
for {
<-ticker.C
if time.Since(lastResponse) > timeout {
if time.Since(time.Unix(atomic.LoadInt64(&lastResponse), 0)) > timeout {
c.Close()
return
}
Expand Down
120 changes: 120 additions & 0 deletions v2/futures/websocket_race_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
package futures

import (
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"
"time"

"github.com/gorilla/websocket"
)

var upgrader = websocket.Upgrader{}

// TestKeepAliveNoRace verifies that keepAlive's lastResponse variable
// doesn't trigger the race detector when pings arrive concurrently
// with the ticker reads.
func TestKeepAliveNoRace(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
c, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
defer c.Close()

// Send pings rapidly to trigger the ping handler on the client side
for i := 0; i < 50; i++ {
if err := c.WriteControl(websocket.PingMessage, []byte("ping"), time.Now().Add(time.Second)); err != nil {
return
}
time.Sleep(5 * time.Millisecond)
}
}))
defer srv.Close()

wsURL := "ws" + strings.TrimPrefix(srv.URL, "http")
c, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
if err != nil {
t.Fatalf("dial: %v", err)
}
defer c.Close()

// keepAlive with a short timeout so the ticker fires frequently
keepAlive(c, 100*time.Millisecond)

// Read messages to drive the ping handler (ReadMessage dispatches control frames)
done := make(chan struct{})
go func() {
defer close(done)
for {
_, _, err := c.ReadMessage()
if err != nil {
return
}
}
}()

// Let it run for enough time that ticker and pings overlap
time.Sleep(300 * time.Millisecond)
c.Close()
<-done
}

// TestWsSilentNoRace verifies that the silent variable in wsServe
// doesn't trigger the race detector when stopC is closed during ReadMessage.
func TestWsSilentNoRace(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
c, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
defer c.Close()

// Send a few messages then hang so the client blocks on ReadMessage
for i := 0; i < 5; i++ {
c.WriteMessage(websocket.TextMessage, []byte("hello"))
time.Sleep(10 * time.Millisecond)
}
// Hold connection open
time.Sleep(500 * time.Millisecond)
}))
defer srv.Close()

wsURL := "ws" + strings.TrimPrefix(srv.URL, "http")

origKeepalive := WebsocketKeepalive
WebsocketKeepalive = false
defer func() { WebsocketKeepalive = origKeepalive }()

cfg := &WsConfig{Endpoint: wsURL}

var received int
var mu sync.Mutex
handler := func(msg []byte) {
mu.Lock()
received++
mu.Unlock()
}
errHandler := func(err error) {}

doneC, stopC, err := wsServe(cfg, handler, errHandler)
if err != nil {
t.Fatalf("wsServe: %v", err)
}

// Let some messages arrive
time.Sleep(80 * time.Millisecond)

// Close stopC which sets silent=true in one goroutine
// while ReadMessage loop checks it in another — this is the race we fixed
close(stopC)
<-doneC

mu.Lock()
defer mu.Unlock()
if received == 0 {
t.Error("expected to receive at least one message")
}
}
14 changes: 8 additions & 6 deletions v2/options/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package options
import (
"net/http"
"net/url"
"sync/atomic"
"time"

"github.com/gorilla/websocket"
Expand Down Expand Up @@ -61,19 +62,19 @@ var wsServe = func(cfg *WsConfig, handler WsHandler, errHandler ErrHandler) (don
// Wait for the stopC channel to be closed. We do that in a
// separate goroutine because ReadMessage is a blocking
// operation.
silent := false
var silent int32
go func() {
select {
case <-stopC:
silent = true
atomic.StoreInt32(&silent, 1)
case <-doneC:
}
c.Close()
}()
for {
_, message, err := c.ReadMessage()
if err != nil {
if !silent {
if atomic.LoadInt32(&silent) == 0 {
errHandler(err)
}
return
Expand All @@ -87,7 +88,8 @@ var wsServe = func(cfg *WsConfig, handler WsHandler, errHandler ErrHandler) (don
func keepAlive(c *websocket.Conn, timeout time.Duration) {
ticker := time.NewTicker(timeout)

lastResponse := time.Now()
var lastResponse int64
atomic.StoreInt64(&lastResponse, time.Now().Unix())

c.SetPingHandler(func(pingData string) error {
// Respond with Pong using the server's PING payload
Expand All @@ -100,7 +102,7 @@ func keepAlive(c *websocket.Conn, timeout time.Duration) {
return err
}

lastResponse = time.Now()
atomic.StoreInt64(&lastResponse, time.Now().Unix())

return nil
})
Expand All @@ -109,7 +111,7 @@ func keepAlive(c *websocket.Conn, timeout time.Duration) {
defer ticker.Stop()
for {
<-ticker.C
if time.Since(lastResponse) > timeout {
if time.Since(time.Unix(atomic.LoadInt64(&lastResponse), 0)) > timeout {
c.Close()
return
}
Expand Down
6 changes: 3 additions & 3 deletions v2/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,19 +83,19 @@ var wsServeWithConnHandler = func(cfg *WsConfig, handler WsHandler, errHandler E
// Wait for the stopC channel to be closed. We do that in a
// separate goroutine because ReadMessage is a blocking
// operation.
silent := false
var silent int32
go func() {
select {
case <-stopC:
silent = true
atomic.StoreInt32(&silent, 1)
case <-doneC:
}
c.Close()
}()
for {
_, message, err := c.ReadMessage()
if err != nil {
if !silent {
if atomic.LoadInt32(&silent) == 0 {
errHandler(err)
}
return
Expand Down