diff --git a/engine/access/rest/http/request/event_type.go b/engine/access/rest/common/parser/event_type.go similarity index 98% rename from engine/access/rest/http/request/event_type.go rename to engine/access/rest/common/parser/event_type.go index c3f425d81c8..f1ba7ca1acb 100644 --- a/engine/access/rest/http/request/event_type.go +++ b/engine/access/rest/common/parser/event_type.go @@ -1,4 +1,4 @@ -package request +package parser import ( "fmt" diff --git a/engine/access/rest/http/request/get_events.go b/engine/access/rest/http/request/get_events.go index c864cf24a47..dee55f98ded 100644 --- a/engine/access/rest/http/request/get_events.go +++ b/engine/access/rest/http/request/get_events.go @@ -71,7 +71,7 @@ func (g *GetEvents) Parse(rawType string, rawStart string, rawEnd string, rawBlo if rawType == "" { return fmt.Errorf("event type must be provided") } - var eventType EventType + var eventType parser.EventType err = eventType.Parse(rawType) if err != nil { return err diff --git a/engine/access/rest/server.go b/engine/access/rest/server.go index 4f0e2260ae5..c45919725b2 100644 --- a/engine/access/rest/server.go +++ b/engine/access/rest/server.go @@ -51,7 +51,13 @@ func NewServer(serverAPI access.API, builder.AddLegacyWebsocketsRoutes(stateStreamApi, chain, stateStreamConfig, config.MaxRequestSize) } - dataProviderFactory := dp.NewDataProviderFactory(logger, stateStreamApi, serverAPI) + dataProviderFactory := dp.NewDataProviderFactory( + logger, + stateStreamApi, + serverAPI, + chain, + stateStreamConfig.EventFilterConfig, + stateStreamConfig.HeartbeatInterval) builder.AddWebsocketsRoute(chain, wsConfig, config.MaxRequestSize, dataProviderFactory) c := cors.New(cors.Options{ diff --git a/engine/access/rest/websockets/data_providers/account_statuses_provider.go b/engine/access/rest/websockets/data_providers/account_statuses_provider.go new file mode 100644 index 00000000000..1a3aee203c9 --- /dev/null +++ b/engine/access/rest/websockets/data_providers/account_statuses_provider.go @@ -0,0 +1,207 @@ +package data_providers + +import ( + "context" + "fmt" + "strconv" + + "github.com/rs/zerolog" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + "github.com/onflow/flow-go/engine/access/rest/common/parser" + "github.com/onflow/flow-go/engine/access/rest/http/request" + "github.com/onflow/flow-go/engine/access/rest/util" + "github.com/onflow/flow-go/engine/access/rest/websockets/models" + "github.com/onflow/flow-go/engine/access/state_stream" + "github.com/onflow/flow-go/engine/access/state_stream/backend" + "github.com/onflow/flow-go/engine/access/subscription" + "github.com/onflow/flow-go/model/flow" + "github.com/onflow/flow-go/module/counters" +) + +type AccountStatusesArguments struct { + StartBlockID flow.Identifier // ID of the block to start subscription from + StartBlockHeight uint64 // Height of the block to start subscription from + Filter state_stream.AccountStatusFilter // Filter applied to events for a given subscription +} + +type AccountStatusesDataProvider struct { + *baseDataProvider + + logger zerolog.Logger + stateStreamApi state_stream.API + + heartbeatInterval uint64 +} + +var _ DataProvider = (*AccountStatusesDataProvider)(nil) + +// NewAccountStatusesDataProvider creates a new instance of AccountStatusesDataProvider. +func NewAccountStatusesDataProvider( + ctx context.Context, + logger zerolog.Logger, + stateStreamApi state_stream.API, + topic string, + arguments models.Arguments, + send chan<- interface{}, + chain flow.Chain, + eventFilterConfig state_stream.EventFilterConfig, + heartbeatInterval uint64, +) (*AccountStatusesDataProvider, error) { + p := &AccountStatusesDataProvider{ + logger: logger.With().Str("component", "account-statuses-data-provider").Logger(), + stateStreamApi: stateStreamApi, + heartbeatInterval: heartbeatInterval, + } + + // Initialize arguments passed to the provider. + accountStatusesArgs, err := parseAccountStatusesArguments(arguments, chain, eventFilterConfig) + if err != nil { + return nil, fmt.Errorf("invalid arguments for account statuses data provider: %w", err) + } + + subCtx, cancel := context.WithCancel(ctx) + + p.baseDataProvider = newBaseDataProvider( + topic, + cancel, + send, + p.createSubscription(subCtx, accountStatusesArgs), // Set up a subscription to account statuses based on arguments. + ) + + return p, nil +} + +// Run starts processing the subscription for events and handles responses. +// +// No errors are expected during normal operations. +func (p *AccountStatusesDataProvider) Run() error { + return subscription.HandleSubscription(p.subscription, p.handleResponse()) +} + +// createSubscription creates a new subscription using the specified input arguments. +func (p *AccountStatusesDataProvider) createSubscription(ctx context.Context, args AccountStatusesArguments) subscription.Subscription { + if args.StartBlockID != flow.ZeroID { + return p.stateStreamApi.SubscribeAccountStatusesFromStartBlockID(ctx, args.StartBlockID, args.Filter) + } + + if args.StartBlockHeight != request.EmptyHeight { + return p.stateStreamApi.SubscribeAccountStatusesFromStartHeight(ctx, args.StartBlockHeight, args.Filter) + } + + return p.stateStreamApi.SubscribeAccountStatusesFromLatestBlock(ctx, args.Filter) +} + +// handleResponse processes an account statuses and sends the formatted response. +// +// No errors are expected during normal operations. +func (p *AccountStatusesDataProvider) handleResponse() func(accountStatusesResponse *backend.AccountStatusesResponse) error { + blocksSinceLastMessage := uint64(0) + messageIndex := counters.NewMonotonousCounter(1) + + return func(accountStatusesResponse *backend.AccountStatusesResponse) error { + // check if there are any events in the response. if not, do not send a message unless the last + // response was more than HeartbeatInterval blocks ago + if len(accountStatusesResponse.AccountEvents) == 0 { + blocksSinceLastMessage++ + if blocksSinceLastMessage < p.heartbeatInterval { + return nil + } + blocksSinceLastMessage = 0 + } + + index := messageIndex.Value() + if ok := messageIndex.Set(messageIndex.Value() + 1); !ok { + return status.Errorf(codes.Internal, "message index already incremented to %d", messageIndex.Value()) + } + + p.send <- &models.AccountStatusesResponse{ + BlockID: accountStatusesResponse.BlockID.String(), + Height: strconv.FormatUint(accountStatusesResponse.Height, 10), + AccountEvents: accountStatusesResponse.AccountEvents, + MessageIndex: strconv.FormatUint(index, 10), + } + + return nil + } +} + +// parseAccountStatusesArguments validates and initializes the account statuses arguments. +func parseAccountStatusesArguments( + arguments models.Arguments, + chain flow.Chain, + eventFilterConfig state_stream.EventFilterConfig, +) (AccountStatusesArguments, error) { + var args AccountStatusesArguments + + // Check for mutual exclusivity of start_block_id and start_block_height early + startBlockIDIn, hasStartBlockID := arguments["start_block_id"] + startBlockHeightIn, hasStartBlockHeight := arguments["start_block_height"] + + if hasStartBlockID && hasStartBlockHeight { + return args, fmt.Errorf("can only provide either 'start_block_id' or 'start_block_height'") + } + + // Parse 'start_block_id' if provided + if hasStartBlockID { + result, ok := startBlockIDIn.(string) + if !ok { + return args, fmt.Errorf("'start_block_id' must be a string") + } + var startBlockID parser.ID + err := startBlockID.Parse(result) + if err != nil { + return args, fmt.Errorf("invalid 'start_block_id': %w", err) + } + args.StartBlockID = startBlockID.Flow() + } + + // Parse 'start_block_height' if provided + // Parse 'start_block_height' if provided + if hasStartBlockHeight { + result, ok := startBlockHeightIn.(string) + if !ok { + return args, fmt.Errorf("'start_block_height' must be a string") + } + startBlockHeight, err := util.ToUint64(result) + if err != nil { + return args, fmt.Errorf("invalid 'start_block_height': %w", err) + } + args.StartBlockHeight = startBlockHeight + } else { + args.StartBlockHeight = request.EmptyHeight + } + + // Parse 'event_types' as a JSON array + var eventTypes parser.EventTypes + if eventTypesIn, ok := arguments["event_types"]; ok && eventTypesIn != "" { + result, ok := eventTypesIn.([]string) + if !ok { + return args, fmt.Errorf("'event_types' must be an array of string") + } + + err := eventTypes.Parse(result) + if err != nil { + return args, fmt.Errorf("invalid 'event_types': %w", err) + } + } + + // Parse 'accountAddresses' as []string{} + var accountAddresses []string + if accountAddressesIn, ok := arguments["account_addresses"]; ok && accountAddressesIn != "" { + accountAddresses, ok = accountAddressesIn.([]string) + if !ok { + return args, fmt.Errorf("'account_addresses' must be an array of string") + } + } + + // Initialize the event filter with the parsed arguments + filter, err := state_stream.NewAccountStatusFilter(eventFilterConfig, chain, eventTypes.Flow(), accountAddresses) + if err != nil { + return args, fmt.Errorf("failed to create event filter: %w", err) + } + args.Filter = filter + + return args, nil +} diff --git a/engine/access/rest/websockets/data_providers/account_statuses_provider_test.go b/engine/access/rest/websockets/data_providers/account_statuses_provider_test.go new file mode 100644 index 00000000000..aeadd68f649 --- /dev/null +++ b/engine/access/rest/websockets/data_providers/account_statuses_provider_test.go @@ -0,0 +1,303 @@ +package data_providers + +import ( + "context" + "strconv" + "testing" + + "github.com/rs/zerolog" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + + "github.com/onflow/flow-go/engine/access/rest/websockets/models" + "github.com/onflow/flow-go/engine/access/state_stream" + "github.com/onflow/flow-go/engine/access/state_stream/backend" + ssmock "github.com/onflow/flow-go/engine/access/state_stream/mock" + "github.com/onflow/flow-go/engine/access/subscription" + "github.com/onflow/flow-go/model/flow" + "github.com/onflow/flow-go/utils/unittest" +) + +// AccountStatusesProviderSuite is a test suite for testing the account statuses providers functionality. +type AccountStatusesProviderSuite struct { + suite.Suite + + log zerolog.Logger + api *ssmock.API + + chain flow.Chain + rootBlock flow.Block + finalizedBlock *flow.Header + + factory *DataProviderFactoryImpl +} + +func TestNewAccountStatusesDataProvider(t *testing.T) { + suite.Run(t, new(AccountStatusesProviderSuite)) +} + +func (s *AccountStatusesProviderSuite) SetupTest() { + s.log = unittest.Logger() + s.api = ssmock.NewAPI(s.T()) + + s.chain = flow.Testnet.Chain() + + s.rootBlock = unittest.BlockFixture() + s.rootBlock.Header.Height = 0 + + s.factory = NewDataProviderFactory( + s.log, + s.api, + nil, + flow.Testnet.Chain(), + state_stream.DefaultEventFilterConfig, + subscription.DefaultHeartbeatInterval) + s.Require().NotNil(s.factory) +} + +// TestAccountStatusesDataProvider_HappyPath tests the behavior of the account statuses data provider +// when it is configured correctly and operating under normal conditions. It +// validates that events are correctly streamed to the channel and ensures +// no unexpected errors occur. +func (s *AccountStatusesProviderSuite) TestAccountStatusesDataProvider_HappyPath() { + s.testHappyPath( + AccountStatusesTopic, + s.subscribeAccountStatusesDataProviderTestCases(), + s.requireAccountStatuses, + ) +} + +func (s *AccountStatusesProviderSuite) testHappyPath( + topic string, + tests []testType, + requireFn func(interface{}, *backend.AccountStatusesResponse), +) { + expectedEvents := []flow.Event{ + unittest.EventFixture(state_stream.CoreEventAccountCreated, 0, 0, unittest.IdentifierFixture(), 0), + unittest.EventFixture(state_stream.CoreEventAccountKeyAdded, 0, 0, unittest.IdentifierFixture(), 0), + } + + var expectedAccountStatusesResponses []backend.AccountStatusesResponse + + for i := 0; i < len(expectedEvents); i++ { + expectedAccountStatusesResponses = append(expectedAccountStatusesResponses, backend.AccountStatusesResponse{ + Height: s.rootBlock.Header.Height, + BlockID: s.rootBlock.ID(), + AccountEvents: map[string]flow.EventsList{ + unittest.RandomAddressFixture().String(): expectedEvents, + }, + }) + } + + for _, test := range tests { + s.Run(test.name, func() { + ctx := context.Background() + send := make(chan interface{}, 10) + + // Create a channel to simulate the subscription's data channel + accStatusesChan := make(chan interface{}) + + // // Create a mock subscription and mock the channel + sub := ssmock.NewSubscription(s.T()) + sub.On("Channel").Return((<-chan interface{})(accStatusesChan)) + sub.On("Err").Return(nil) + test.setupBackend(sub) + + // Create the data provider instance + provider, err := s.factory.NewDataProvider(ctx, topic, test.arguments, send) + s.Require().NotNil(provider) + s.Require().NoError(err) + + // Run the provider in a separate goroutine + go func() { + err = provider.Run() + s.Require().NoError(err) + }() + + // Simulate emitting data to the events channel + go func() { + defer close(accStatusesChan) + + for i := 0; i < len(expectedAccountStatusesResponses); i++ { + accStatusesChan <- &expectedAccountStatusesResponses[i] + } + }() + + // Collect responses + for _, e := range expectedAccountStatusesResponses { + v, ok := <-send + s.Require().True(ok, "channel closed while waiting for event %v: err: %v", e.BlockID, sub.Err()) + + requireFn(v, &e) + } + + // Ensure the provider is properly closed after the test + provider.Close() + }) + } +} + +func (s *AccountStatusesProviderSuite) subscribeAccountStatusesDataProviderTestCases() []testType { + return []testType{ + { + name: "SubscribeAccountStatusesFromStartBlockID happy path", + arguments: models.Arguments{ + "start_block_id": s.rootBlock.ID().String(), + "event_types": []string{"flow.AccountCreated", "flow.AccountKeyAdded"}, + }, + setupBackend: func(sub *ssmock.Subscription) { + s.api.On( + "SubscribeAccountStatusesFromStartBlockID", + mock.Anything, + s.rootBlock.ID(), + mock.Anything, + ).Return(sub).Once() + }, + }, + { + name: "SubscribeAccountStatusesFromStartHeight happy path", + arguments: models.Arguments{ + "start_block_height": strconv.FormatUint(s.rootBlock.Header.Height, 10), + }, + setupBackend: func(sub *ssmock.Subscription) { + s.api.On( + "SubscribeAccountStatusesFromStartHeight", + mock.Anything, + s.rootBlock.Header.Height, + mock.Anything, + ).Return(sub).Once() + }, + }, + { + name: "SubscribeAccountStatusesFromLatestBlock happy path", + arguments: models.Arguments{}, + setupBackend: func(sub *ssmock.Subscription) { + s.api.On( + "SubscribeAccountStatusesFromLatestBlock", + mock.Anything, + mock.Anything, + ).Return(sub).Once() + }, + }, + } +} + +// requireAccountStatuses ensures that the received account statuses information matches the expected data. +func (s *AccountStatusesProviderSuite) requireAccountStatuses( + v interface{}, + expectedAccountStatusesResponse *backend.AccountStatusesResponse, +) { + _, ok := v.(*models.AccountStatusesResponse) + require.True(s.T(), ok, "Expected *models.AccountStatusesResponse, got %T", v) + + //s.Require().ElementsMatch(expectedAccountStatusesResponse.AccountEvents, actualResponse.AccountEvents) +} + +// TestAccountStatusesDataProvider_InvalidArguments tests the behavior of the account statuses data provider +// when invalid arguments are provided. It verifies that appropriate errors are returned +// for missing or conflicting arguments. +// This test covers the test cases: +// 1. Providing both 'start_block_id' and 'start_block_height' simultaneously. +// 2. Invalid 'start_block_id' argument. +// 3. Invalid 'start_block_height' argument. +func (s *AccountStatusesProviderSuite) TestAccountStatusesDataProvider_InvalidArguments() { + ctx := context.Background() + send := make(chan interface{}) + + topic := AccountStatusesTopic + + for _, test := range invalidArgumentsTestCases() { + s.Run(test.name, func() { + provider, err := NewAccountStatusesDataProvider( + ctx, + s.log, + s.api, + topic, + test.arguments, + send, + s.chain, + state_stream.DefaultEventFilterConfig, + subscription.DefaultHeartbeatInterval, + ) + s.Require().Nil(provider) + s.Require().Error(err) + s.Require().Contains(err.Error(), test.expectedErrorMsg) + }) + } +} + +// TestMessageIndexAccountStatusesProviderResponse_HappyPath tests that MessageIndex values in response are strictly increasing. +func (s *AccountStatusesProviderSuite) TestMessageIndexAccountStatusesProviderResponse_HappyPath() { + ctx := context.Background() + send := make(chan interface{}, 10) + topic := AccountStatusesTopic + accountStatusesCount := 4 + + // Create a channel to simulate the subscription's account statuses channel + accountStatusesChan := make(chan interface{}) + + // Create a mock subscription and mock the channel + sub := ssmock.NewSubscription(s.T()) + sub.On("Channel").Return((<-chan interface{})(accountStatusesChan)) + sub.On("Err").Return(nil) + + s.api.On("SubscribeAccountStatusesFromStartBlockID", mock.Anything, mock.Anything, mock.Anything).Return(sub) + + arguments := + map[string]interface{}{ + "start_block_id": s.rootBlock.ID().String(), + } + + // Create the AccountStatusesDataProvider instance + provider, err := NewAccountStatusesDataProvider( + ctx, + s.log, + s.api, + topic, + arguments, + send, + s.chain, + state_stream.DefaultEventFilterConfig, + subscription.DefaultHeartbeatInterval, + ) + s.Require().NotNil(provider) + s.Require().NoError(err) + + // Run the provider in a separate goroutine to simulate subscription processing + go func() { + err = provider.Run() + s.Require().NoError(err) + }() + + // Simulate emitting data to the account statuses channel + go func() { + defer close(accountStatusesChan) // Close the channel when done + + for i := 0; i < accountStatusesCount; i++ { + accountStatusesChan <- &backend.AccountStatusesResponse{} + } + }() + + // Collect responses + var responses []*models.AccountStatusesResponse + for i := 0; i < accountStatusesCount; i++ { + res := <-send + accountStatusesRes, ok := res.(*models.AccountStatusesResponse) + s.Require().True(ok, "Expected *models.AccountStatusesResponse, got %T", res) + responses = append(responses, accountStatusesRes) + } + + // Verifying that indices are starting from 1 + s.Require().Equal("1", responses[0].MessageIndex, "Expected MessageIndex to start with 1") + + // Verifying that indices are strictly increasing + for i := 1; i < len(responses); i++ { + prevIndex, _ := strconv.Atoi(responses[i-1].MessageIndex) + currentIndex, _ := strconv.Atoi(responses[i].MessageIndex) + s.Require().Equal(prevIndex+1, currentIndex, "Expected MessageIndex to increment by 1") + } + + // Ensure the provider is properly closed after the test + provider.Close() +} diff --git a/engine/access/rest/websockets/data_providers/blocks_provider.go b/engine/access/rest/websockets/data_providers/blocks_provider.go index 72cfaa6f554..28e0a9a03d2 100644 --- a/engine/access/rest/websockets/data_providers/blocks_provider.go +++ b/engine/access/rest/websockets/data_providers/blocks_provider.go @@ -96,7 +96,11 @@ func ParseBlocksArguments(arguments models.Arguments) (BlocksArguments, error) { // Parse 'block_status' if blockStatusIn, ok := arguments["block_status"]; ok { - blockStatus, err := parser.ParseBlockStatus(blockStatusIn) + result, ok := blockStatusIn.(string) + if !ok { + return args, fmt.Errorf("'block_status' must be string") + } + blockStatus, err := parser.ParseBlockStatus(result) if err != nil { return args, err } @@ -113,24 +117,31 @@ func ParseBlocksArguments(arguments models.Arguments) (BlocksArguments, error) { return args, fmt.Errorf("can only provide either 'start_block_id' or 'start_block_height'") } - // Parse 'start_block_id' if provided if hasStartBlockID { + result, ok := startBlockIDIn.(string) + if !ok { + return args, fmt.Errorf("'start_block_id' must be a string") + } var startBlockID parser.ID - err := startBlockID.Parse(startBlockIDIn) + err := startBlockID.Parse(result) if err != nil { return args, err } args.StartBlockID = startBlockID.Flow() } - // Parse 'start_block_height' if provided if hasStartBlockHeight { - var err error - args.StartBlockHeight, err = util.ToUint64(startBlockHeightIn) + result, ok := startBlockHeightIn.(string) + if !ok { + return args, fmt.Errorf("'start_block_height' must be a string") + } + startBlockHeight, err := util.ToUint64(result) if err != nil { return args, fmt.Errorf("invalid 'start_block_height': %w", err) } + args.StartBlockHeight = startBlockHeight } else { + // Default value if 'start_block_height' is not provided args.StartBlockHeight = request.EmptyHeight } diff --git a/engine/access/rest/websockets/data_providers/blocks_provider_test.go b/engine/access/rest/websockets/data_providers/blocks_provider_test.go index 9e07f9459e9..51cf9e63e44 100644 --- a/engine/access/rest/websockets/data_providers/blocks_provider_test.go +++ b/engine/access/rest/websockets/data_providers/blocks_provider_test.go @@ -15,7 +15,9 @@ import ( accessmock "github.com/onflow/flow-go/access/mock" "github.com/onflow/flow-go/engine/access/rest/common/parser" "github.com/onflow/flow-go/engine/access/rest/websockets/models" + "github.com/onflow/flow-go/engine/access/state_stream" statestreamsmock "github.com/onflow/flow-go/engine/access/state_stream/mock" + "github.com/onflow/flow-go/engine/access/subscription" "github.com/onflow/flow-go/model/flow" "github.com/onflow/flow-go/utils/unittest" ) @@ -73,7 +75,13 @@ func (s *BlocksProviderSuite) SetupTest() { } s.finalizedBlock = parent - s.factory = NewDataProviderFactory(s.log, nil, s.api) + s.factory = NewDataProviderFactory( + s.log, + nil, + s.api, + flow.Testnet.Chain(), + state_stream.DefaultEventFilterConfig, + subscription.DefaultHeartbeatInterval) s.Require().NotNil(s.factory) } diff --git a/engine/access/rest/websockets/data_providers/events_provider.go b/engine/access/rest/websockets/data_providers/events_provider.go new file mode 100644 index 00000000000..6b62f45ffac --- /dev/null +++ b/engine/access/rest/websockets/data_providers/events_provider.go @@ -0,0 +1,216 @@ +package data_providers + +import ( + "context" + "fmt" + "strconv" + + "github.com/rs/zerolog" + + "github.com/onflow/flow-go/engine/access/rest/common/parser" + "github.com/onflow/flow-go/engine/access/rest/http/request" + "github.com/onflow/flow-go/engine/access/rest/util" + "github.com/onflow/flow-go/engine/access/rest/websockets/models" + "github.com/onflow/flow-go/engine/access/state_stream" + "github.com/onflow/flow-go/engine/access/state_stream/backend" + "github.com/onflow/flow-go/engine/access/subscription" + "github.com/onflow/flow-go/model/flow" + "github.com/onflow/flow-go/module/counters" +) + +// EventsArguments contains the arguments required for subscribing to events +type EventsArguments struct { + StartBlockID flow.Identifier // ID of the block to start subscription from + StartBlockHeight uint64 // Height of the block to start subscription from + Filter state_stream.EventFilter // Filter applied to events for a given subscription +} + +// EventsDataProvider is responsible for providing events +type EventsDataProvider struct { + *baseDataProvider + + logger zerolog.Logger + stateStreamApi state_stream.API + + heartbeatInterval uint64 +} + +var _ DataProvider = (*EventsDataProvider)(nil) + +// NewEventsDataProvider creates a new instance of EventsDataProvider. +func NewEventsDataProvider( + ctx context.Context, + logger zerolog.Logger, + stateStreamApi state_stream.API, + topic string, + arguments models.Arguments, + send chan<- interface{}, + chain flow.Chain, + eventFilterConfig state_stream.EventFilterConfig, + heartbeatInterval uint64, +) (*EventsDataProvider, error) { + p := &EventsDataProvider{ + logger: logger.With().Str("component", "events-data-provider").Logger(), + stateStreamApi: stateStreamApi, + heartbeatInterval: heartbeatInterval, + } + + // Initialize arguments passed to the provider. + eventArgs, err := parseEventsArguments(arguments, chain, eventFilterConfig) + if err != nil { + return nil, fmt.Errorf("invalid arguments for events data provider: %w", err) + } + + subCtx, cancel := context.WithCancel(ctx) + + p.baseDataProvider = newBaseDataProvider( + topic, + cancel, + send, + p.createSubscription(subCtx, eventArgs), // Set up a subscription to events based on arguments. + ) + + return p, nil +} + +// Run starts processing the subscription for events and handles responses. +// +// No errors are expected during normal operations. +func (p *EventsDataProvider) Run() error { + return subscription.HandleSubscription(p.subscription, p.handleResponse()) +} + +// handleResponse processes events and sends the formatted response. +// +// No errors are expected during normal operations. +func (p *EventsDataProvider) handleResponse() func(eventsResponse *backend.EventsResponse) error { + blocksSinceLastMessage := uint64(0) + messageIndex := counters.NewMonotonousCounter(1) + + return func(eventsResponse *backend.EventsResponse) error { + // check if there are any events in the response. if not, do not send a message unless the last + // response was more than HeartbeatInterval blocks ago + if len(eventsResponse.Events) == 0 { + blocksSinceLastMessage++ + if blocksSinceLastMessage < p.heartbeatInterval { + return nil + } + blocksSinceLastMessage = 0 + } + + index := messageIndex.Value() + if ok := messageIndex.Set(messageIndex.Value() + 1); !ok { + return fmt.Errorf("message index already incremented to: %d", messageIndex.Value()) + } + + p.send <- &models.EventResponse{ + BlockId: eventsResponse.BlockID.String(), + BlockHeight: strconv.FormatUint(eventsResponse.Height, 10), + BlockTimestamp: eventsResponse.BlockTimestamp, + Events: eventsResponse.Events, + MessageIndex: strconv.FormatUint(index, 10), + } + + return nil + } +} + +// createSubscription creates a new subscription using the specified input arguments. +func (p *EventsDataProvider) createSubscription(ctx context.Context, args EventsArguments) subscription.Subscription { + if args.StartBlockID != flow.ZeroID { + return p.stateStreamApi.SubscribeEventsFromStartBlockID(ctx, args.StartBlockID, args.Filter) + } + + if args.StartBlockHeight != request.EmptyHeight { + return p.stateStreamApi.SubscribeEventsFromStartHeight(ctx, args.StartBlockHeight, args.Filter) + } + + return p.stateStreamApi.SubscribeEventsFromLatest(ctx, args.Filter) +} + +// parseEventsArguments validates and initializes the events arguments. +func parseEventsArguments( + arguments models.Arguments, + chain flow.Chain, + eventFilterConfig state_stream.EventFilterConfig, +) (EventsArguments, error) { + var args EventsArguments + + // Check for mutual exclusivity of start_block_id and start_block_height early + startBlockIDIn, hasStartBlockID := arguments["start_block_id"] + startBlockHeightIn, hasStartBlockHeight := arguments["start_block_height"] + + if hasStartBlockID && hasStartBlockHeight { + return args, fmt.Errorf("can only provide either 'start_block_id' or 'start_block_height'") + } + + // Parse 'start_block_id' if provided + if hasStartBlockID { + result, ok := startBlockIDIn.(string) + if !ok { + return args, fmt.Errorf("'start_block_id' must be a string") + } + var startBlockID parser.ID + err := startBlockID.Parse(result) + if err != nil { + return args, fmt.Errorf("invalid 'start_block_id': %w", err) + } + args.StartBlockID = startBlockID.Flow() + } + + // Parse 'start_block_height' if provided + if hasStartBlockHeight { + result, ok := startBlockHeightIn.(string) + if !ok { + return args, fmt.Errorf("'start_block_height' must be a string") + } + startBlockHeight, err := util.ToUint64(result) + if err != nil { + return args, fmt.Errorf("invalid 'start_block_height': %w", err) + } + args.StartBlockHeight = startBlockHeight + } else { + args.StartBlockHeight = request.EmptyHeight + } + + // Parse 'event_types' as a JSON array + var eventTypes parser.EventTypes + if eventTypesIn, ok := arguments["event_types"]; ok && eventTypesIn != "" { + result, ok := eventTypesIn.([]string) + if !ok { + return args, fmt.Errorf("'event_types' must be an array of string") + } + + err := eventTypes.Parse(result) + if err != nil { + return args, fmt.Errorf("invalid 'event_types': %w", err) + } + } + + // Parse 'addresses' as []string{} + var addresses []string + if addressesIn, ok := arguments["addresses"]; ok && addressesIn != "" { + addresses, ok = addressesIn.([]string) + if !ok { + return args, fmt.Errorf("'addresses' must be an array of string") + } + } + + // Parse 'contracts' as []string{} + var contracts []string + if contractsIn, ok := arguments["contracts"]; ok && contractsIn != "" { + contracts, ok = contractsIn.([]string) + if !ok { + return args, fmt.Errorf("'contracts' must be an array of string") + } + } + + // Initialize the event filter with the parsed arguments + filter, err := state_stream.NewEventFilter(eventFilterConfig, chain, eventTypes.Flow(), addresses, contracts) + if err != nil { + return args, fmt.Errorf("failed to create event filter: %w", err) + } + args.Filter = filter + + return args, nil +} diff --git a/engine/access/rest/websockets/data_providers/events_provider_test.go b/engine/access/rest/websockets/data_providers/events_provider_test.go new file mode 100644 index 00000000000..336744d24c7 --- /dev/null +++ b/engine/access/rest/websockets/data_providers/events_provider_test.go @@ -0,0 +1,349 @@ +package data_providers + +import ( + "context" + "fmt" + "strconv" + "testing" + + "github.com/rs/zerolog" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + + "github.com/onflow/flow-go/engine/access/rest/websockets/models" + "github.com/onflow/flow-go/engine/access/state_stream" + "github.com/onflow/flow-go/engine/access/state_stream/backend" + ssmock "github.com/onflow/flow-go/engine/access/state_stream/mock" + "github.com/onflow/flow-go/engine/access/subscription" + "github.com/onflow/flow-go/model/flow" + "github.com/onflow/flow-go/utils/unittest" +) + +// EventsProviderSuite is a test suite for testing the events providers functionality. +type EventsProviderSuite struct { + suite.Suite + + log zerolog.Logger + api *ssmock.API + + chain flow.Chain + rootBlock flow.Block + finalizedBlock *flow.Header + + factory *DataProviderFactoryImpl +} + +func TestEventsProviderSuite(t *testing.T) { + suite.Run(t, new(EventsProviderSuite)) +} + +func (s *EventsProviderSuite) SetupTest() { + s.log = unittest.Logger() + s.api = ssmock.NewAPI(s.T()) + + s.chain = flow.Testnet.Chain() + + s.rootBlock = unittest.BlockFixture() + s.rootBlock.Header.Height = 0 + + s.factory = NewDataProviderFactory( + s.log, + s.api, + nil, + flow.Testnet.Chain(), + state_stream.DefaultEventFilterConfig, + subscription.DefaultHeartbeatInterval) + s.Require().NotNil(s.factory) +} + +// TestEventsDataProvider_HappyPath tests the behavior of the events data provider +// when it is configured correctly and operating under normal conditions. It +// validates that events are correctly streamed to the channel and ensures +// no unexpected errors occur. +func (s *EventsProviderSuite) TestEventsDataProvider_HappyPath() { + s.testHappyPath( + EventsTopic, + s.subscribeEventsDataProviderTestCases(), + s.requireEvents, + ) +} + +// subscribeEventsDataProviderTestCases generates test cases for events data providers. +func (s *EventsProviderSuite) subscribeEventsDataProviderTestCases() []testType { + return []testType{ + { + name: "SubscribeBlocksFromStartBlockID happy path", + arguments: models.Arguments{ + "start_block_id": s.rootBlock.ID().String(), + "event_types": []string{"flow.AccountCreated", "flow.AccountUpdated"}, + }, + setupBackend: func(sub *ssmock.Subscription) { + s.api.On( + "SubscribeEventsFromStartBlockID", + mock.Anything, + s.rootBlock.ID(), + mock.Anything, + ).Return(sub).Once() + }, + }, + { + name: "SubscribeEventsFromStartHeight happy path", + arguments: models.Arguments{ + "start_block_height": strconv.FormatUint(s.rootBlock.Header.Height, 10), + }, + setupBackend: func(sub *ssmock.Subscription) { + s.api.On( + "SubscribeEventsFromStartHeight", + mock.Anything, + s.rootBlock.Header.Height, + mock.Anything, + ).Return(sub).Once() + }, + }, + { + name: "SubscribeEventsFromLatest happy path", + arguments: models.Arguments{}, + setupBackend: func(sub *ssmock.Subscription) { + s.api.On( + "SubscribeEventsFromLatest", + mock.Anything, + mock.Anything, + ).Return(sub).Once() + }, + }, + } +} + +// testHappyPath tests a variety of scenarios for data providers in +// happy path scenarios. This function runs parameterized test cases that +// simulate various configurations and verifies that the data provider operates +// as expected without encountering errors. +// +// Arguments: +// - topic: The topic associated with the data provider. +// - tests: A slice of test cases to run, each specifying setup and validation logic. +// - requireFn: A function to validate the output received in the send channel. +func (s *EventsProviderSuite) testHappyPath( + topic string, + tests []testType, + requireFn func(interface{}, *backend.EventsResponse), +) { + expectedEvents := []flow.Event{ + unittest.EventFixture(flow.EventAccountCreated, 0, 0, unittest.IdentifierFixture(), 0), + unittest.EventFixture(flow.EventAccountUpdated, 0, 0, unittest.IdentifierFixture(), 0), + unittest.EventFixture(flow.EventAccountCreated, 0, 0, unittest.IdentifierFixture(), 0), + unittest.EventFixture(flow.EventAccountUpdated, 0, 0, unittest.IdentifierFixture(), 0), + } + + var expectedEventsResponses []backend.EventsResponse + + for i := 0; i < len(expectedEvents); i++ { + expectedEventsResponses = append(expectedEventsResponses, backend.EventsResponse{ + Height: s.rootBlock.Header.Height, + BlockID: s.rootBlock.ID(), + Events: expectedEvents, + BlockTimestamp: s.rootBlock.Header.Timestamp, + }) + + } + + for _, test := range tests { + s.Run(test.name, func() { + ctx := context.Background() + send := make(chan interface{}, 10) + + // Create a channel to simulate the subscription's data channel + eventChan := make(chan interface{}) + + // // Create a mock subscription and mock the channel + sub := ssmock.NewSubscription(s.T()) + sub.On("Channel").Return((<-chan interface{})(eventChan)) + sub.On("Err").Return(nil) + test.setupBackend(sub) + + // Create the data provider instance + provider, err := s.factory.NewDataProvider(ctx, topic, test.arguments, send) + s.Require().NotNil(provider) + s.Require().NoError(err) + + // Run the provider in a separate goroutine + go func() { + err = provider.Run() + s.Require().NoError(err) + }() + + // Simulate emitting data to the events channel + go func() { + defer close(eventChan) + + for i := 0; i < len(expectedEventsResponses); i++ { + eventChan <- &expectedEventsResponses[i] + } + }() + + // Collect responses + for _, e := range expectedEventsResponses { + v, ok := <-send + s.Require().True(ok, "channel closed while waiting for event %v: err: %v", e.BlockID, sub.Err()) + + requireFn(v, &e) + } + + // Ensure the provider is properly closed after the test + provider.Close() + }) + } +} + +// requireEvents ensures that the received event information matches the expected data. +func (s *EventsProviderSuite) requireEvents(v interface{}, expectedEventsResponse *backend.EventsResponse) { + actualResponse, ok := v.(*models.EventResponse) + require.True(s.T(), ok, "Expected *models.EventResponse, got %T", v) + + s.Require().ElementsMatch(expectedEventsResponse.Events, actualResponse.Events) +} + +// invalidArgumentsTestCases returns a list of test cases with invalid argument combinations +// for testing the behavior of events data providers. Each test case includes a name, +// a set of input arguments, and the expected error message that should be returned. +// +// The test cases cover scenarios such as: +// 1. Supplying both 'start_block_id' and 'start_block_height' simultaneously, which is not allowed. +// 2. Providing invalid 'start_block_id' value. +// 3. Providing invalid 'start_block_height' value. +func invalidArgumentsTestCases() []testErrType { + return []testErrType{ + { + name: "provide both 'start_block_id' and 'start_block_height' arguments", + arguments: models.Arguments{ + "start_block_id": unittest.BlockFixture().ID().String(), + "start_block_height": fmt.Sprintf("%d", unittest.BlockFixture().Header.Height), + }, + expectedErrorMsg: "can only provide either 'start_block_id' or 'start_block_height'", + }, + { + name: "invalid 'start_block_id' argument", + arguments: map[string]interface{}{ + "start_block_id": "invalid_block_id", + }, + expectedErrorMsg: "invalid ID format", + }, + { + name: "invalid 'start_block_height' argument", + arguments: map[string]interface{}{ + "start_block_height": "-1", + }, + expectedErrorMsg: "value must be an unsigned 64 bit integer", + }, + } +} + +// TestEventsDataProvider_InvalidArguments tests the behavior of the event data provider +// when invalid arguments are provided. It verifies that appropriate errors are returned +// for missing or conflicting arguments. +// This test covers the test cases: +// 1. Providing both 'start_block_id' and 'start_block_height' simultaneously. +// 2. Invalid 'start_block_id' argument. +// 3. Invalid 'start_block_height' argument. +func (s *EventsProviderSuite) TestEventsDataProvider_InvalidArguments() { + ctx := context.Background() + send := make(chan interface{}) + + topic := EventsTopic + + for _, test := range invalidArgumentsTestCases() { + s.Run(test.name, func() { + provider, err := NewEventsDataProvider( + ctx, + s.log, + s.api, + topic, + test.arguments, + send, + s.chain, + state_stream.DefaultEventFilterConfig, + subscription.DefaultHeartbeatInterval, + ) + s.Require().Nil(provider) + s.Require().Error(err) + s.Require().Contains(err.Error(), test.expectedErrorMsg) + }) + } +} + +// TestMessageIndexEventProviderResponse_HappyPath tests that MessageIndex values in response are strictly increasing. +func (s *EventsProviderSuite) TestMessageIndexEventProviderResponse_HappyPath() { + ctx := context.Background() + send := make(chan interface{}, 10) + topic := EventsTopic + eventsCount := 4 + + // Create a channel to simulate the subscription's event channel + eventChan := make(chan interface{}) + + // Create a mock subscription and mock the channel + sub := ssmock.NewSubscription(s.T()) + sub.On("Channel").Return((<-chan interface{})(eventChan)) + sub.On("Err").Return(nil) + + s.api.On("SubscribeEventsFromStartBlockID", mock.Anything, mock.Anything, mock.Anything).Return(sub) + + arguments := + map[string]interface{}{ + "start_block_id": s.rootBlock.ID().String(), + } + + // Create the EventsDataProvider instance + provider, err := NewEventsDataProvider( + ctx, + s.log, + s.api, + topic, + arguments, + send, + s.chain, + state_stream.DefaultEventFilterConfig, + subscription.DefaultHeartbeatInterval) + s.Require().NotNil(provider) + s.Require().NoError(err) + + // Run the provider in a separate goroutine to simulate subscription processing + go func() { + err = provider.Run() + s.Require().NoError(err) + }() + + // Simulate emitting events to the event channel + go func() { + defer close(eventChan) // Close the channel when done + + for i := 0; i < eventsCount; i++ { + eventChan <- &backend.EventsResponse{ + Height: s.rootBlock.Header.Height, + } + } + }() + + // Collect responses + var responses []*models.EventResponse + for i := 0; i < eventsCount; i++ { + res := <-send + eventRes, ok := res.(*models.EventResponse) + s.Require().True(ok, "Expected *models.EventResponse, got %T", res) + responses = append(responses, eventRes) + } + + // Verifying that indices are starting from 1 + s.Require().Equal("1", responses[0].MessageIndex, "Expected MessageIndex to start with 1") + + // Verifying that indices are strictly increasing + for i := 1; i < len(responses); i++ { + prevIndex, _ := strconv.Atoi(responses[i-1].MessageIndex) + currentIndex, _ := strconv.Atoi(responses[i].MessageIndex) + s.Require().Equal(prevIndex+1, currentIndex, "Expected MessageIndex to increment by 1") + } + + // Ensure the provider is properly closed after the test + provider.Close() +} diff --git a/engine/access/rest/websockets/data_providers/factory.go b/engine/access/rest/websockets/data_providers/factory.go index 72f4a6b7633..26aade4e090 100644 --- a/engine/access/rest/websockets/data_providers/factory.go +++ b/engine/access/rest/websockets/data_providers/factory.go @@ -9,6 +9,7 @@ import ( "github.com/onflow/flow-go/access" "github.com/onflow/flow-go/engine/access/rest/websockets/models" "github.com/onflow/flow-go/engine/access/state_stream" + "github.com/onflow/flow-go/model/flow" ) // Constants defining various topic names used to specify different types of @@ -43,6 +44,10 @@ type DataProviderFactoryImpl struct { stateStreamApi state_stream.API accessApi access.API + + chain flow.Chain + eventFilterConfig state_stream.EventFilterConfig + heartbeatInterval uint64 } // NewDataProviderFactory creates a new DataProviderFactory @@ -56,11 +61,17 @@ func NewDataProviderFactory( logger zerolog.Logger, stateStreamApi state_stream.API, accessApi access.API, + chain flow.Chain, + eventFilterConfig state_stream.EventFilterConfig, + heartbeatInterval uint64, ) *DataProviderFactoryImpl { return &DataProviderFactoryImpl{ - logger: logger, - stateStreamApi: stateStreamApi, - accessApi: accessApi, + logger: logger, + stateStreamApi: stateStreamApi, + accessApi: accessApi, + chain: chain, + eventFilterConfig: eventFilterConfig, + heartbeatInterval: heartbeatInterval, } } @@ -87,10 +98,12 @@ func (s *DataProviderFactoryImpl) NewDataProvider( return NewBlockHeadersDataProvider(ctx, s.logger, s.accessApi, topic, arguments, ch) case BlockDigestsTopic: return NewBlockDigestsDataProvider(ctx, s.logger, s.accessApi, topic, arguments, ch) - // TODO: Implemented handlers for each topic should be added in respective case - case EventsTopic, - AccountStatusesTopic, - TransactionStatusesTopic: + case EventsTopic: + return NewEventsDataProvider(ctx, s.logger, s.stateStreamApi, topic, arguments, ch, s.chain, s.eventFilterConfig, s.heartbeatInterval) + case AccountStatusesTopic: + return NewAccountStatusesDataProvider(ctx, s.logger, s.stateStreamApi, topic, arguments, ch, s.chain, s.eventFilterConfig, s.heartbeatInterval) + case TransactionStatusesTopic: + // TODO: Implemented handlers for each topic should be added in respective case return nil, fmt.Errorf(`topic "%s" not implemented yet`, topic) default: return nil, fmt.Errorf("unsupported topic \"%s\"", topic) diff --git a/engine/access/rest/websockets/data_providers/factory_test.go b/engine/access/rest/websockets/data_providers/factory_test.go index 2ed2b075d0c..9421cef37f6 100644 --- a/engine/access/rest/websockets/data_providers/factory_test.go +++ b/engine/access/rest/websockets/data_providers/factory_test.go @@ -11,7 +11,9 @@ import ( accessmock "github.com/onflow/flow-go/access/mock" "github.com/onflow/flow-go/engine/access/rest/common/parser" "github.com/onflow/flow-go/engine/access/rest/websockets/models" + "github.com/onflow/flow-go/engine/access/state_stream" statestreammock "github.com/onflow/flow-go/engine/access/state_stream/mock" + "github.com/onflow/flow-go/engine/access/subscription" "github.com/onflow/flow-go/model/flow" "github.com/onflow/flow-go/utils/unittest" ) @@ -43,7 +45,16 @@ func (s *DataProviderFactorySuite) SetupTest() { s.ctx = context.Background() s.ch = make(chan interface{}) - s.factory = NewDataProviderFactory(log, s.stateStreamApi, s.accessApi) + chain := flow.Testnet.Chain() + + s.factory = NewDataProviderFactory( + log, + s.stateStreamApi, + s.accessApi, + chain, + state_stream.DefaultEventFilterConfig, + subscription.DefaultHeartbeatInterval, + ) s.Require().NotNil(s.factory) } @@ -99,6 +110,28 @@ func (s *DataProviderFactorySuite) TestSupportedTopics() { s.accessApi.AssertExpectations(s.T()) }, }, + { + name: "events topic", + topic: EventsTopic, + arguments: models.Arguments{}, + setupSubscription: func() { + s.setupSubscription(s.stateStreamApi.On("SubscribeEventsFromLatest", mock.Anything, mock.Anything)) + }, + assertExpectations: func() { + s.stateStreamApi.AssertExpectations(s.T()) + }, + }, + { + name: "account statuses topic", + topic: AccountStatusesTopic, + arguments: models.Arguments{}, + setupSubscription: func() { + s.setupSubscription(s.stateStreamApi.On("SubscribeAccountStatusesFromLatestBlock", mock.Anything, mock.Anything)) + }, + assertExpectations: func() { + s.stateStreamApi.AssertExpectations(s.T()) + }, + }, } for _, test := range testCases { diff --git a/engine/access/rest/websockets/legacy/request/subscribe_events.go b/engine/access/rest/websockets/legacy/request/subscribe_events.go index 1110d3582d4..9e53e7c5fca 100644 --- a/engine/access/rest/websockets/legacy/request/subscribe_events.go +++ b/engine/access/rest/websockets/legacy/request/subscribe_events.go @@ -81,7 +81,7 @@ func (g *SubscribeEvents) Parse( g.StartHeight = 0 } - var eventTypes request.EventTypes + var eventTypes parser.EventTypes err = eventTypes.Parse(rawTypes) if err != nil { return err diff --git a/engine/access/rest/websockets/models/account_models.go b/engine/access/rest/websockets/models/account_models.go new file mode 100644 index 00000000000..712f7a1be6a --- /dev/null +++ b/engine/access/rest/websockets/models/account_models.go @@ -0,0 +1,11 @@ +package models + +import "github.com/onflow/flow-go/model/flow" + +// AccountStatusesResponse is the response message for 'events' topic. +type AccountStatusesResponse struct { + BlockID string `json:"blockID"` + Height string `json:"height"` + AccountEvents map[string]flow.EventsList `json:"account_events"` + MessageIndex string `json:"message_index"` +} diff --git a/engine/access/rest/websockets/models/event_models.go b/engine/access/rest/websockets/models/event_models.go new file mode 100644 index 00000000000..48d085d9b85 --- /dev/null +++ b/engine/access/rest/websockets/models/event_models.go @@ -0,0 +1,16 @@ +package models + +import ( + "time" + + "github.com/onflow/flow-go/model/flow" +) + +// EventResponse is the response message for 'events' topic. +type EventResponse struct { + BlockId string `json:"block_id"` + BlockHeight string `json:"block_height"` + BlockTimestamp time.Time `json:"block_timestamp"` + Events []flow.Event `json:"events"` + MessageIndex string `json:"message_index"` +} diff --git a/engine/access/rest/websockets/models/subscribe.go b/engine/access/rest/websockets/models/subscribe.go index 95ad17e3708..03b37aee5f1 100644 --- a/engine/access/rest/websockets/models/subscribe.go +++ b/engine/access/rest/websockets/models/subscribe.go @@ -1,6 +1,6 @@ package models -type Arguments map[string]string +type Arguments map[string]interface{} // SubscribeMessageRequest represents a request to subscribe to a topic. type SubscribeMessageRequest struct {