From 81ddee55c64f34621e2b25f6d37ded4bac0e778e Mon Sep 17 00:00:00 2001 From: UlyanaAndrukhiv Date: Wed, 20 Nov 2024 14:03:01 +0200 Subject: [PATCH 01/15] Added Websocket connection configurating --- engine/access/rest/websockets/controller.go | 40 ++++++++++++++++++++- 1 file changed, 39 insertions(+), 1 deletion(-) diff --git a/engine/access/rest/websockets/controller.go b/engine/access/rest/websockets/controller.go index fe873f5f61c..f2f9e761a25 100644 --- a/engine/access/rest/websockets/controller.go +++ b/engine/access/rest/websockets/controller.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "time" "github.com/google/uuid" "github.com/gorilla/websocket" @@ -16,6 +17,14 @@ import ( "github.com/onflow/flow-go/utils/concurrentmap" ) +const ( + // Time allowed to read the next pong message from the peer. + pongWait = 10 * time.Second + + // Time allowed to write a message to the peer. + writeWait = 10 * time.Second +) + type Controller struct { logger zerolog.Logger config Config @@ -44,12 +53,41 @@ func NewWebSocketController( // HandleConnection manages the WebSocket connection, adding context and error handling. func (c *Controller) HandleConnection(ctx context.Context) { - //TODO: configure the connection with ping-pong and deadlines + // configuring the connection with appropriate read/write deadlines and handlers. + err := c.configureConnection() + if err != nil { + // TODO: add error handling here + c.logger.Error().Err(err).Msg("error configuring connection") + c.shutdownConnection() + return + } + //TODO: spin up a response limit tracker routine go c.readMessagesFromClient(ctx) c.writeMessagesToClient(ctx) } +// configureConnection used to set read and write deadlines for WebSocket connections and establishes a Pong handler to +// manage incoming Pong messages. These methods allow to specify a time limit for reading from or writing to a WebSocket +// connection. If the operation (reading or writing) takes longer than the specified deadline, the connection will be closed. +func (c *Controller) configureConnection() error { + // Set the initial write deadline for the first ping message + if err := c.conn.SetWriteDeadline(time.Now().Add(writeWait)); err != nil { + return fmt.Errorf("failed to set the initial write deadline: %w", err) + } + // Set the initial read deadline for the first pong message + if err := c.conn.SetReadDeadline(time.Now().Add(pongWait)); err != nil { + return fmt.Errorf("failed to set the initial read deadline: %w", err) + } + + // Establish a Pong handler + c.conn.SetPongHandler(func(string) error { + return c.conn.SetReadDeadline(time.Now().Add(pongWait)) + }) + + return nil +} + // writeMessagesToClient reads a messages from communication channel and passes them on to a client WebSocket connection. // The communication channel is filled by data providers. Besides, the response limit tracker is involved in // write message regulation From 808b54ba6c0e8d85f53081a02cc2785669debf37 Mon Sep 17 00:00:00 2001 From: UlyanaAndrukhiv Date: Thu, 21 Nov 2024 12:05:47 +0200 Subject: [PATCH 02/15] Updated configureConnection and godoc --- engine/access/rest/websockets/controller.go | 24 ++++++++++----------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/engine/access/rest/websockets/controller.go b/engine/access/rest/websockets/controller.go index f2f9e761a25..51423fe2606 100644 --- a/engine/access/rest/websockets/controller.go +++ b/engine/access/rest/websockets/controller.go @@ -67,20 +67,18 @@ func (c *Controller) HandleConnection(ctx context.Context) { c.writeMessagesToClient(ctx) } -// configureConnection used to set read and write deadlines for WebSocket connections and establishes a Pong handler to -// manage incoming Pong messages. These methods allow to specify a time limit for reading from or writing to a WebSocket -// connection. If the operation (reading or writing) takes longer than the specified deadline, the connection will be closed. +// configureConnection configures the WebSocket connection by setting up a Pong handler +// to handle incoming Pong messages and update the read deadline accordingly. +// +// The Pong handler resets the read deadline whenever a Pong message is received from the peer. +// This mechanism ensures the connection remains active as long as the peer responds to periodic pings. +// +// Note: The default value for the read deadline in Gorilla WebSockets is 0, which means +// no deadline is set unless explicitly configured. Without a read deadline, the connection +// will remain open indefinitely if the client keeps the connection open without sending any messages unless explicitly +// closed by either the server or the client. func (c *Controller) configureConnection() error { - // Set the initial write deadline for the first ping message - if err := c.conn.SetWriteDeadline(time.Now().Add(writeWait)); err != nil { - return fmt.Errorf("failed to set the initial write deadline: %w", err) - } - // Set the initial read deadline for the first pong message - if err := c.conn.SetReadDeadline(time.Now().Add(pongWait)); err != nil { - return fmt.Errorf("failed to set the initial read deadline: %w", err) - } - - // Establish a Pong handler + // Establish a Pong handler which sets the handler for pong messages received from the peer. c.conn.SetPongHandler(func(string) error { return c.conn.SetReadDeadline(time.Now().Add(pongWait)) }) From 6c5ab5dd1be8a2b066812421bb2f7c2fa297dad0 Mon Sep 17 00:00:00 2001 From: UlyanaAndrukhiv Date: Thu, 21 Nov 2024 12:42:58 +0200 Subject: [PATCH 03/15] Adedd SetWriteDeadline before write operation --- engine/access/rest/websockets/controller.go | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/engine/access/rest/websockets/controller.go b/engine/access/rest/websockets/controller.go index 51423fe2606..8fd0406a6c6 100644 --- a/engine/access/rest/websockets/controller.go +++ b/engine/access/rest/websockets/controller.go @@ -98,6 +98,14 @@ func (c *Controller) writeMessagesToClient(ctx context.Context) { case msg := <-c.communicationChannel: // TODO: handle 'response per second' limits + // Specifies a timeout for the write operation. If the write + // isn't completed within this duration, it fails with a timeout error. + // SetWriteDeadline ensures the write operation does not block indefinitely + // if the client is slow or unresponsive. This prevents resource exhaustion + // and allows the server to gracefully handle timeouts for delayed writes. + if err := c.conn.SetWriteDeadline(time.Now().Add(writeWait)); err != nil { + c.logger.Error().Err(err).Msg("failed to set the write deadline") + } err := c.conn.WriteJSON(msg) if err != nil { c.logger.Error().Err(err).Msg("error writing to connection") From eec15e5e356406eac9a343dfc1cb134a3111cb0e Mon Sep 17 00:00:00 2001 From: UlyanaAndrukhiv Date: Thu, 21 Nov 2024 12:58:10 +0200 Subject: [PATCH 04/15] Set initital read deadline, updated godoc --- engine/access/rest/websockets/controller.go | 25 +++++++++++++-------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/engine/access/rest/websockets/controller.go b/engine/access/rest/websockets/controller.go index 8fd0406a6c6..5be5a9af318 100644 --- a/engine/access/rest/websockets/controller.go +++ b/engine/access/rest/websockets/controller.go @@ -67,17 +67,24 @@ func (c *Controller) HandleConnection(ctx context.Context) { c.writeMessagesToClient(ctx) } -// configureConnection configures the WebSocket connection by setting up a Pong handler -// to handle incoming Pong messages and update the read deadline accordingly. +// configureConnection sets up the WebSocket connection with a read deadline +// and a handler for receiving pong messages from the client. // -// The Pong handler resets the read deadline whenever a Pong message is received from the peer. -// This mechanism ensures the connection remains active as long as the peer responds to periodic pings. -// -// Note: The default value for the read deadline in Gorilla WebSockets is 0, which means -// no deadline is set unless explicitly configured. Without a read deadline, the connection -// will remain open indefinitely if the client keeps the connection open without sending any messages unless explicitly -// closed by either the server or the client. +// The function does the following: +// 1. Sets an initial read deadline to ensure the server doesn't wait indefinitely +// for a pong message from the client. If no message is received within the +// specified `pongWait` duration, the connection will be closed. +// 2. Establishes a Pong handler that resets the read deadline every time a pong +// message is received from the client, allowing the server to continue waiting +// for further pong messages within the new deadline. func (c *Controller) configureConnection() error { + // Set the initial read deadline for the first pong message + // The Pong handler itself only resets the read deadline after receiving a Pong. + // It doesn't set an initial deadline. The initial read deadline is crucial to prevent the server from waiting + // forever if the client doesn't send Pongs. + if err := c.conn.SetReadDeadline(time.Now().Add(pongWait)); err != nil { + return fmt.Errorf("failed to set the initial read deadline: %w", err) + } // Establish a Pong handler which sets the handler for pong messages received from the peer. c.conn.SetPongHandler(func(string) error { return c.conn.SetReadDeadline(time.Now().Add(pongWait)) From 917bbde350c46f7e5ddee8cd2186ad3f896a3ea8 Mon Sep 17 00:00:00 2001 From: UlyanaAndrukhiv Date: Fri, 22 Nov 2024 18:47:03 +0200 Subject: [PATCH 05/15] Implemented ping-pong ws routine, refactored shutdownConnection --- engine/access/rest/websockets/controller.go | 152 +++++++++++++++--- .../websockets/legacy/websocket_handler.go | 6 +- 2 files changed, 131 insertions(+), 27 deletions(-) diff --git a/engine/access/rest/websockets/controller.go b/engine/access/rest/websockets/controller.go index 5be5a9af318..4d206b871cd 100644 --- a/engine/access/rest/websockets/controller.go +++ b/engine/access/rest/websockets/controller.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "sync" "time" "github.com/google/uuid" @@ -18,6 +19,10 @@ import ( ) const ( + // PingPeriod defines the interval at which ping messages are sent to the client. + // This value must be less than pongWait. + PingPeriod = (pongWait * 9) / 10 + // Time allowed to read the next pong message from the peer. pongWait = 10 * time.Second @@ -30,8 +35,12 @@ type Controller struct { config Config conn *websocket.Conn communicationChannel chan interface{} + errorChannel chan error dataProviders *concurrentmap.Map[uuid.UUID, dp.DataProvider] dataProvidersFactory *dp.Factory + + shutdownOnce sync.Once // Ensures shutdown is only called once + shutdown bool } func NewWebSocketController( @@ -46,13 +55,19 @@ func NewWebSocketController( config: config, conn: conn, communicationChannel: make(chan interface{}), //TODO: should it be buffered chan? + errorChannel: make(chan error, 1), dataProviders: concurrentmap.New[uuid.UUID, dp.DataProvider](), dataProvidersFactory: dp.NewDataProviderFactory(logger, streamApi, streamConfig), } } -// HandleConnection manages the WebSocket connection, adding context and error handling. +// HandleConnection manages the lifecycle of a WebSocket connection, +// including setup, message processing, and graceful shutdown. +// +// Parameters: +// - ctx: The context for controlling cancellation and timeouts. func (c *Controller) HandleConnection(ctx context.Context) { + defer close(c.errorChannel) // configuring the connection with appropriate read/write deadlines and handlers. err := c.configureConnection() if err != nil { @@ -63,8 +78,54 @@ func (c *Controller) HandleConnection(ctx context.Context) { } //TODO: spin up a response limit tracker routine - go c.readMessagesFromClient(ctx) - c.writeMessagesToClient(ctx) + + // for track all goroutines and error handling + var wg sync.WaitGroup + + c.startProcess(&wg, ctx, c.readMessagesFromClient) + c.startProcess(&wg, ctx, c.keepalive) + c.startProcess(&wg, ctx, c.writeMessagesToClient) + + select { + case err := <-c.errorChannel: + c.logger.Error().Err(err).Msg("error detected in one of the goroutines") + //TODO: add error handling here + c.shutdownConnection() + case <-ctx.Done(): + // Context canceled, shut down gracefully + c.shutdownConnection() + } + + // Wait for all goroutines + wg.Wait() +} + +// startProcess is a helper function to start a goroutine for a given process +// and ensure it is tracked via a sync.WaitGroup. +// +// Parameters: +// - wg: The wait group to track goroutines. +// - ctx: The context for cancellation. +// - process: The function to run in a new goroutine. +// +// No errors are expected during normal operation. +func (c *Controller) startProcess(wg *sync.WaitGroup, ctx context.Context, process func(context.Context) error) { + wg.Add(1) + + go func() { + defer wg.Done() + + err := process(ctx) + if err != nil { + // Check if shutdown has already been called, to avoid multiple shutdowns + if c.shutdown { + c.logger.Warn().Err(err).Msg("error detected after shutdown initiated, ignoring") + return + } + + c.errorChannel <- err + } + }() } // configureConnection sets up the WebSocket connection with a read deadline @@ -96,12 +157,14 @@ func (c *Controller) configureConnection() error { // writeMessagesToClient reads a messages from communication channel and passes them on to a client WebSocket connection. // The communication channel is filled by data providers. Besides, the response limit tracker is involved in // write message regulation -func (c *Controller) writeMessagesToClient(ctx context.Context) { +// +// No errors are expected during normal operation. +func (c *Controller) writeMessagesToClient(ctx context.Context) error { //TODO: can it run forever? maybe we should cancel the ctx in the reader routine for { select { case <-ctx.Done(): - return + return nil case msg := <-c.communicationChannel: // TODO: handle 'response per second' limits @@ -112,10 +175,12 @@ func (c *Controller) writeMessagesToClient(ctx context.Context) { // and allows the server to gracefully handle timeouts for delayed writes. if err := c.conn.SetWriteDeadline(time.Now().Add(writeWait)); err != nil { c.logger.Error().Err(err).Msg("failed to set the write deadline") + return err } err := c.conn.WriteJSON(msg) if err != nil { c.logger.Error().Err(err).Msg("error writing to connection") + return err } } } @@ -123,32 +188,33 @@ func (c *Controller) writeMessagesToClient(ctx context.Context) { // readMessagesFromClient continuously reads messages from a client WebSocket connection, // processes each message, and handles actions based on the message type. -func (c *Controller) readMessagesFromClient(ctx context.Context) { - defer c.shutdownConnection() - +// +// No errors are expected during normal operation. +func (c *Controller) readMessagesFromClient(ctx context.Context) error { for { select { case <-ctx.Done(): c.logger.Info().Msg("context canceled, stopping read message loop") - return + return nil default: msg, err := c.readMessage() if err != nil { if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseAbnormalClosure) { - return + return nil } c.logger.Warn().Err(err).Msg("error reading message from client") - return + return err } baseMsg, validatedMsg, err := c.parseAndValidateMessage(msg) if err != nil { c.logger.Debug().Err(err).Msg("error parsing and validating client message") - return + return err } if err := c.handleAction(ctx, validatedMsg); err != nil { c.logger.Warn().Err(err).Str("action", baseMsg.Action).Msg("error handling action") + return err } } } @@ -244,20 +310,60 @@ func (c *Controller) handleListSubscriptions(ctx context.Context, msg models.Lis } func (c *Controller) shutdownConnection() { - defer close(c.communicationChannel) - defer func(conn *websocket.Conn) { - if err := c.conn.Close(); err != nil { - c.logger.Error().Err(err).Msg("error closing connection") + c.shutdownOnce.Do(func() { + c.shutdown = true + + defer close(c.communicationChannel) + defer func(conn *websocket.Conn) { + if err := c.conn.Close(); err != nil { + c.logger.Error().Err(err).Msg("error closing connection") + } + }(c.conn) + + err := c.dataProviders.ForEach(func(_ uuid.UUID, dp dp.DataProvider) error { + dp.Close() + return nil + }) + if err != nil { + c.logger.Error().Err(err).Msg("error closing data provider") } - }(c.conn) - err := c.dataProviders.ForEach(func(_ uuid.UUID, dp dp.DataProvider) error { - dp.Close() - return nil + c.dataProviders.Clear() }) - if err != nil { - c.logger.Error().Err(err).Msg("error closing data provider") +} + +// keepalive sends a ping message periodically to keep the WebSocket connection alive +// and avoid timeouts. +// +// No errors are expected during normal operation. +func (c *Controller) keepalive(ctx context.Context) error { + pingTicker := time.NewTicker(PingPeriod) + defer pingTicker.Stop() + + for { + select { + case <-ctx.Done(): + // return ctx.Err() + return nil + case <-pingTicker.C: + if err := c.sendPing(); err != nil { + // Log error and exit the loop on failure + c.logger.Error().Err(err).Msg("failed to send ping") + return err + } + } + } +} + +// sendPing sends a periodic ping message to the WebSocket client to keep the connection alive. +func (c *Controller) sendPing() error { + if err := c.conn.SetWriteDeadline(time.Now().Add(writeWait)); err != nil { + return fmt.Errorf("failed to set the write deadline for ping: %w", err) } - c.dataProviders.Clear() + if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil { + return fmt.Errorf("failed to write ping message: %w", err) + } + + return nil } diff --git a/engine/access/rest/websockets/legacy/websocket_handler.go b/engine/access/rest/websockets/legacy/websocket_handler.go index 7132314b16c..a464cb29cb5 100644 --- a/engine/access/rest/websockets/legacy/websocket_handler.go +++ b/engine/access/rest/websockets/legacy/websocket_handler.go @@ -12,6 +12,7 @@ import ( "go.uber.org/atomic" "github.com/onflow/flow-go/engine/access/rest/common" + "github.com/onflow/flow-go/engine/access/rest/websockets" "github.com/onflow/flow-go/engine/access/state_stream" "github.com/onflow/flow-go/engine/access/state_stream/backend" "github.com/onflow/flow-go/engine/access/subscription" @@ -23,9 +24,6 @@ const ( // Time allowed to read the next pong message from the peer. pongWait = 10 * time.Second - // Send pings to peer with this period. Must be less than pongWait. - pingPeriod = (pongWait * 9) / 10 - // Time allowed to write a message to the peer. writeWait = 10 * time.Second ) @@ -111,7 +109,7 @@ func (wsController *WebsocketController) wsErrorHandler(err error) { // If an error occurs or the subscription channel is closed, it handles the error or termination accordingly. // The function uses a ticker to periodically send ping messages to the client to maintain the connection. func (wsController *WebsocketController) writeEvents(sub subscription.Subscription) { - ticker := time.NewTicker(pingPeriod) + ticker := time.NewTicker(websockets.PingPeriod) defer ticker.Stop() blocksSinceLastMessage := uint64(0) From ec4e2473b9ee8e096afc9f10ac9d73ee61d198ed Mon Sep 17 00:00:00 2001 From: UlyanaAndrukhiv Date: Mon, 25 Nov 2024 13:00:00 +0200 Subject: [PATCH 06/15] Added more comments and updated godoc --- engine/access/rest/websockets/controller.go | 23 ++++++++++++--------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/engine/access/rest/websockets/controller.go b/engine/access/rest/websockets/controller.go index 4d206b871cd..e0d93a93f34 100644 --- a/engine/access/rest/websockets/controller.go +++ b/engine/access/rest/websockets/controller.go @@ -31,16 +31,18 @@ const ( ) type Controller struct { - logger zerolog.Logger - config Config - conn *websocket.Conn - communicationChannel chan interface{} - errorChannel chan error + logger zerolog.Logger + config Config + conn *websocket.Conn + + communicationChannel chan interface{} // Channel for sending messages to the client. + errorChannel chan error // Channel for reporting errors. + dataProviders *concurrentmap.Map[uuid.UUID, dp.DataProvider] dataProvidersFactory *dp.Factory shutdownOnce sync.Once // Ensures shutdown is only called once - shutdown bool + shutdown bool // Indicates if the controller is shutting down. } func NewWebSocketController( @@ -55,7 +57,7 @@ func NewWebSocketController( config: config, conn: conn, communicationChannel: make(chan interface{}), //TODO: should it be buffered chan? - errorChannel: make(chan error, 1), + errorChannel: make(chan error, 1), // Buffered error channel to hold one error. dataProviders: concurrentmap.New[uuid.UUID, dp.DataProvider](), dataProvidersFactory: dp.NewDataProviderFactory(logger, streamApi, streamConfig), } @@ -86,6 +88,7 @@ func (c *Controller) HandleConnection(ctx context.Context) { c.startProcess(&wg, ctx, c.keepalive) c.startProcess(&wg, ctx, c.writeMessagesToClient) + // Wait for context cancellation or errors from goroutines. select { case err := <-c.errorChannel: c.logger.Error().Err(err).Msg("error detected in one of the goroutines") @@ -96,7 +99,7 @@ func (c *Controller) HandleConnection(ctx context.Context) { c.shutdownConnection() } - // Wait for all goroutines + // Ensure all goroutines finish execution. wg.Wait() } @@ -160,7 +163,6 @@ func (c *Controller) configureConnection() error { // // No errors are expected during normal operation. func (c *Controller) writeMessagesToClient(ctx context.Context) error { - //TODO: can it run forever? maybe we should cancel the ctx in the reader routine for { select { case <-ctx.Done(): @@ -343,7 +345,6 @@ func (c *Controller) keepalive(ctx context.Context) error { for { select { case <-ctx.Done(): - // return ctx.Err() return nil case <-pingTicker.C: if err := c.sendPing(); err != nil { @@ -356,6 +357,8 @@ func (c *Controller) keepalive(ctx context.Context) error { } // sendPing sends a periodic ping message to the WebSocket client to keep the connection alive. +// +// No errors are expected during normal operation. func (c *Controller) sendPing() error { if err := c.conn.SetWriteDeadline(time.Now().Add(writeWait)); err != nil { return fmt.Errorf("failed to set the write deadline for ping: %w", err) From eae6bbf1d52b63c94ea003eb1a43f8369fea5ba7 Mon Sep 17 00:00:00 2001 From: UlyanaAndrukhiv Date: Mon, 25 Nov 2024 14:17:53 +0200 Subject: [PATCH 07/15] Moved constants to new websockets package according to comment --- engine/access/rest/websockets/controller.go | 15 ++++++++---- .../websockets/legacy/websocket_handler.go | 24 ++++++------------- 2 files changed, 17 insertions(+), 22 deletions(-) diff --git a/engine/access/rest/websockets/controller.go b/engine/access/rest/websockets/controller.go index 5be5a9af318..2e617d88f24 100644 --- a/engine/access/rest/websockets/controller.go +++ b/engine/access/rest/websockets/controller.go @@ -18,11 +18,16 @@ import ( ) const ( + + // PingPeriod defines the interval at which ping messages are sent to the client. + // This value must be less than pongWait. + PingPeriod = (PongWait * 9) / 10 + // Time allowed to read the next pong message from the peer. - pongWait = 10 * time.Second + PongWait = 10 * time.Second // Time allowed to write a message to the peer. - writeWait = 10 * time.Second + WriteWait = 10 * time.Second ) type Controller struct { @@ -82,12 +87,12 @@ func (c *Controller) configureConnection() error { // The Pong handler itself only resets the read deadline after receiving a Pong. // It doesn't set an initial deadline. The initial read deadline is crucial to prevent the server from waiting // forever if the client doesn't send Pongs. - if err := c.conn.SetReadDeadline(time.Now().Add(pongWait)); err != nil { + if err := c.conn.SetReadDeadline(time.Now().Add(PongWait)); err != nil { return fmt.Errorf("failed to set the initial read deadline: %w", err) } // Establish a Pong handler which sets the handler for pong messages received from the peer. c.conn.SetPongHandler(func(string) error { - return c.conn.SetReadDeadline(time.Now().Add(pongWait)) + return c.conn.SetReadDeadline(time.Now().Add(PongWait)) }) return nil @@ -110,7 +115,7 @@ func (c *Controller) writeMessagesToClient(ctx context.Context) { // SetWriteDeadline ensures the write operation does not block indefinitely // if the client is slow or unresponsive. This prevents resource exhaustion // and allows the server to gracefully handle timeouts for delayed writes. - if err := c.conn.SetWriteDeadline(time.Now().Add(writeWait)); err != nil { + if err := c.conn.SetWriteDeadline(time.Now().Add(WriteWait)); err != nil { c.logger.Error().Err(err).Msg("failed to set the write deadline") } err := c.conn.WriteJSON(msg) diff --git a/engine/access/rest/websockets/legacy/websocket_handler.go b/engine/access/rest/websockets/legacy/websocket_handler.go index 7132314b16c..06aa8323de4 100644 --- a/engine/access/rest/websockets/legacy/websocket_handler.go +++ b/engine/access/rest/websockets/legacy/websocket_handler.go @@ -12,6 +12,7 @@ import ( "go.uber.org/atomic" "github.com/onflow/flow-go/engine/access/rest/common" + "github.com/onflow/flow-go/engine/access/rest/websockets" "github.com/onflow/flow-go/engine/access/state_stream" "github.com/onflow/flow-go/engine/access/state_stream/backend" "github.com/onflow/flow-go/engine/access/subscription" @@ -19,17 +20,6 @@ import ( "github.com/onflow/flow-go/model/flow" ) -const ( - // Time allowed to read the next pong message from the peer. - pongWait = 10 * time.Second - - // Send pings to peer with this period. Must be less than pongWait. - pingPeriod = (pongWait * 9) / 10 - - // Time allowed to write a message to the peer. - writeWait = 10 * time.Second -) - // WebsocketController holds the necessary components and parameters for handling a WebSocket subscription. // It manages the communication between the server and the WebSocket client for subscribing. type WebsocketController struct { @@ -47,17 +37,17 @@ type WebsocketController struct { // manage incoming Pong messages. These methods allow to specify a time limit for reading from or writing to a WebSocket // connection. If the operation (reading or writing) takes longer than the specified deadline, the connection will be closed. func (wsController *WebsocketController) SetWebsocketConf() error { - err := wsController.conn.SetWriteDeadline(time.Now().Add(writeWait)) // Set the initial write deadline for the first ping message + err := wsController.conn.SetWriteDeadline(time.Now().Add(websockets.WriteWait)) // Set the initial write deadline for the first ping message if err != nil { return common.NewRestError(http.StatusInternalServerError, "Set the initial write deadline error: ", err) } - err = wsController.conn.SetReadDeadline(time.Now().Add(pongWait)) // Set the initial read deadline for the first pong message + err = wsController.conn.SetReadDeadline(time.Now().Add(websockets.PongWait)) // Set the initial read deadline for the first pong message if err != nil { return common.NewRestError(http.StatusInternalServerError, "Set the initial read deadline error: ", err) } // Establish a Pong handler wsController.conn.SetPongHandler(func(string) error { - err := wsController.conn.SetReadDeadline(time.Now().Add(pongWait)) + err := wsController.conn.SetReadDeadline(time.Now().Add(websockets.PongWait)) if err != nil { return err } @@ -111,7 +101,7 @@ func (wsController *WebsocketController) wsErrorHandler(err error) { // If an error occurs or the subscription channel is closed, it handles the error or termination accordingly. // The function uses a ticker to periodically send ping messages to the client to maintain the connection. func (wsController *WebsocketController) writeEvents(sub subscription.Subscription) { - ticker := time.NewTicker(pingPeriod) + ticker := time.NewTicker(websockets.PingPeriod) defer ticker.Stop() blocksSinceLastMessage := uint64(0) @@ -137,7 +127,7 @@ func (wsController *WebsocketController) writeEvents(sub subscription.Subscripti wsController.wsErrorHandler(common.NewRestError(http.StatusRequestTimeout, "subscription channel closed", err)) return } - err := wsController.conn.SetWriteDeadline(time.Now().Add(writeWait)) + err := wsController.conn.SetWriteDeadline(time.Now().Add(websockets.WriteWait)) if err != nil { wsController.wsErrorHandler(common.NewRestError(http.StatusInternalServerError, "failed to set the initial write deadline: ", err)) return @@ -178,7 +168,7 @@ func (wsController *WebsocketController) writeEvents(sub subscription.Subscripti return } case <-ticker.C: - err := wsController.conn.SetWriteDeadline(time.Now().Add(writeWait)) + err := wsController.conn.SetWriteDeadline(time.Now().Add(websockets.WriteWait)) if err != nil { wsController.wsErrorHandler(common.NewRestError(http.StatusInternalServerError, "failed to set the initial write deadline: ", err)) return From c90d75f30ce29172f2eb69ebf372ddfe245bd480 Mon Sep 17 00:00:00 2001 From: UlyanaAndrukhiv Date: Tue, 26 Nov 2024 16:16:36 +0200 Subject: [PATCH 08/15] Updated according to comments, added unit tests for ping-pong functionality --- Makefile | 1 + engine/access/rest/websockets/connections.go | 57 +++++++ engine/access/rest/websockets/controller.go | 21 ++- .../access/rest/websockets/controller_test.go | 106 +++++++++++++ engine/access/rest/websockets/handler.go | 2 +- .../websockets/mock/websocket_connection.go | 141 ++++++++++++++++++ 6 files changed, 319 insertions(+), 9 deletions(-) create mode 100644 engine/access/rest/websockets/connections.go create mode 100644 engine/access/rest/websockets/controller_test.go create mode 100644 engine/access/rest/websockets/mock/websocket_connection.go diff --git a/Makefile b/Makefile index 2578fffe4b6..53ea58fc52d 100644 --- a/Makefile +++ b/Makefile @@ -204,6 +204,7 @@ generate-mocks: install-mock-generators mockery --name '.*' --dir="./engine/access/state_stream" --case=underscore --output="./engine/access/state_stream/mock" --outpkg="mock" mockery --name 'BlockTracker' --dir="./engine/access/subscription" --case=underscore --output="./engine/access/subscription/mock" --outpkg="mock" mockery --name 'DataProvider' --dir="./engine/access/rest/websockets/data_provider" --case=underscore --output="./engine/access/rest/websockets/data_provider/mock" --outpkg="mock" + mockery --name 'WebsocketConnection' --dir="./engine/access/rest/websockets" --case=underscore --output="./engine/access/rest/websockets/mock" --outpkg="mock" mockery --name 'ExecutionDataTracker' --dir="./engine/access/subscription" --case=underscore --output="./engine/access/subscription/mock" --outpkg="mock" mockery --name 'ConnectionFactory' --dir="./engine/access/rpc/connection" --case=underscore --output="./engine/access/rpc/connection/mock" --outpkg="mock" mockery --name 'Communicator' --dir="./engine/access/rpc/backend" --case=underscore --output="./engine/access/rpc/backend/mock" --outpkg="mock" diff --git a/engine/access/rest/websockets/connections.go b/engine/access/rest/websockets/connections.go new file mode 100644 index 00000000000..7421d3bd8ec --- /dev/null +++ b/engine/access/rest/websockets/connections.go @@ -0,0 +1,57 @@ +package websockets + +import ( + "time" + + "github.com/gorilla/websocket" +) + +type WebsocketConnection interface { + ReadJSON(v interface{}) error + WriteJSON(v interface{}) error + WriteMessage(int, []byte) error + Close() error + SetReadDeadline(time.Time) error + SetWriteDeadline(time.Time) error + SetPongHandler(func(string) error) +} + +type WebsocketConnectionImpl struct { + conn *websocket.Conn +} + +func NewWebsocketConnection(conn *websocket.Conn) *WebsocketConnectionImpl { + return &WebsocketConnectionImpl{ + conn: conn, + } +} + +var _ WebsocketConnection = (*WebsocketConnectionImpl)(nil) + +func (c *WebsocketConnectionImpl) ReadJSON(v interface{}) error { + return c.conn.ReadJSON(v) +} + +func (c *WebsocketConnectionImpl) WriteJSON(v interface{}) error { + return c.conn.WriteJSON(v) +} + +func (c *WebsocketConnectionImpl) WriteMessage(messageType int, data []byte) error { + return c.conn.WriteMessage(messageType, data) +} + +func (c *WebsocketConnectionImpl) Close() error { + return c.conn.Close() +} + +func (c *WebsocketConnectionImpl) SetReadDeadline(t time.Time) error { + return c.conn.SetReadDeadline(t) +} + +func (c *WebsocketConnectionImpl) SetWriteDeadline(t time.Time) error { + return c.conn.SetWriteDeadline(t) +} + +func (c *WebsocketConnectionImpl) SetPongHandler(h func(string) error) { + c.conn.SetPongHandler(h) +} diff --git a/engine/access/rest/websockets/controller.go b/engine/access/rest/websockets/controller.go index 5a6c4494062..4edff1ab282 100644 --- a/engine/access/rest/websockets/controller.go +++ b/engine/access/rest/websockets/controller.go @@ -33,7 +33,7 @@ const ( type Controller struct { logger zerolog.Logger config Config - conn *websocket.Conn + conn WebsocketConnection communicationChannel chan interface{} // Channel for sending messages to the client. errorChannel chan error // Channel for reporting errors. @@ -50,7 +50,7 @@ func NewWebSocketController( config Config, streamApi state_stream.API, streamConfig backend.Config, - conn *websocket.Conn, + conn WebsocketConnection, ) *Controller { return &Controller{ logger: logger.With().Str("component", "websocket-controller").Logger(), @@ -90,7 +90,12 @@ func (c *Controller) HandleConnection(ctx context.Context) { // Wait for context cancellation or errors from goroutines. select { - case err := <-c.errorChannel: + case err, ok := <-c.errorChannel: + if !ok { + c.logger.Error().Msg("error channel closed") + //TODO: add error handling here + return + } c.logger.Error().Err(err).Msg("error detected in one of the goroutines") //TODO: add error handling here c.shutdownConnection() @@ -166,7 +171,7 @@ func (c *Controller) writeMessagesToClient(ctx context.Context) error { for { select { case <-ctx.Done(): - return nil + return ctx.Err() case msg := <-c.communicationChannel: // TODO: handle 'response per second' limits @@ -197,7 +202,7 @@ func (c *Controller) readMessagesFromClient(ctx context.Context) error { select { case <-ctx.Done(): c.logger.Info().Msg("context canceled, stopping read message loop") - return nil + return ctx.Err() default: msg, err := c.readMessage() if err != nil { @@ -315,12 +320,12 @@ func (c *Controller) shutdownConnection() { c.shutdownOnce.Do(func() { c.shutdown = true - defer close(c.communicationChannel) - defer func(conn *websocket.Conn) { + defer func(conn WebsocketConnection) { if err := c.conn.Close(); err != nil { c.logger.Error().Err(err).Msg("error closing connection") } }(c.conn) + close(c.communicationChannel) err := c.dataProviders.ForEach(func(_ uuid.UUID, dp dp.DataProvider) error { dp.Close() @@ -345,7 +350,7 @@ func (c *Controller) keepalive(ctx context.Context) error { for { select { case <-ctx.Done(): - return nil + return ctx.Err() case <-pingTicker.C: if err := c.sendPing(); err != nil { // Log error and exit the loop on failure diff --git a/engine/access/rest/websockets/controller_test.go b/engine/access/rest/websockets/controller_test.go new file mode 100644 index 00000000000..0b04ba1eaa7 --- /dev/null +++ b/engine/access/rest/websockets/controller_test.go @@ -0,0 +1,106 @@ +package websockets + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/gorilla/websocket" + "github.com/rs/zerolog" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + + connectionmock "github.com/onflow/flow-go/engine/access/rest/websockets/mock" +) + +type ControllerSuite struct { + suite.Suite + + connection *connectionmock.WebsocketConnection + controller *Controller +} + +func TestControllerSuite(t *testing.T) { + suite.Run(t, new(ControllerSuite)) +} + +// SetupTest initializes the test suite with required dependencies. +func (s *ControllerSuite) SetupTest() { + s.connection = connectionmock.NewWebsocketConnection(s.T()) + + // Create the controller + log := zerolog.New(zerolog.NewConsoleWriter()) + config := Config{} + s.controller = &Controller{ + logger: log, + config: config, + conn: s.connection, + communicationChannel: make(chan interface{}), + errorChannel: make(chan error, 1), + } +} + +// Helper function to start the keepalive process. +func (s *ControllerSuite) startKeepalive(ctx context.Context, expectedError error) { + go func() { + err := s.controller.keepalive(ctx) + if expectedError != nil { + s.Require().Error(err) + s.Require().Equal(expectedError, err) + } else { + s.Require().NoError(err) + } + }() +} + +// Helper function to setup mock behavior for SetWriteDeadline and WriteMessage. +func (s *ControllerSuite) setupMockConnection(writeMessageError error) { + s.connection.On("SetWriteDeadline", mock.Anything).Return(nil).Once() + s.connection.On("WriteMessage", websocket.PingMessage, mock.Anything).Return(writeMessageError).Once() +} + +// Helper function to wait for expected mock calls. +func (s *ControllerSuite) waitForMockCalls(expectedCalls int, timeout time.Duration, interval time.Duration, errorMessage string) { + require.Eventually(s.T(), func() bool { + return len(s.connection.Calls) == expectedCalls + }, timeout, interval, errorMessage) +} + +// TestKeepaliveError tests the behavior of the keepalive function when there is an error in writing the ping. +func (s *ControllerSuite) TestKeepaliveError() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Setup the mock connection with an error + expectedError := fmt.Errorf("failed to write ping message: %w", assert.AnError) + s.setupMockConnection(assert.AnError) + + // Start the keepalive process + s.startKeepalive(ctx, expectedError) + + // Wait for the ping message or timeout + expectedCalls := 2 + s.waitForMockCalls(expectedCalls, PongWait*3/2, 1*time.Second, "ping message was not sent") + + // Assert expectations + s.connection.AssertExpectations(s.T()) +} + +// TestKeepaliveContextCancel tests the behavior of keepalive when the context is canceled before a ping is sent and +// no ping message is sent after the context is canceled. +func (s *ControllerSuite) TestKeepaliveContextCancel() { + ctx, cancel := context.WithCancel(context.Background()) + cancel() // Immediately cancel the context + + // Start the keepalive process with the context canceled + s.startKeepalive(ctx, context.Canceled) + + // Wait for the timeout to ensure no ping is sent + time.Sleep(PongWait) + + // Assert expectations + s.connection.AssertExpectations(s.T()) // Should not invoke WriteMessage after context cancellation +} diff --git a/engine/access/rest/websockets/handler.go b/engine/access/rest/websockets/handler.go index 247890c2a62..c7acb46e506 100644 --- a/engine/access/rest/websockets/handler.go +++ b/engine/access/rest/websockets/handler.go @@ -65,6 +65,6 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - controller := NewWebSocketController(logger, h.websocketConfig, h.streamApi, h.streamConfig, conn) + controller := NewWebSocketController(logger, h.websocketConfig, h.streamApi, h.streamConfig, NewWebsocketConnection(conn)) controller.HandleConnection(context.TODO()) } diff --git a/engine/access/rest/websockets/mock/websocket_connection.go b/engine/access/rest/websockets/mock/websocket_connection.go new file mode 100644 index 00000000000..cafa0999278 --- /dev/null +++ b/engine/access/rest/websockets/mock/websocket_connection.go @@ -0,0 +1,141 @@ +// Code generated by mockery v2.43.2. DO NOT EDIT. + +package mock + +import ( + time "time" + + mock "github.com/stretchr/testify/mock" +) + +// WebsocketConnection is an autogenerated mock type for the WebsocketConnection type +type WebsocketConnection struct { + mock.Mock +} + +// Close provides a mock function with given fields: +func (_m *WebsocketConnection) Close() error { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Close") + } + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// ReadJSON provides a mock function with given fields: v +func (_m *WebsocketConnection) ReadJSON(v interface{}) error { + ret := _m.Called(v) + + if len(ret) == 0 { + panic("no return value specified for ReadJSON") + } + + var r0 error + if rf, ok := ret.Get(0).(func(interface{}) error); ok { + r0 = rf(v) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// SetPongHandler provides a mock function with given fields: _a0 +func (_m *WebsocketConnection) SetPongHandler(_a0 func(string) error) { + _m.Called(_a0) +} + +// SetReadDeadline provides a mock function with given fields: _a0 +func (_m *WebsocketConnection) SetReadDeadline(_a0 time.Time) error { + ret := _m.Called(_a0) + + if len(ret) == 0 { + panic("no return value specified for SetReadDeadline") + } + + var r0 error + if rf, ok := ret.Get(0).(func(time.Time) error); ok { + r0 = rf(_a0) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// SetWriteDeadline provides a mock function with given fields: _a0 +func (_m *WebsocketConnection) SetWriteDeadline(_a0 time.Time) error { + ret := _m.Called(_a0) + + if len(ret) == 0 { + panic("no return value specified for SetWriteDeadline") + } + + var r0 error + if rf, ok := ret.Get(0).(func(time.Time) error); ok { + r0 = rf(_a0) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// WriteJSON provides a mock function with given fields: v +func (_m *WebsocketConnection) WriteJSON(v interface{}) error { + ret := _m.Called(v) + + if len(ret) == 0 { + panic("no return value specified for WriteJSON") + } + + var r0 error + if rf, ok := ret.Get(0).(func(interface{}) error); ok { + r0 = rf(v) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// WriteMessage provides a mock function with given fields: _a0, _a1 +func (_m *WebsocketConnection) WriteMessage(_a0 int, _a1 []byte) error { + ret := _m.Called(_a0, _a1) + + if len(ret) == 0 { + panic("no return value specified for WriteMessage") + } + + var r0 error + if rf, ok := ret.Get(0).(func(int, []byte) error); ok { + r0 = rf(_a0, _a1) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// NewWebsocketConnection creates a new instance of WebsocketConnection. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewWebsocketConnection(t interface { + mock.TestingT + Cleanup(func()) +}) *WebsocketConnection { + mock := &WebsocketConnection{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} From 040a949fc29bb925abb789eb480318b595d67539 Mon Sep 17 00:00:00 2001 From: UlyanaAndrukhiv Date: Wed, 27 Nov 2024 13:26:30 +0200 Subject: [PATCH 09/15] Updated WriteMessage to WriteControl for Ping messages, updated mocks and tests --- engine/access/rest/websockets/connections.go | 20 ++++---- engine/access/rest/websockets/controller.go | 6 +-- .../access/rest/websockets/controller_test.go | 5 +- .../websockets/mock/websocket_connection.go | 46 +++++++++---------- 4 files changed, 36 insertions(+), 41 deletions(-) diff --git a/engine/access/rest/websockets/connections.go b/engine/access/rest/websockets/connections.go index 7421d3bd8ec..5170e917e9f 100644 --- a/engine/access/rest/websockets/connections.go +++ b/engine/access/rest/websockets/connections.go @@ -9,11 +9,11 @@ import ( type WebsocketConnection interface { ReadJSON(v interface{}) error WriteJSON(v interface{}) error - WriteMessage(int, []byte) error + WriteControl(messageType int, deadline time.Time) error Close() error - SetReadDeadline(time.Time) error - SetWriteDeadline(time.Time) error - SetPongHandler(func(string) error) + SetReadDeadline(deadline time.Time) error + SetWriteDeadline(deadline time.Time) error + SetPongHandler(h func(string) error) } type WebsocketConnectionImpl struct { @@ -36,20 +36,20 @@ func (c *WebsocketConnectionImpl) WriteJSON(v interface{}) error { return c.conn.WriteJSON(v) } -func (c *WebsocketConnectionImpl) WriteMessage(messageType int, data []byte) error { - return c.conn.WriteMessage(messageType, data) +func (c *WebsocketConnectionImpl) WriteControl(messageType int, deadline time.Time) error { + return c.conn.WriteControl(messageType, nil, deadline) } func (c *WebsocketConnectionImpl) Close() error { return c.conn.Close() } -func (c *WebsocketConnectionImpl) SetReadDeadline(t time.Time) error { - return c.conn.SetReadDeadline(t) +func (c *WebsocketConnectionImpl) SetReadDeadline(deadline time.Time) error { + return c.conn.SetReadDeadline(deadline) } -func (c *WebsocketConnectionImpl) SetWriteDeadline(t time.Time) error { - return c.conn.SetWriteDeadline(t) +func (c *WebsocketConnectionImpl) SetWriteDeadline(deadline time.Time) error { + return c.conn.SetWriteDeadline(deadline) } func (c *WebsocketConnectionImpl) SetPongHandler(h func(string) error) { diff --git a/engine/access/rest/websockets/controller.go b/engine/access/rest/websockets/controller.go index 4edff1ab282..44ee72a86a3 100644 --- a/engine/access/rest/websockets/controller.go +++ b/engine/access/rest/websockets/controller.go @@ -365,11 +365,7 @@ func (c *Controller) keepalive(ctx context.Context) error { // // No errors are expected during normal operation. func (c *Controller) sendPing() error { - if err := c.conn.SetWriteDeadline(time.Now().Add(WriteWait)); err != nil { - return fmt.Errorf("failed to set the write deadline for ping: %w", err) - } - - if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil { + if err := c.conn.WriteControl(websocket.PingMessage, time.Now().Add(WriteWait)); err != nil { return fmt.Errorf("failed to write ping message: %w", err) } diff --git a/engine/access/rest/websockets/controller_test.go b/engine/access/rest/websockets/controller_test.go index 0b04ba1eaa7..9fa6f99df58 100644 --- a/engine/access/rest/websockets/controller_test.go +++ b/engine/access/rest/websockets/controller_test.go @@ -58,8 +58,7 @@ func (s *ControllerSuite) startKeepalive(ctx context.Context, expectedError erro // Helper function to setup mock behavior for SetWriteDeadline and WriteMessage. func (s *ControllerSuite) setupMockConnection(writeMessageError error) { - s.connection.On("SetWriteDeadline", mock.Anything).Return(nil).Once() - s.connection.On("WriteMessage", websocket.PingMessage, mock.Anything).Return(writeMessageError).Once() + s.connection.On("WriteControl", websocket.PingMessage, mock.Anything).Return(writeMessageError).Once() } // Helper function to wait for expected mock calls. @@ -82,7 +81,7 @@ func (s *ControllerSuite) TestKeepaliveError() { s.startKeepalive(ctx, expectedError) // Wait for the ping message or timeout - expectedCalls := 2 + expectedCalls := 1 s.waitForMockCalls(expectedCalls, PongWait*3/2, 1*time.Second, "ping message was not sent") // Assert expectations diff --git a/engine/access/rest/websockets/mock/websocket_connection.go b/engine/access/rest/websockets/mock/websocket_connection.go index cafa0999278..02a60fd0a3c 100644 --- a/engine/access/rest/websockets/mock/websocket_connection.go +++ b/engine/access/rest/websockets/mock/websocket_connection.go @@ -49,14 +49,14 @@ func (_m *WebsocketConnection) ReadJSON(v interface{}) error { return r0 } -// SetPongHandler provides a mock function with given fields: _a0 -func (_m *WebsocketConnection) SetPongHandler(_a0 func(string) error) { - _m.Called(_a0) +// SetPongHandler provides a mock function with given fields: h +func (_m *WebsocketConnection) SetPongHandler(h func(string) error) { + _m.Called(h) } -// SetReadDeadline provides a mock function with given fields: _a0 -func (_m *WebsocketConnection) SetReadDeadline(_a0 time.Time) error { - ret := _m.Called(_a0) +// SetReadDeadline provides a mock function with given fields: deadline +func (_m *WebsocketConnection) SetReadDeadline(deadline time.Time) error { + ret := _m.Called(deadline) if len(ret) == 0 { panic("no return value specified for SetReadDeadline") @@ -64,7 +64,7 @@ func (_m *WebsocketConnection) SetReadDeadline(_a0 time.Time) error { var r0 error if rf, ok := ret.Get(0).(func(time.Time) error); ok { - r0 = rf(_a0) + r0 = rf(deadline) } else { r0 = ret.Error(0) } @@ -72,9 +72,9 @@ func (_m *WebsocketConnection) SetReadDeadline(_a0 time.Time) error { return r0 } -// SetWriteDeadline provides a mock function with given fields: _a0 -func (_m *WebsocketConnection) SetWriteDeadline(_a0 time.Time) error { - ret := _m.Called(_a0) +// SetWriteDeadline provides a mock function with given fields: deadline +func (_m *WebsocketConnection) SetWriteDeadline(deadline time.Time) error { + ret := _m.Called(deadline) if len(ret) == 0 { panic("no return value specified for SetWriteDeadline") @@ -82,7 +82,7 @@ func (_m *WebsocketConnection) SetWriteDeadline(_a0 time.Time) error { var r0 error if rf, ok := ret.Get(0).(func(time.Time) error); ok { - r0 = rf(_a0) + r0 = rf(deadline) } else { r0 = ret.Error(0) } @@ -90,17 +90,17 @@ func (_m *WebsocketConnection) SetWriteDeadline(_a0 time.Time) error { return r0 } -// WriteJSON provides a mock function with given fields: v -func (_m *WebsocketConnection) WriteJSON(v interface{}) error { - ret := _m.Called(v) +// WriteControl provides a mock function with given fields: messageType, deadline +func (_m *WebsocketConnection) WriteControl(messageType int, deadline time.Time) error { + ret := _m.Called(messageType, deadline) if len(ret) == 0 { - panic("no return value specified for WriteJSON") + panic("no return value specified for WriteControl") } var r0 error - if rf, ok := ret.Get(0).(func(interface{}) error); ok { - r0 = rf(v) + if rf, ok := ret.Get(0).(func(int, time.Time) error); ok { + r0 = rf(messageType, deadline) } else { r0 = ret.Error(0) } @@ -108,17 +108,17 @@ func (_m *WebsocketConnection) WriteJSON(v interface{}) error { return r0 } -// WriteMessage provides a mock function with given fields: _a0, _a1 -func (_m *WebsocketConnection) WriteMessage(_a0 int, _a1 []byte) error { - ret := _m.Called(_a0, _a1) +// WriteJSON provides a mock function with given fields: v +func (_m *WebsocketConnection) WriteJSON(v interface{}) error { + ret := _m.Called(v) if len(ret) == 0 { - panic("no return value specified for WriteMessage") + panic("no return value specified for WriteJSON") } var r0 error - if rf, ok := ret.Get(0).(func(int, []byte) error); ok { - r0 = rf(_a0, _a1) + if rf, ok := ret.Get(0).(func(interface{}) error); ok { + r0 = rf(v) } else { r0 = ret.Error(0) } From 276ea7ed6d2e037e9b39f00e26307c2186d6c923 Mon Sep 17 00:00:00 2001 From: UlyanaAndrukhiv Date: Thu, 28 Nov 2024 15:22:33 +0200 Subject: [PATCH 10/15] Added tests for keepalive, configure connection, graceful shutdown, added some refactoring, added godoc --- Makefile | 1 + engine/access/rest/router/router.go | 4 +- engine/access/rest/websockets/controller.go | 24 +- .../access/rest/websockets/controller_test.go | 241 ++++++++++++++---- .../rest/websockets/data_provider/factory.go | 45 +++- .../mock/data_provider_factory.go | 47 ++++ engine/access/rest/websockets/handler.go | 24 +- engine/access/rest/websockets/handler_test.go | 86 ------- 8 files changed, 308 insertions(+), 164 deletions(-) create mode 100644 engine/access/rest/websockets/data_provider/mock/data_provider_factory.go delete mode 100644 engine/access/rest/websockets/handler_test.go diff --git a/Makefile b/Makefile index 53ea58fc52d..84fadce74a0 100644 --- a/Makefile +++ b/Makefile @@ -204,6 +204,7 @@ generate-mocks: install-mock-generators mockery --name '.*' --dir="./engine/access/state_stream" --case=underscore --output="./engine/access/state_stream/mock" --outpkg="mock" mockery --name 'BlockTracker' --dir="./engine/access/subscription" --case=underscore --output="./engine/access/subscription/mock" --outpkg="mock" mockery --name 'DataProvider' --dir="./engine/access/rest/websockets/data_provider" --case=underscore --output="./engine/access/rest/websockets/data_provider/mock" --outpkg="mock" + mockery --name 'DataProviderFactory' --dir="./engine/access/rest/websockets/data_provider" --case=underscore --output="./engine/access/rest/websockets/data_provider/mock" --outpkg="mock" mockery --name 'WebsocketConnection' --dir="./engine/access/rest/websockets" --case=underscore --output="./engine/access/rest/websockets/mock" --outpkg="mock" mockery --name 'ExecutionDataTracker' --dir="./engine/access/subscription" --case=underscore --output="./engine/access/subscription/mock" --outpkg="mock" mockery --name 'ConnectionFactory' --dir="./engine/access/rpc/connection" --case=underscore --output="./engine/access/rpc/connection/mock" --outpkg="mock" diff --git a/engine/access/rest/router/router.go b/engine/access/rest/router/router.go index a2d81cb0a58..14487ef57df 100644 --- a/engine/access/rest/router/router.go +++ b/engine/access/rest/router/router.go @@ -14,6 +14,7 @@ import ( flowhttp "github.com/onflow/flow-go/engine/access/rest/http" "github.com/onflow/flow-go/engine/access/rest/http/models" "github.com/onflow/flow-go/engine/access/rest/websockets" + "github.com/onflow/flow-go/engine/access/rest/websockets/data_provider" legacyws "github.com/onflow/flow-go/engine/access/rest/websockets/legacy" "github.com/onflow/flow-go/engine/access/state_stream" "github.com/onflow/flow-go/engine/access/state_stream/backend" @@ -93,7 +94,8 @@ func (b *RouterBuilder) AddWebsocketsRoute( streamConfig backend.Config, maxRequestSize int64, ) *RouterBuilder { - handler := websockets.NewWebSocketHandler(b.logger, config, chain, streamApi, streamConfig, maxRequestSize) + dataProviderFactory := data_provider.NewDataProviderFactory(b.logger, streamApi, streamConfig) + handler := websockets.NewWebSocketHandler(b.logger, config, chain, dataProviderFactory, maxRequestSize) b.v1SubRouter. Methods(http.MethodGet). Path("/ws"). diff --git a/engine/access/rest/websockets/controller.go b/engine/access/rest/websockets/controller.go index 44ee72a86a3..b423aaecf3b 100644 --- a/engine/access/rest/websockets/controller.go +++ b/engine/access/rest/websockets/controller.go @@ -13,8 +13,6 @@ import ( dp "github.com/onflow/flow-go/engine/access/rest/websockets/data_provider" "github.com/onflow/flow-go/engine/access/rest/websockets/models" - "github.com/onflow/flow-go/engine/access/state_stream" - "github.com/onflow/flow-go/engine/access/state_stream/backend" "github.com/onflow/flow-go/utils/concurrentmap" ) @@ -39,7 +37,7 @@ type Controller struct { errorChannel chan error // Channel for reporting errors. dataProviders *concurrentmap.Map[uuid.UUID, dp.DataProvider] - dataProvidersFactory *dp.Factory + dataProvidersFactory dp.DataProviderFactory shutdownOnce sync.Once // Ensures shutdown is only called once shutdown bool // Indicates if the controller is shutting down. @@ -48,8 +46,7 @@ type Controller struct { func NewWebSocketController( logger zerolog.Logger, config Config, - streamApi state_stream.API, - streamConfig backend.Config, + dataProviderFactory dp.DataProviderFactory, conn WebsocketConnection, ) *Controller { return &Controller{ @@ -59,7 +56,7 @@ func NewWebSocketController( communicationChannel: make(chan interface{}), //TODO: should it be buffered chan? errorChannel: make(chan error, 1), // Buffered error channel to hold one error. dataProviders: concurrentmap.New[uuid.UUID, dp.DataProvider](), - dataProvidersFactory: dp.NewDataProviderFactory(logger, streamApi, streamConfig), + dataProvidersFactory: dataProviderFactory, } } @@ -172,7 +169,11 @@ func (c *Controller) writeMessagesToClient(ctx context.Context) error { select { case <-ctx.Done(): return ctx.Err() - case msg := <-c.communicationChannel: + case msg, ok := <-c.communicationChannel: + if !ok { + err := fmt.Errorf("communication channel closed, no error occurred") + return err + } // TODO: handle 'response per second' limits // Specifies a timeout for the write operation. If the write @@ -290,10 +291,11 @@ func (c *Controller) handleAction(ctx context.Context, message interface{}) erro func (c *Controller) handleSubscribe(ctx context.Context, msg models.SubscribeMessageRequest) { dp := c.dataProvidersFactory.NewDataProvider(c.communicationChannel, msg.Topic) c.dataProviders.Add(dp.ID(), dp) - dp.Run(ctx) //TODO: return OK response to client c.communicationChannel <- msg + + dp.Run(ctx) } func (c *Controller) handleUnsubscribe(_ context.Context, msg models.UnsubscribeMessageRequest) { @@ -320,12 +322,12 @@ func (c *Controller) shutdownConnection() { c.shutdownOnce.Do(func() { c.shutdown = true - defer func(conn WebsocketConnection) { + defer func() { if err := c.conn.Close(); err != nil { c.logger.Error().Err(err).Msg("error closing connection") } - }(c.conn) - close(c.communicationChannel) + close(c.communicationChannel) + }() err := c.dataProviders.ForEach(func(_ uuid.UUID, dp dp.DataProvider) error { dp.Close() diff --git a/engine/access/rest/websockets/controller_test.go b/engine/access/rest/websockets/controller_test.go index 9fa6f99df58..f31c87f5923 100644 --- a/engine/access/rest/websockets/controller_test.go +++ b/engine/access/rest/websockets/controller_test.go @@ -2,25 +2,32 @@ package websockets import ( "context" + "encoding/json" "fmt" "testing" "time" + "github.com/google/uuid" "github.com/gorilla/websocket" "github.com/rs/zerolog" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" + dpmock "github.com/onflow/flow-go/engine/access/rest/websockets/data_provider/mock" connectionmock "github.com/onflow/flow-go/engine/access/rest/websockets/mock" + "github.com/onflow/flow-go/engine/access/rest/websockets/models" + "github.com/onflow/flow-go/utils/unittest" ) type ControllerSuite struct { suite.Suite - connection *connectionmock.WebsocketConnection - controller *Controller + logger zerolog.Logger + config Config + + connection *connectionmock.WebsocketConnection + dataProviderFactory *dpmock.DataProviderFactory } func TestControllerSuite(t *testing.T) { @@ -29,60 +36,179 @@ func TestControllerSuite(t *testing.T) { // SetupTest initializes the test suite with required dependencies. func (s *ControllerSuite) SetupTest() { - s.connection = connectionmock.NewWebsocketConnection(s.T()) + s.logger = unittest.Logger() + s.config = Config{} - // Create the controller - log := zerolog.New(zerolog.NewConsoleWriter()) - config := Config{} - s.controller = &Controller{ - logger: log, - config: config, - conn: s.connection, - communicationChannel: make(chan interface{}), - errorChannel: make(chan error, 1), - } + s.connection = connectionmock.NewWebsocketConnection(s.T()) + s.dataProviderFactory = dpmock.NewDataProviderFactory(s.T()) } -// Helper function to start the keepalive process. -func (s *ControllerSuite) startKeepalive(ctx context.Context, expectedError error) { - go func() { - err := s.controller.keepalive(ctx) - if expectedError != nil { - s.Require().Error(err) - s.Require().Equal(expectedError, err) - } else { - s.Require().NoError(err) - } - }() -} +// TestConfigureConnection ensures that the WebSocket connection is configured correctly. +func (s *ControllerSuite) TestConfigureConnection() { + controller := s.initializeController() + + // Mock configureConnection to succeed + s.mockConnectionSetup() + + // Call configureConnection + err := controller.configureConnection() + s.Require().NoError(err, "configureConnection should not return an error") -// Helper function to setup mock behavior for SetWriteDeadline and WriteMessage. -func (s *ControllerSuite) setupMockConnection(writeMessageError error) { - s.connection.On("WriteControl", websocket.PingMessage, mock.Anything).Return(writeMessageError).Once() + // Assert expectations + s.connection.AssertExpectations(s.T()) } -// Helper function to wait for expected mock calls. -func (s *ControllerSuite) waitForMockCalls(expectedCalls int, timeout time.Duration, interval time.Duration, errorMessage string) { - require.Eventually(s.T(), func() bool { - return len(s.connection.Calls) == expectedCalls - }, timeout, interval, errorMessage) +// TestControllerShutdown ensures that HandleConnection shuts down gracefully when an error occurs. +func (s *ControllerSuite) TestControllerShutdown() { + s.T().Run("keepalive routine failed", func(*testing.T) { + controller := s.initializeController() + + // Mock configureConnection to succeed + s.mockConnectionSetup() + + // Mock keepalive to return an error + done := make(chan struct{}, 1) + s.connection.On("WriteControl", websocket.PingMessage, mock.Anything).Return(func(int, time.Time) error { + close(done) + return websocket.ErrCloseSent + }).Once() + + s.connection. + On("ReadJSON", mock.Anything). + Return(func(interface{}) error { + _, ok := <-done + if !ok { + return websocket.ErrCloseSent + } + return nil + }). + Once() + + s.connection.On("Close").Return(nil).Once() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + controller.HandleConnection(ctx) + + s.Require().True(controller.shutdown) + // Ensure all expectations are met + s.connection.AssertExpectations(s.T()) + }) + + s.T().Run("read routine failed", func(*testing.T) { + controller := s.initializeController() + // Mock configureConnection to succeed + s.mockConnectionSetup() + + // Mock keepalive to return an error + done := make(chan struct{}, 1) + s.connection.On("WriteControl", websocket.PingMessage, mock.Anything).Return(func(int, time.Time) error { + _, ok := <-done + if !ok { + return websocket.ErrCloseSent + } + return nil + }).Once() + s.connection. + On("ReadJSON", mock.Anything). + Return(func(_ interface{}) error { + close(done) + return assert.AnError + }). + Once() + + s.connection.On("Close").Return(nil).Once() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + controller.HandleConnection(ctx) + + s.Require().True(controller.shutdown) + // Ensure all expectations are met + s.connection.AssertExpectations(s.T()) + }) + + s.T().Run("write routine failed", func(*testing.T) { + controller := s.initializeController() + + // Mock configureConnection to succeed + s.mockConnectionSetup() + blocksDataProvider := s.mockBlockDataProviderSetup(uuid.New()) + + done := make(chan struct{}, 1) + // Mock keepalive to return a connection error + s.connection.On("WriteControl", websocket.PingMessage, mock.Anything).Return(func(int, time.Time) error { + _, ok := <-done + if !ok { + return websocket.ErrCloseSent + } + return nil + }).Once() + + requestMessage := models.SubscribeMessageRequest{ + BaseMessageRequest: models.BaseMessageRequest{Action: "subscribe"}, + Topic: "blocks", + Arguments: nil, + } + + s.connection. + On("ReadJSON", mock.Anything). + Run(func(args mock.Arguments) { + reqMsg, ok := args.Get(0).(*json.RawMessage) + s.Require().True(ok) + msg, err := json.Marshal(requestMessage) + s.Require().NoError(err) + *reqMsg = msg + }). + Return(nil). + Once() + + s.connection. + On("ReadJSON", mock.Anything). + Return(func(interface{}) error { + _, ok := <-done + if !ok { + return websocket.ErrCloseSent + } + return nil + }) + + s.connection.On("SetWriteDeadline", mock.Anything).Return(nil).Once() + s.connection. + On("WriteJSON", mock.Anything). + Return(func(msg interface{}) error { + close(done) + return assert.AnError + }) + s.connection.On("Close").Return(nil).Once() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + controller.HandleConnection(ctx) + + s.Require().True(controller.shutdown) + // Ensure all expectations are met + s.connection.AssertExpectations(s.T()) + s.dataProviderFactory.AssertExpectations(s.T()) + blocksDataProvider.AssertExpectations(s.T()) + }) } // TestKeepaliveError tests the behavior of the keepalive function when there is an error in writing the ping. func (s *ControllerSuite) TestKeepaliveError() { + controller := s.initializeController() + + // Setup the mock connection with an error + s.connection.On("WriteControl", websocket.PingMessage, mock.Anything).Return(assert.AnError).Once() + ctx, cancel := context.WithCancel(context.Background()) defer cancel() - // Setup the mock connection with an error expectedError := fmt.Errorf("failed to write ping message: %w", assert.AnError) - s.setupMockConnection(assert.AnError) - // Start the keepalive process - s.startKeepalive(ctx, expectedError) - - // Wait for the ping message or timeout - expectedCalls := 1 - s.waitForMockCalls(expectedCalls, PongWait*3/2, 1*time.Second, "ping message was not sent") + err := controller.keepalive(ctx) + s.Require().Error(err) + s.Require().Equal(expectedError, err) // Assert expectations s.connection.AssertExpectations(s.T()) @@ -91,15 +217,38 @@ func (s *ControllerSuite) TestKeepaliveError() { // TestKeepaliveContextCancel tests the behavior of keepalive when the context is canceled before a ping is sent and // no ping message is sent after the context is canceled. func (s *ControllerSuite) TestKeepaliveContextCancel() { + controller := s.initializeController() + ctx, cancel := context.WithCancel(context.Background()) cancel() // Immediately cancel the context // Start the keepalive process with the context canceled - s.startKeepalive(ctx, context.Canceled) - - // Wait for the timeout to ensure no ping is sent - time.Sleep(PongWait) + err := controller.keepalive(ctx) + s.Require().Error(err) + s.Require().Equal(context.Canceled, err) // Assert expectations s.connection.AssertExpectations(s.T()) // Should not invoke WriteMessage after context cancellation } + +// initializeController initializes the WebSocket controller. +func (s *ControllerSuite) initializeController() *Controller { + return NewWebSocketController(s.logger, s.config, s.dataProviderFactory, s.connection) +} + +// mockDataProviderSetup is a helper which mocks a blocks data provider setup. +func (s *ControllerSuite) mockBlockDataProviderSetup(id uuid.UUID) *dpmock.DataProvider { + dataProvider := dpmock.NewDataProvider(s.T()) + dataProvider.On("ID").Return(id).Once() + dataProvider.On("Close").Return(nil).Once() + s.dataProviderFactory.On("NewDataProvider", mock.Anything, mock.Anything).Return(dataProvider).Once() + dataProvider.On("Run", mock.Anything).Return().Once() + + return dataProvider +} + +// mockConnectionSetup is a helper which mocks connection setup for SetReadDeadline and SetPongHandler. +func (s *ControllerSuite) mockConnectionSetup() { + s.connection.On("SetReadDeadline", mock.Anything).Return(nil).Once() + s.connection.On("SetPongHandler", mock.AnythingOfType("func(string) error")).Return(nil).Once() +} diff --git a/engine/access/rest/websockets/data_provider/factory.go b/engine/access/rest/websockets/data_provider/factory.go index 6a2658b1b95..bb6a50ae9b0 100644 --- a/engine/access/rest/websockets/data_provider/factory.go +++ b/engine/access/rest/websockets/data_provider/factory.go @@ -7,24 +7,57 @@ import ( "github.com/onflow/flow-go/engine/access/state_stream/backend" ) -type Factory struct { +// Constants defining various topic names used to specify different types of +// data providers. +const ( + BlocksTopic = "blocks" +) + +// TODO: Temporary implementation without godoc; should be replaced once PR #6636 is merged + +// DataProviderFactory defines an interface for creating data providers +// based on specified topics. The factory abstracts the creation process +// and ensures consistent access to required APIs. +type DataProviderFactory interface { + // NewDataProvider creates a new data provider based on the specified topic + // and configuration parameters. + // + // No errors are expected during normal operations. + NewDataProvider( + ch chan<- interface{}, + topic string) DataProvider +} + +var _ DataProviderFactory = (*DataProviderFactoryImpl)(nil) + +// DataProviderFactoryImpl is an implementation of the DataProviderFactory interface. +// It is responsible for creating data providers based on the +// requested topic. It manages access to logging and relevant APIs needed to retrieve data. +type DataProviderFactoryImpl struct { logger zerolog.Logger streamApi state_stream.API streamConfig backend.Config } -func NewDataProviderFactory(logger zerolog.Logger, streamApi state_stream.API, streamConfig backend.Config) *Factory { - return &Factory{ +func NewDataProviderFactory( + logger zerolog.Logger, + streamApi state_stream.API, + streamConfig backend.Config, +) *DataProviderFactoryImpl { + return &DataProviderFactoryImpl{ logger: logger, streamApi: streamApi, streamConfig: streamConfig, } } -func (f *Factory) NewDataProvider(ch chan<- interface{}, topic string) DataProvider { +func (s *DataProviderFactoryImpl) NewDataProvider( + ch chan<- interface{}, + topic string, +) DataProvider { switch topic { - case "blocks": - return NewMockBlockProvider(ch, topic, f.logger, f.streamApi) + case BlocksTopic: + return NewMockBlockProvider(ch, topic, s.logger, s.streamApi) default: return nil } diff --git a/engine/access/rest/websockets/data_provider/mock/data_provider_factory.go b/engine/access/rest/websockets/data_provider/mock/data_provider_factory.go new file mode 100644 index 00000000000..406231710dc --- /dev/null +++ b/engine/access/rest/websockets/data_provider/mock/data_provider_factory.go @@ -0,0 +1,47 @@ +// Code generated by mockery v2.43.2. DO NOT EDIT. + +package mock + +import ( + data_provider "github.com/onflow/flow-go/engine/access/rest/websockets/data_provider" + mock "github.com/stretchr/testify/mock" +) + +// DataProviderFactory is an autogenerated mock type for the DataProviderFactory type +type DataProviderFactory struct { + mock.Mock +} + +// NewDataProvider provides a mock function with given fields: ch, topic +func (_m *DataProviderFactory) NewDataProvider(ch chan<- interface{}, topic string) data_provider.DataProvider { + ret := _m.Called(ch, topic) + + if len(ret) == 0 { + panic("no return value specified for NewDataProvider") + } + + var r0 data_provider.DataProvider + if rf, ok := ret.Get(0).(func(chan<- interface{}, string) data_provider.DataProvider); ok { + r0 = rf(ch, topic) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(data_provider.DataProvider) + } + } + + return r0 +} + +// NewDataProviderFactory creates a new instance of DataProviderFactory. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewDataProviderFactory(t interface { + mock.TestingT + Cleanup(func()) +}) *DataProviderFactory { + mock := &DataProviderFactory{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/engine/access/rest/websockets/handler.go b/engine/access/rest/websockets/handler.go index c7acb46e506..a408308ae1f 100644 --- a/engine/access/rest/websockets/handler.go +++ b/engine/access/rest/websockets/handler.go @@ -8,18 +8,16 @@ import ( "github.com/rs/zerolog" "github.com/onflow/flow-go/engine/access/rest/common" - "github.com/onflow/flow-go/engine/access/state_stream" - "github.com/onflow/flow-go/engine/access/state_stream/backend" + "github.com/onflow/flow-go/engine/access/rest/websockets/data_provider" "github.com/onflow/flow-go/model/flow" ) type Handler struct { *common.HttpHandler - logger zerolog.Logger - websocketConfig Config - streamApi state_stream.API - streamConfig backend.Config + logger zerolog.Logger + websocketConfig Config + dataProviderFactory data_provider.DataProviderFactory } var _ http.Handler = (*Handler)(nil) @@ -28,16 +26,14 @@ func NewWebSocketHandler( logger zerolog.Logger, config Config, chain flow.Chain, - streamApi state_stream.API, - streamConfig backend.Config, + dataProviderFactory data_provider.DataProviderFactory, maxRequestSize int64, ) *Handler { return &Handler{ - HttpHandler: common.NewHttpHandler(logger, chain, maxRequestSize), - websocketConfig: config, - logger: logger, - streamApi: streamApi, - streamConfig: streamConfig, + HttpHandler: common.NewHttpHandler(logger, chain, maxRequestSize), + websocketConfig: config, + logger: logger, + dataProviderFactory: dataProviderFactory, } } @@ -65,6 +61,6 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - controller := NewWebSocketController(logger, h.websocketConfig, h.streamApi, h.streamConfig, NewWebsocketConnection(conn)) + controller := NewWebSocketController(logger, h.websocketConfig, h.dataProviderFactory, NewWebsocketConnection(conn)) controller.HandleConnection(context.TODO()) } diff --git a/engine/access/rest/websockets/handler_test.go b/engine/access/rest/websockets/handler_test.go deleted file mode 100644 index 6b9cce06572..00000000000 --- a/engine/access/rest/websockets/handler_test.go +++ /dev/null @@ -1,86 +0,0 @@ -package websockets_test - -import ( - "encoding/json" - "net/http" - "net/http/httptest" - "strings" - "testing" - - "github.com/gorilla/websocket" - "github.com/rs/zerolog" - "github.com/stretchr/testify/require" - "github.com/stretchr/testify/suite" - - "github.com/onflow/flow-go/engine/access/rest/websockets" - "github.com/onflow/flow-go/engine/access/rest/websockets/models" - "github.com/onflow/flow-go/engine/access/state_stream/backend" - streammock "github.com/onflow/flow-go/engine/access/state_stream/mock" - "github.com/onflow/flow-go/model/flow" - "github.com/onflow/flow-go/utils/unittest" -) - -var ( - chainID = flow.Testnet -) - -type WsHandlerSuite struct { - suite.Suite - - logger zerolog.Logger - handler *websockets.Handler - wsConfig websockets.Config - streamApi *streammock.API - streamConfig backend.Config -} - -func (s *WsHandlerSuite) SetupTest() { - s.logger = unittest.Logger() - s.wsConfig = websockets.NewDefaultWebsocketConfig() - s.streamApi = streammock.NewAPI(s.T()) - s.streamConfig = backend.Config{} - s.handler = websockets.NewWebSocketHandler(s.logger, s.wsConfig, chainID.Chain(), s.streamApi, s.streamConfig, 1024) -} - -func TestWsHandlerSuite(t *testing.T) { - suite.Run(t, new(WsHandlerSuite)) -} - -func ClientConnection(url string) (*websocket.Conn, *http.Response, error) { - wsURL := "ws" + strings.TrimPrefix(url, "http") - return websocket.DefaultDialer.Dial(wsURL, nil) -} - -func (s *WsHandlerSuite) TestSubscribeRequest() { - s.Run("Happy path", func() { - server := httptest.NewServer(s.handler) - defer server.Close() - - conn, _, err := ClientConnection(server.URL) - defer func(conn *websocket.Conn) { - err := conn.Close() - require.NoError(s.T(), err) - }(conn) - require.NoError(s.T(), err) - - args := map[string]interface{}{ - "start_block_height": 10, - } - body := models.SubscribeMessageRequest{ - BaseMessageRequest: models.BaseMessageRequest{Action: "subscribe"}, - Topic: "blocks", - Arguments: args, - } - bodyJSON, err := json.Marshal(body) - require.NoError(s.T(), err) - - err = conn.WriteMessage(websocket.TextMessage, bodyJSON) - require.NoError(s.T(), err) - - _, msg, err := conn.ReadMessage() - require.NoError(s.T(), err) - - actualMsg := strings.Trim(string(msg), "\n\"\\ ") - require.Equal(s.T(), "block{height: 42}", actualMsg) - }) -} From 21259cecfe6a1b2a991fe8ac5e69b41032e7fc47 Mon Sep 17 00:00:00 2001 From: UlyanaAndrukhiv Date: Thu, 28 Nov 2024 15:47:33 +0200 Subject: [PATCH 11/15] Added happy case test for keepalive --- .../access/rest/websockets/controller_test.go | 28 ++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/engine/access/rest/websockets/controller_test.go b/engine/access/rest/websockets/controller_test.go index f31c87f5923..4121d97cb94 100644 --- a/engine/access/rest/websockets/controller_test.go +++ b/engine/access/rest/websockets/controller_test.go @@ -20,6 +20,7 @@ import ( "github.com/onflow/flow-go/utils/unittest" ) +// ControllerSuite is a test suite for the WebSocket Controller. type ControllerSuite struct { suite.Suite @@ -50,7 +51,7 @@ func (s *ControllerSuite) TestConfigureConnection() { // Mock configureConnection to succeed s.mockConnectionSetup() - // Call configureConnection + // Call configureConnection and check for errors err := controller.configureConnection() s.Require().NoError(err, "configureConnection should not return an error") @@ -194,6 +195,31 @@ func (s *ControllerSuite) TestControllerShutdown() { }) } +// TestKeepalive tests the behavior of the keepalive function. +func (s *ControllerSuite) TestKeepalive() { + // Create a context for the test + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + controller := s.initializeController() + s.connection.On("WriteControl", websocket.PingMessage, mock.Anything).Return(nil) + + // Start the keepalive process in a separate goroutine + go func() { + err := controller.keepalive(ctx) + s.Require().NoError(err) + }() + + // Use Eventually to wait for some ping messages + expectedCalls := 3 // expected 3 ping messages for 30 seconds + s.Require().Eventually(func() bool { + return len(s.connection.Calls) == expectedCalls + }, 30*time.Second, 1*time.Second, "not all ping messages were sent") + + // Assert that the ping was sent + s.connection.AssertExpectations(s.T()) +} + // TestKeepaliveError tests the behavior of the keepalive function when there is an error in writing the ping. func (s *ControllerSuite) TestKeepaliveError() { controller := s.initializeController() From 1f5728d8662505919e54a42c9ed16b46f6edf466 Mon Sep 17 00:00:00 2001 From: UlyanaAndrukhiv Date: Thu, 28 Nov 2024 15:59:37 +0200 Subject: [PATCH 12/15] Updated unit test for keep alive --- engine/access/rest/websockets/controller_test.go | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/engine/access/rest/websockets/controller_test.go b/engine/access/rest/websockets/controller_test.go index 4121d97cb94..964802d04c5 100644 --- a/engine/access/rest/websockets/controller_test.go +++ b/engine/access/rest/websockets/controller_test.go @@ -195,11 +195,10 @@ func (s *ControllerSuite) TestControllerShutdown() { }) } -// TestKeepalive tests the behavior of the keepalive function. -func (s *ControllerSuite) TestKeepalive() { +// TestKeepaliveHappyCase tests the behavior of the keepalive function. +func (s *ControllerSuite) TestKeepaliveHappyCase() { // Create a context for the test - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + ctx := context.Background() controller := s.initializeController() s.connection.On("WriteControl", websocket.PingMessage, mock.Anything).Return(nil) @@ -216,6 +215,9 @@ func (s *ControllerSuite) TestKeepalive() { return len(s.connection.Calls) == expectedCalls }, 30*time.Second, 1*time.Second, "not all ping messages were sent") + s.connection.On("Close").Return(nil).Once() + controller.shutdownConnection() + // Assert that the ping was sent s.connection.AssertExpectations(s.T()) } From f384b0a24d1165444a57bba1e9eadd09893723a5 Mon Sep 17 00:00:00 2001 From: UlyanaAndrukhiv Date: Mon, 2 Dec 2024 17:59:33 +0200 Subject: [PATCH 13/15] Removed sendPing abstraction, updated godoc according to comments --- .../rest/websockets/{connections.go => connection.go} | 0 engine/access/rest/websockets/controller.go | 11 +++++++---- 2 files changed, 7 insertions(+), 4 deletions(-) rename engine/access/rest/websockets/{connections.go => connection.go} (100%) diff --git a/engine/access/rest/websockets/connections.go b/engine/access/rest/websockets/connection.go similarity index 100% rename from engine/access/rest/websockets/connections.go rename to engine/access/rest/websockets/connection.go diff --git a/engine/access/rest/websockets/controller.go b/engine/access/rest/websockets/controller.go index b423aaecf3b..cdfa3e8cc91 100644 --- a/engine/access/rest/websockets/controller.go +++ b/engine/access/rest/websockets/controller.go @@ -21,7 +21,8 @@ const ( // This value must be less than pongWait. PingPeriod = (PongWait * 9) / 10 - // PongWait specifies the maximum time to wait for a pong message from the peer. + // PongWait specifies the maximum time to wait for a pong response message from the peer + // after sending a ping PongWait = 10 * time.Second // WriteWait specifies the maximum duration allowed to write a message to the peer. @@ -354,10 +355,12 @@ func (c *Controller) keepalive(ctx context.Context) error { case <-ctx.Done(): return ctx.Err() case <-pingTicker.C: - if err := c.sendPing(); err != nil { + err := c.conn.WriteControl(websocket.PingMessage, time.Now().Add(WriteWait)) + if err != nil { // Log error and exit the loop on failure - c.logger.Error().Err(err).Msg("failed to send ping") - return err + c.logger.Debug().Err(err).Msg("failed to send ping") + + return fmt.Errorf("failed to write ping message: %w", err) } } } From 66d0607af169159a3fef213b953c8ebbe41e3ddc Mon Sep 17 00:00:00 2001 From: UlyanaAndrukhiv Date: Mon, 2 Dec 2024 18:48:42 +0200 Subject: [PATCH 14/15] Updated last commit --- engine/access/rest/websockets/controller.go | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/engine/access/rest/websockets/controller.go b/engine/access/rest/websockets/controller.go index cdfa3e8cc91..da4be0d3267 100644 --- a/engine/access/rest/websockets/controller.go +++ b/engine/access/rest/websockets/controller.go @@ -365,14 +365,3 @@ func (c *Controller) keepalive(ctx context.Context) error { } } } - -// sendPing sends a periodic ping message to the WebSocket client to keep the connection alive. -// -// No errors are expected during normal operation. -func (c *Controller) sendPing() error { - if err := c.conn.WriteControl(websocket.PingMessage, time.Now().Add(WriteWait)); err != nil { - return fmt.Errorf("failed to write ping message: %w", err) - } - - return nil -} From 3cfe98b6271eee46981689fe2b6fe96913939bef Mon Sep 17 00:00:00 2001 From: UlyanaAndrukhiv Date: Mon, 2 Dec 2024 20:29:54 +0200 Subject: [PATCH 15/15] Extended godoc --- engine/access/rest/websockets/controller.go | 25 ++++++++++++++------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/engine/access/rest/websockets/controller.go b/engine/access/rest/websockets/controller.go index da4be0d3267..4a6e7b4b074 100644 --- a/engine/access/rest/websockets/controller.go +++ b/engine/access/rest/websockets/controller.go @@ -18,14 +18,28 @@ import ( const ( // PingPeriod defines the interval at which ping messages are sent to the client. - // This value must be less than pongWait. + // This value must be less than pongWait, cause it that case the server ensures it sends a ping well before the PongWait + // timeout elapses. Each new pong message resets the server's read deadline, keeping the connection alive as long as + // the client is responsive. + // + // Example: + // At t=9, the server sends a ping, initial read deadline is t=10 (for the first message) + // At t=10, the client responds with a pong. The server resets its read deadline to t=20. + // At t=18, the server sends another ping. If the client responds with a pong at t=19, the read deadline is extended to t=29. + // + // In case of failure: + // If the client stops responding, the server will send a ping at t=9 but won't receive a pong by t=10. The server then closes the connection. PingPeriod = (PongWait * 9) / 10 // PongWait specifies the maximum time to wait for a pong response message from the peer // after sending a ping PongWait = 10 * time.Second - // WriteWait specifies the maximum duration allowed to write a message to the peer. + // WriteWait specifies a timeout for the write operation. If the write + // isn't completed within this duration, it fails with a timeout error. + // SetWriteDeadline ensures the write operation does not block indefinitely + // if the client is slow or unresponsive. This prevents resource exhaustion + // and allows the server to gracefully handle timeouts for delayed writes. WriteWait = 10 * time.Second ) @@ -88,12 +102,7 @@ func (c *Controller) HandleConnection(ctx context.Context) { // Wait for context cancellation or errors from goroutines. select { - case err, ok := <-c.errorChannel: - if !ok { - c.logger.Error().Msg("error channel closed") - //TODO: add error handling here - return - } + case err := <-c.errorChannel: c.logger.Error().Err(err).Msg("error detected in one of the goroutines") //TODO: add error handling here c.shutdownConnection()