Skip to content

Commit 9bf3fd6

Browse files
committed
Use compare and swap.
1 parent b82a816 commit 9bf3fd6

File tree

1 file changed

+20
-20
lines changed

1 file changed

+20
-20
lines changed

server/session_ws.go

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ type sessionWS struct {
6767
runtime *Runtime
6868

6969
stopped bool
70-
closeSent *atomic.Bool
70+
closeSentCAS *atomic.Uint32
7171
closeWaitCh chan struct{}
7272
conn *websocket.Conn
7373
receivedMessageCounter int
@@ -122,6 +122,8 @@ func NewSessionWS(logger *zap.Logger, config Config, format SessionFormat, sessi
122122
runtime: runtime,
123123

124124
stopped: false,
125+
closeSentCAS: atomic.NewUint32(0),
126+
closeWaitCh: make(chan struct{}),
125127
conn: conn,
126128
receivedMessageCounter: config.GetSocket().PingBackoffThreshold,
127129
pingTimer: time.NewTimer(time.Duration(config.GetSocket().PingPeriodMs) * time.Millisecond),
@@ -190,15 +192,12 @@ func (s *sessionWS) Consume() {
190192
s.maybeResetPingTimer()
191193
return nil
192194
})
193-
s.closeSent = atomic.NewBool(false)
194-
s.closeWaitCh = make(chan struct{}, 1)
195-
// Disable the close handler so that the server can handle the close message itself.
196195
s.conn.SetCloseHandler(func(code int, text string) error {
197-
if !s.closeSent.Load() {
196+
if s.closeSentCAS.CompareAndSwap(0, 1) {
198197
message := websocket.FormatCloseMessage(code, "")
199198
_ = s.conn.WriteControl(websocket.CloseMessage, message, time.Now().Add(s.pongWaitDuration))
200199
} else {
201-
s.closeWaitCh <- struct{}{}
200+
close(s.closeWaitCh)
202201
}
203202
return nil
204203
})
@@ -525,20 +524,21 @@ func (s *sessionWS) Close(msg string, reason runtime.PresenceReason, envelopes .
525524
}
526525

527526
if msg != "" {
528-
// Server initiated close, attempt to send a close control message.
529-
reasonMsg := websocket.FormatCloseMessage(websocket.CloseNormalClosure, msg)
530-
if err := s.conn.WriteControl(websocket.CloseMessage, reasonMsg, time.Now().Add(s.writeWaitDuration)); err != nil {
531-
// This may not be possible if the socket was already fully closed by an error.
532-
s.logger.Debug("Could not send close message", zap.Error(err))
533-
} else {
534-
s.closeSent.Store(true)
535-
t := time.NewTimer(10 * time.Second)
536-
defer t.Stop()
537-
select {
538-
case <-s.closeWaitCh:
539-
s.logger.Debug("socket close ack received")
540-
case <-t.C:
541-
s.logger.Debug("socket close ack not received within 10 seconds")
527+
if s.closeSentCAS.CompareAndSwap(0, 1) {
528+
// Server initiated close, attempt to send a close control message.
529+
reasonMsg := websocket.FormatCloseMessage(websocket.CloseNormalClosure, msg)
530+
if err := s.conn.WriteControl(websocket.CloseMessage, reasonMsg, time.Now().Add(s.writeWaitDuration)); err != nil {
531+
// This may not be possible if the socket was already fully closed by an error.
532+
s.logger.Debug("Could not send close message", zap.Error(err))
533+
} else {
534+
t := time.NewTimer(10 * time.Second)
535+
defer t.Stop()
536+
select {
537+
case <-s.closeWaitCh:
538+
s.logger.Debug("socket close ack received")
539+
case <-t.C:
540+
s.logger.Debug("socket close ack not received within 10 seconds")
541+
}
542542
}
543543
}
544544
}

0 commit comments

Comments
 (0)