Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Access] Add new websocket handler and skeleton for its deps #6630

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmd/observer/node_builder/observer_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ func DefaultObserverServiceConfig() *ObserverServiceConfig {
registerCacheSize: 0,
programCacheSize: 0,
registerDBPruneThreshold: pruner.DefaultThreshold,
websocketConfig: *websockets.NewDefaultWebsocketConfig(),
websocketConfig: websockets.NewDefaultWebsocketConfig(),
}
}

Expand Down
2 changes: 1 addition & 1 deletion cmd/util/cmd/run-script/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ func run(*cobra.Command, []string) {
metrics.NewNoopCollector(),
nil,
backend.Config{},
*websockets.NewDefaultWebsocketConfig(),
websockets.NewDefaultWebsocketConfig(),
)
if err != nil {
log.Fatal().Err(err).Msg("failed to create server")
Expand Down
2 changes: 1 addition & 1 deletion engine/access/handle_irrecoverable_state_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ func (suite *IrrecoverableStateTestSuite) SetupTest() {
RestConfig: rest.Config{
ListenAddress: unittest.DefaultAddress,
},
WebSocketConfig: *websockets.NewDefaultWebsocketConfig(),
WebSocketConfig: websockets.NewDefaultWebsocketConfig(),
}

// generate a server certificate that will be served by the GRPC server
Expand Down
2 changes: 1 addition & 1 deletion engine/access/integration_unsecure_grpc_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ func (suite *SameGRPCPortTestSuite) SetupTest() {
UnsecureGRPCListenAddr: unittest.DefaultAddress,
SecureGRPCListenAddr: unittest.DefaultAddress,
HTTPListenAddr: unittest.DefaultAddress,
WebSocketConfig: *websockets.NewDefaultWebsocketConfig(),
WebSocketConfig: websockets.NewDefaultWebsocketConfig(),
}

blockCount := 5
Expand Down
2 changes: 1 addition & 1 deletion engine/access/rest/router/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ func (b *RouterBuilder) AddLegacyWebsocketsRoutes(

func (b *RouterBuilder) AddWebsocketsRoute(
chain flow.Chain,
config *websockets.Config,
config websockets.Config,
streamApi state_stream.API,
streamConfig backend.Config,
maxRequestSize int64,
Expand Down
2 changes: 1 addition & 1 deletion engine/access/rest/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func NewServer(serverAPI access.API,
builder.AddLegacyWebsocketsRoutes(stateStreamApi, chain, stateStreamConfig, config.MaxRequestSize)
}

builder.AddWebsocketsRoute(chain, &wsConfig, stateStreamApi, stateStreamConfig, config.MaxRequestSize)
builder.AddWebsocketsRoute(chain, wsConfig, stateStreamApi, stateStreamConfig, config.MaxRequestSize)

c := cors.New(cors.Options{
AllowedOrigins: []string{"*"},
Expand Down
4 changes: 2 additions & 2 deletions engine/access/rest/websockets/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ type Config struct {
MaxRequestSize int64
}

func NewDefaultWebsocketConfig() *Config {
return &Config{
func NewDefaultWebsocketConfig() Config {
return Config{
MaxSubscriptionsPerConnection: 1000,
MaxResponsesPerSecond: 1000,
SendMessageTimeout: 10 * time.Second,
Expand Down
98 changes: 58 additions & 40 deletions engine/access/rest/websockets/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,45 +10,44 @@ import (
"github.com/rs/zerolog"

dp "github.com/onflow/flow-go/engine/access/rest/websockets/data_provider"
"github.com/onflow/flow-go/engine/access/rest/websockets/models"
"github.com/onflow/flow-go/engine/access/state_stream"
"github.com/onflow/flow-go/engine/access/state_stream/backend"
"github.com/onflow/flow-go/utils/concurrentmap"
)

type Controller struct {
ctx context.Context
logger zerolog.Logger
config *Config
config Config
conn *websocket.Conn
communicationChannel chan interface{}
dataProviders *ThreadSafeMap[uuid.UUID, dp.DataProvider]
dataProviders *concurrentmap.ConcurrentMap[uuid.UUID, dp.DataProvider]
Guitarheroua marked this conversation as resolved.
Show resolved Hide resolved
dataProvidersFactory *dp.Factory
}

func NewWebSocketController(
ctx context.Context,
logger zerolog.Logger,
config *Config,
config Config,
streamApi state_stream.API,
streamConfig backend.Config,
conn *websocket.Conn,
) *Controller {
return &Controller{
ctx: ctx,
logger: logger.With().Str("component", "websocket-controller").Logger(),
config: config,
conn: conn,
communicationChannel: make(chan interface{}), //TODO: should it be buffered chan?
dataProviders: NewThreadSafeMap[uuid.UUID, dp.DataProvider](),
dataProviders: concurrentmap.NewConcurrentMap[uuid.UUID, dp.DataProvider](),
dataProvidersFactory: dp.NewDataProviderFactory(logger, streamApi, streamConfig),
}
}

// HandleConnection manages the WebSocket connection, adding context and error handling.
func (c *Controller) HandleConnection() {
func (c *Controller) HandleConnection(ctx context.Context) {
//TODO: configure the connection with ping-pong and deadlines

go c.readMessagesFromClient(c.ctx)
go c.writeMessagesToClient(c.ctx)
//TODO: spin up a response limit tracker routine
go c.readMessagesFromClient(ctx)
go c.writeMessagesToClient(ctx)
}

func (c *Controller) writeMessagesToClient(ctx context.Context) {
Expand Down Expand Up @@ -85,13 +84,13 @@ func (c *Controller) readMessagesFromClient(ctx context.Context) {
return
}

baseMsg, err := c.parseMessage(msg)
baseMsg, validatedMsg, err := c.parseAndValidateMessage(msg)
if err != nil {
c.logger.Warn().Err(err).Msg("error parsing base message")
c.logger.Debug().Err(err).Msg("error parsing and validating client message")
return
}

if err := c.dispatchAction(baseMsg.Action, msg); err != nil {
if err := c.handleAction(ctx, baseMsg.Action, validatedMsg); err != nil {
c.logger.Warn().Err(err).Str("action", baseMsg.Action).Msg("error handling action")
}
}
Expand All @@ -106,55 +105,68 @@ func (c *Controller) readMessage() (json.RawMessage, error) {
return message, nil
}

func (c *Controller) parseMessage(message json.RawMessage) (BaseMessageRequest, error) {
var baseMsg BaseMessageRequest
func (c *Controller) parseAndValidateMessage(message json.RawMessage) (models.BaseMessageRequest, interface{}, error) {
var baseMsg models.BaseMessageRequest
if err := json.Unmarshal(message, &baseMsg); err != nil {
return BaseMessageRequest{}, fmt.Errorf("error unmarshalling base message: %w", err)
return models.BaseMessageRequest{}, nil, fmt.Errorf("error unmarshalling base message: %w", err)
}
return baseMsg, nil
}

// dispatchAction routes the action to the appropriate handler based on the action type.
func (c *Controller) dispatchAction(action string, message json.RawMessage) error {
switch action {
var validatedMsg interface{}
switch baseMsg.Action {
case "subscribe":
var subscribeMsg SubscribeMessageRequest
var subscribeMsg models.SubscribeMessageRequest
if err := json.Unmarshal(message, &subscribeMsg); err != nil {
return fmt.Errorf("error unmarshalling subscribe message: %w", err)
return baseMsg, nil, fmt.Errorf("error unmarshalling subscribe message: %w", err)
}
c.handleSubscribe(subscribeMsg)
//TODO: add validation logic for `topic` field
validatedMsg = subscribeMsg

case "unsubscribe":
var unsubscribeMsg UnsubscribeMessageRequest
var unsubscribeMsg models.UnsubscribeMessageRequest
if err := json.Unmarshal(message, &unsubscribeMsg); err != nil {
return fmt.Errorf("error unmarshalling unsubscribe message: %w", err)
return baseMsg, nil, fmt.Errorf("error unmarshalling unsubscribe message: %w", err)
}
c.handleUnsubscribe(unsubscribeMsg)
validatedMsg = unsubscribeMsg

case "list_subscriptions":
var listMsg ListSubscriptionsMessageRequest
var listMsg models.ListSubscriptionsMessageRequest
if err := json.Unmarshal(message, &listMsg); err != nil {
return fmt.Errorf("error unmarshalling list subscriptions message: %w", err)
return baseMsg, nil, fmt.Errorf("error unmarshalling list subscriptions message: %w", err)
}
c.handleListSubscriptions(listMsg)
validatedMsg = listMsg

default:
c.logger.Debug().Str("action", baseMsg.Action).Msg("unknown action type")
return baseMsg, nil, fmt.Errorf("unknown action type: %s", baseMsg.Action)
}

return baseMsg, validatedMsg, nil
}

func (c *Controller) handleAction(ctx context.Context, action string, message interface{}) error {
illia-malachyn marked this conversation as resolved.
Show resolved Hide resolved
switch action {
case "subscribe":
c.handleSubscribe(ctx, message.(models.SubscribeMessageRequest))
case "unsubscribe":
c.handleUnsubscribe(ctx, message.(models.UnsubscribeMessageRequest))
case "list_subscriptions":
c.handleListSubscriptions(ctx, message.(models.ListSubscriptionsMessageRequest))
default:
c.logger.Warn().Str("action", action).Msg("unknown action type")
return fmt.Errorf("unknown action type: %s", action)
}
return nil
}

func (c *Controller) handleSubscribe(msg SubscribeMessageRequest) {
dp := c.dataProvidersFactory.NewDataProvider(c.ctx, c.communicationChannel, msg.Topic)
c.dataProviders.Insert(dp.ID(), dp)
dp.Run()
func (c *Controller) handleSubscribe(ctx context.Context, msg models.SubscribeMessageRequest) {
dp := c.dataProvidersFactory.NewDataProvider(ctx, c.communicationChannel, msg.Topic)
c.dataProviders.Add(dp.ID(), dp)
dp.Run(ctx)
illia-malachyn marked this conversation as resolved.
Show resolved Hide resolved
}

func (c *Controller) handleUnsubscribe(msg UnsubscribeMessageRequest) {
func (c *Controller) handleUnsubscribe(ctx context.Context, msg models.UnsubscribeMessageRequest) {
id, err := uuid.Parse(msg.ID)
if err != nil {
c.logger.Warn().Err(err).Str("topic", msg.Topic).Msg("error parsing message ID")
c.logger.Debug().Err(err).Msg("error parsing message ID")
illia-malachyn marked this conversation as resolved.
Show resolved Hide resolved
return
}

Expand All @@ -165,7 +177,8 @@ func (c *Controller) handleUnsubscribe(msg UnsubscribeMessageRequest) {
}
}
illia-malachyn marked this conversation as resolved.
Show resolved Hide resolved

func (c *Controller) handleListSubscriptions(msg ListSubscriptionsMessageRequest) {}
func (c *Controller) handleListSubscriptions(ctx context.Context, msg models.ListSubscriptionsMessageRequest) {
}

func (c *Controller) shutdownConnection() {
defer close(c.communicationChannel)
Expand All @@ -175,8 +188,13 @@ func (c *Controller) shutdownConnection() {
}
}(c.conn)

c.dataProviders.ForEach(func(_ uuid.UUID, dp dp.DataProvider) {
err := c.dataProviders.ForEach(func(_ uuid.UUID, dp dp.DataProvider) error {
dp.Close()
return nil
})
if err != nil {
c.logger.Error().Err(err).Msg("error closing data provider")
}

c.dataProviders.Clear()
}
5 changes: 2 additions & 3 deletions engine/access/rest/websockets/data_provider/blocks.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,12 @@ func NewMockBlockProvider(
}
}

func (p *MockBlockProvider) Run() {
func (p *MockBlockProvider) Run(_ context.Context) {
select {
case <-p.ctx.Done():
return
default:
p.ch <- "hello"
p.ch <- "world"
p.ch <- "hello world"
illia-malachyn marked this conversation as resolved.
Show resolved Hide resolved
}
}

Expand Down
4 changes: 3 additions & 1 deletion engine/access/rest/websockets/data_provider/provider.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
package data_provider

import (
"context"

"github.com/google/uuid"
)

type DataProvider interface {
Run()
Run(ctx context.Context)
ID() uuid.UUID
Topic() string
Close()
Expand Down
9 changes: 4 additions & 5 deletions engine/access/rest/websockets/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ type Handler struct {
*common.HttpHandler

logger zerolog.Logger
websocketConfig *Config
websocketConfig Config
streamApi state_stream.API
streamConfig backend.Config
}
Expand All @@ -26,7 +26,7 @@ var _ http.Handler = (*Handler)(nil)

func NewWebSocketHandler(
logger zerolog.Logger,
config *Config,
config Config,
chain flow.Chain,
streamApi state_stream.API,
streamConfig backend.Config,
Expand Down Expand Up @@ -64,7 +64,6 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}

ctx := context.Background()
controller := NewWebSocketController(ctx, logger, h.websocketConfig, h.streamApi, h.streamConfig, conn)
controller.HandleConnection()
controller := NewWebSocketController(logger, h.websocketConfig, h.streamApi, h.streamConfig, conn)
controller.HandleConnection(context.TODO())
illia-malachyn marked this conversation as resolved.
Show resolved Hide resolved
}
21 changes: 8 additions & 13 deletions engine/access/rest/websockets/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/stretchr/testify/suite"

"github.com/onflow/flow-go/engine/access/rest/websockets"
"github.com/onflow/flow-go/engine/access/rest/websockets/models"
"github.com/onflow/flow-go/engine/access/state_stream/backend"
streammock "github.com/onflow/flow-go/engine/access/state_stream/mock"
"github.com/onflow/flow-go/model/flow"
Expand All @@ -28,17 +29,17 @@ type WsHandlerSuite struct {

logger zerolog.Logger
handler *websockets.Handler
wsConfig *websockets.Config
wsConfig websockets.Config
streamApi *streammock.API
streamConfig *backend.Config
streamConfig backend.Config
}

func (s *WsHandlerSuite) SetupTest() {
s.logger = unittest.Logger()
s.wsConfig = websockets.NewDefaultWebsocketConfig()
s.streamApi = streammock.NewAPI(s.T())
s.streamConfig = &backend.Config{}
s.handler = websockets.NewWebSocketHandler(s.logger, s.wsConfig, chainID.Chain(), s.streamApi, *s.streamConfig, 1024)
s.streamConfig = backend.Config{}
s.handler = websockets.NewWebSocketHandler(s.logger, s.wsConfig, chainID.Chain(), s.streamApi, s.streamConfig, 1024)
}

func TestWsHandlerSuite(t *testing.T) {
Expand All @@ -65,8 +66,8 @@ func (s *WsHandlerSuite) TestSubscribeRequest() {
args := map[string]interface{}{
"start_block_height": 10,
}
body := websockets.SubscribeMessageRequest{
BaseMessageRequest: websockets.BaseMessageRequest{Action: "subscribe"},
body := models.SubscribeMessageRequest{
BaseMessageRequest: models.BaseMessageRequest{Action: "subscribe"},
Topic: "blocks",
Arguments: args,
}
Expand All @@ -80,12 +81,6 @@ func (s *WsHandlerSuite) TestSubscribeRequest() {
require.NoError(s.T(), err)

actualMsg := strings.Trim(string(msg), "\n\"\\ ")
require.Equal(s.T(), "hello", actualMsg)

_, msg, err = conn.ReadMessage()
require.NoError(s.T(), err)

actualMsg = strings.Trim(string(msg), "\n\"\\ ")
require.Equal(s.T(), "world", actualMsg)
require.Equal(s.T(), "hello world", actualMsg)
})
}
Loading
Loading