From bc2e08507c8e0cbb38d6f63e4f85a2f4fecaa48e Mon Sep 17 00:00:00 2001 From: Omer <100387053+omerlavanet@users.noreply.github.com> Date: Mon, 19 Aug 2024 17:35:29 +0300 Subject: [PATCH 01/12] feat: added option to configure static providers (#1629) * added option to configure static providers * who doesnt like some lint on comments? * disabled verifications for static provider on consumer, added static provider on provider side, disabled provider sessions on static provider code * added unitests for static providers * fix lock hanging * added tests * lint * added examples prints and script to run static provider --- .../lava_consumer_static_peers.yml | 23 ++ protocol/common/conf.go | 1 + protocol/integration/mocks.go | 23 +- protocol/integration/protocol_test.go | 75 ++++++- protocol/lavasession/consumer_types.go | 2 + .../lavasession/single_consumer_session.go | 1 + .../consumer_state_tracker_mock.go | 2 +- protocol/rpcconsumer/rpcconsumer.go | 22 +- protocol/rpcconsumer/rpcconsumer_server.go | 14 +- .../rpcprovider/rewardserver/reward_server.go | 21 ++ protocol/rpcprovider/rpcprovider.go | 30 ++- protocol/rpcprovider/rpcprovider_server.go | 36 +++- .../statetracker/consumer_state_tracker.go | 6 +- .../statetracker/updaters/pairing_updater.go | 124 ++++++++--- .../updaters/pairing_updater_test.go | 204 ++++++++++++++++++ .../statetracker/updaters/updaters_mock.go | 155 +++++++++++++ .../pre_setups/init_lava_static_provider.sh | 57 +++++ 17 files changed, 743 insertions(+), 53 deletions(-) create mode 100644 config/consumer_examples/lava_consumer_static_peers.yml create mode 100644 protocol/statetracker/updaters/pairing_updater_test.go create mode 100644 protocol/statetracker/updaters/updaters_mock.go create mode 100755 scripts/pre_setups/init_lava_static_provider.sh diff --git a/config/consumer_examples/lava_consumer_static_peers.yml b/config/consumer_examples/lava_consumer_static_peers.yml new file mode 100644 index 0000000000..5e3a6bfe5f --- /dev/null +++ b/config/consumer_examples/lava_consumer_static_peers.yml @@ -0,0 +1,23 @@ +endpoints: + - chain-id: LAV1 + api-interface: rest + network-address: 127.0.0.1:3360 + - chain-id: LAV1 + api-interface: tendermintrpc + network-address: 127.0.0.1:3361 + - chain-id: LAV1 + api-interface: grpc + network-address: 127.0.0.1:3362 +static-providers: + - api-interface: tendermintrpc + chain-id: LAV1 + node-urls: + - url: 127.0.0.1:2220 + - api-interface: grpc + chain-id: LAV1 + node-urls: + - url: 127.0.0.1:2220 + - api-interface: rest + chain-id: LAV1 + node-urls: + - url: 127.0.0.1:2220 \ No newline at end of file diff --git a/protocol/common/conf.go b/protocol/common/conf.go index b2b4bad9bd..5df9ebf1b9 100644 --- a/protocol/common/conf.go +++ b/protocol/common/conf.go @@ -13,6 +13,7 @@ type Test_mode_ctx_key struct{} const ( PlainTextConnection = "allow-plaintext-connection" EndpointsConfigName = "endpoints" + StaticProvidersConfigName = "static-providers" SaveConfigFlagName = "save-conf" GeolocationFlag = "geolocation" TestModeFlagName = "test-mode" diff --git a/protocol/integration/mocks.go b/protocol/integration/mocks.go index f1e31746a5..5790a09c17 100644 --- a/protocol/integration/mocks.go +++ b/protocol/integration/mocks.go @@ -3,6 +3,7 @@ package integration_test import ( "context" "fmt" + "net" "net/http" "strconv" "sync" @@ -31,7 +32,7 @@ type mockConsumerStateTracker struct { func (m *mockConsumerStateTracker) RegisterForVersionUpdates(ctx context.Context, version *protocoltypes.Version, versionValidator updaters.VersionValidationInf) { } -func (m *mockConsumerStateTracker) RegisterConsumerSessionManagerForPairingUpdates(ctx context.Context, consumerSessionManager *lavasession.ConsumerSessionManager) { +func (m *mockConsumerStateTracker) RegisterConsumerSessionManagerForPairingUpdates(ctx context.Context, consumerSessionManager *lavasession.ConsumerSessionManager, staticProviders []*lavasession.RPCProviderEndpoint) { } func (m *mockConsumerStateTracker) RegisterForSpecUpdates(ctx context.Context, specUpdatable updaters.SpecUpdatable, endpoint lavasession.RPCEndpoint) error { @@ -267,6 +268,19 @@ type uniqueAddressGenerator struct { lock sync.Mutex } +func isPortInUse(port int) bool { + // Attempt to listen on the port + ln, err := net.Listen("tcp", fmt.Sprintf(":%d", port)) + if err != nil { + // If there's an error, the port is likely in use + return true + } + + // Close the listener immediately if successful + ln.Close() + return false +} + func NewUniqueAddressGenerator() uniqueAddressGenerator { return uniqueAddressGenerator{ currentPort: minPort, @@ -277,6 +291,13 @@ func (ag *uniqueAddressGenerator) GetAddress() string { ag.lock.Lock() defer ag.lock.Unlock() + for { + if !isPortInUse(ag.currentPort) { + break + } + ag.currentPort++ + } + if ag.currentPort > maxPort { panic("all ports have been exhausted") } diff --git a/protocol/integration/protocol_test.go b/protocol/integration/protocol_test.go index 1b7f0989e5..8fece3cd9f 100644 --- a/protocol/integration/protocol_test.go +++ b/protocol/integration/protocol_test.go @@ -295,7 +295,7 @@ func createRpcProvider(t *testing.T, ctx context.Context, consumerAddress string require.NoError(t, err) reliabilityManager := reliabilitymanager.NewReliabilityManager(chainTracker, &mockProviderStateTracker, account.Addr.String(), chainRouter, chainParser) mockReliabilityManager := NewMockReliabilityManager(reliabilityManager) - rpcProviderServer.ServeRPCRequests(ctx, rpcProviderEndpoint, chainParser, rws, providerSessionManager, mockReliabilityManager, account.SK, nil, chainRouter, &mockProviderStateTracker, account.Addr, lavaChainID, rpcprovider.DEFAULT_ALLOWED_MISSING_CU, nil, nil, nil) + rpcProviderServer.ServeRPCRequests(ctx, rpcProviderEndpoint, chainParser, rws, providerSessionManager, mockReliabilityManager, account.SK, nil, chainRouter, &mockProviderStateTracker, account.Addr, lavaChainID, rpcprovider.DEFAULT_ALLOWED_MISSING_CU, nil, nil, nil, false) listener := rpcprovider.NewProviderListener(ctx, rpcProviderEndpoint.NetworkAddress, "/health") err = listener.RegisterReceiver(rpcProviderServer, rpcProviderEndpoint) require.NoError(t, err) @@ -1149,3 +1149,76 @@ func TestSameProviderConflictReport(t *testing.T) { require.True(t, twoProvidersConflictSent) }) } + +func TestConsumerProviderStatic(t *testing.T) { + ctx := context.Background() + // can be any spec and api interface + specId := "LAV1" + apiInterface := spectypes.APIInterfaceTendermintRPC + epoch := uint64(100) + requiredResponses := 1 + lavaChainID := "lava" + + numProviders := 1 + + consumerListenAddress := addressGen.GetAddress() + pairingList := map[uint64]*lavasession.ConsumerSessionsWithProvider{} + type providerData struct { + account sigs.Account + endpoint *lavasession.RPCProviderEndpoint + server *rpcprovider.RPCProviderServer + replySetter *ReplySetter + mockChainFetcher *MockChainFetcher + } + providers := []providerData{} + + for i := 0; i < numProviders; i++ { + account := sigs.GenerateDeterministicFloatingKey(randomizer) + providerDataI := providerData{account: account} + providers = append(providers, providerDataI) + } + consumerAccount := sigs.GenerateDeterministicFloatingKey(randomizer) + for i := 0; i < numProviders; i++ { + ctx := context.Background() + providerDataI := providers[i] + listenAddress := addressGen.GetAddress() + providers[i].server, providers[i].endpoint, providers[i].replySetter, providers[i].mockChainFetcher, _ = createRpcProvider(t, ctx, consumerAccount.Addr.String(), specId, apiInterface, listenAddress, providerDataI.account, lavaChainID, []string(nil), fmt.Sprintf("provider%d", i)) + } + // provider is static + for i := 0; i < numProviders; i++ { + pairingList[uint64(i)] = &lavasession.ConsumerSessionsWithProvider{ + PublicLavaAddress: "BANANA" + strconv.Itoa(i), + Endpoints: []*lavasession.Endpoint{ + { + NetworkAddress: providers[i].endpoint.NetworkAddress.Address, + Enabled: true, + Geolocation: 1, + }, + }, + Sessions: map[int64]*lavasession.SingleConsumerSession{}, + MaxComputeUnits: 10000, + UsedComputeUnits: 0, + PairingEpoch: epoch, + StaticProvider: true, + } + } + rpcconsumerServer, _ := createRpcConsumer(t, ctx, specId, apiInterface, consumerAccount, consumerListenAddress, epoch, pairingList, requiredResponses, lavaChainID) + require.NotNil(t, rpcconsumerServer) + client := http.Client{} + // consumer sends the relay to a provider with an address BANANA+%d so the provider needs to skip validations for this to work + resp, err := client.Get("http://" + consumerListenAddress + "/status") + require.NoError(t, err) + // we expect provider to fail the request on a verification + require.Equal(t, http.StatusInternalServerError, resp.StatusCode) + for i := 0; i < numProviders; i++ { + providers[i].server.StaticProvider = true + } + resp, err = client.Get("http://" + consumerListenAddress + "/status") + require.NoError(t, err) + // we expect provider to fail the request on a verification + require.Equal(t, http.StatusOK, resp.StatusCode) + bodyBytes, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, providers[0].replySetter.replyDataBuf, bodyBytes) + resp.Body.Close() +} diff --git a/protocol/lavasession/consumer_types.go b/protocol/lavasession/consumer_types.go index ef3e6735bf..77a50ca685 100644 --- a/protocol/lavasession/consumer_types.go +++ b/protocol/lavasession/consumer_types.go @@ -219,6 +219,7 @@ type ConsumerSessionsWithProvider struct { // blocked provider recovery status if 0 currently not used, if 1 a session has tried resume communication with this provider // if the provider is not blocked at all this field is irrelevant blockedAndUsedWithChanceForRecoveryStatus uint32 + StaticProvider bool } func NewConsumerSessionWithProvider(publicLavaAddress string, pairingEndpoints []*Endpoint, maxCu uint64, epoch uint64, stakeSize sdk.Coin) *ConsumerSessionsWithProvider { @@ -435,6 +436,7 @@ func (cswp *ConsumerSessionsWithProvider) GetConsumerSessionInstanceFromEndpoint SessionId: randomSessionId, Parent: cswp, EndpointConnection: endpointConnection, + StaticProvider: cswp.StaticProvider, } consumerSession.TryUseSession() // we must lock the session so other requests wont get it. diff --git a/protocol/lavasession/single_consumer_session.go b/protocol/lavasession/single_consumer_session.go index b6b5141d73..92d698ed56 100644 --- a/protocol/lavasession/single_consumer_session.go +++ b/protocol/lavasession/single_consumer_session.go @@ -27,6 +27,7 @@ type SingleConsumerSession struct { errorsCount uint64 relayProcessor UsedProvidersInf providerUniqueId string + StaticProvider bool } // returns the expected latency to a threshold. diff --git a/protocol/rpcconsumer/consumer_state_tracker_mock.go b/protocol/rpcconsumer/consumer_state_tracker_mock.go index 7fa930f992..23d477b995 100644 --- a/protocol/rpcconsumer/consumer_state_tracker_mock.go +++ b/protocol/rpcconsumer/consumer_state_tracker_mock.go @@ -91,7 +91,7 @@ func (mr *MockConsumerStateTrackerInfMockRecorder) GetProtocolVersion(ctx any) * } // RegisterConsumerSessionManagerForPairingUpdates mocks base method. -func (m *MockConsumerStateTrackerInf) RegisterConsumerSessionManagerForPairingUpdates(ctx context.Context, consumerSessionManager *lavasession.ConsumerSessionManager) { +func (m *MockConsumerStateTrackerInf) RegisterConsumerSessionManagerForPairingUpdates(ctx context.Context, consumerSessionManager *lavasession.ConsumerSessionManager, staticProviders []*lavasession.RPCProviderEndpoint) { m.ctrl.T.Helper() m.ctrl.Call(m, "RegisterConsumerSessionManagerForPairingUpdates", ctx, consumerSessionManager) } diff --git a/protocol/rpcconsumer/rpcconsumer.go b/protocol/rpcconsumer/rpcconsumer.go index fc01203918..8cee4014e9 100644 --- a/protocol/rpcconsumer/rpcconsumer.go +++ b/protocol/rpcconsumer/rpcconsumer.go @@ -24,6 +24,7 @@ import ( "github.com/lavanet/lava/v2/protocol/metrics" "github.com/lavanet/lava/v2/protocol/performance" "github.com/lavanet/lava/v2/protocol/provideroptimizer" + "github.com/lavanet/lava/v2/protocol/rpcprovider" "github.com/lavanet/lava/v2/protocol/statetracker" "github.com/lavanet/lava/v2/protocol/statetracker/updaters" "github.com/lavanet/lava/v2/protocol/upgrade" @@ -89,7 +90,7 @@ func (s *strategyValue) Type() string { type ConsumerStateTrackerInf interface { RegisterForVersionUpdates(ctx context.Context, version *protocoltypes.Version, versionValidator updaters.VersionValidationInf) - RegisterConsumerSessionManagerForPairingUpdates(ctx context.Context, consumerSessionManager *lavasession.ConsumerSessionManager) + RegisterConsumerSessionManagerForPairingUpdates(ctx context.Context, consumerSessionManager *lavasession.ConsumerSessionManager, staticProvidersList []*lavasession.RPCProviderEndpoint) RegisterForSpecUpdates(ctx context.Context, specUpdatable updaters.SpecUpdatable, endpoint lavasession.RPCEndpoint) error RegisterFinalizationConsensusForUpdates(context.Context, *finalizationconsensus.FinalizationConsensus) RegisterForDowntimeParamsUpdates(ctx context.Context, downtimeParamsUpdatable updaters.DowntimeParamsUpdatable) error @@ -121,6 +122,7 @@ type rpcConsumerStartOptions struct { cmdFlags common.ConsumerCmdFlags stateShare bool refererData *chainlib.RefererData + staticProvidersList []*lavasession.RPCProviderEndpoint // define static providers as backup to lava providers } // spawns a new RPCConsumer server with all it's processes and internals ready for communications @@ -287,7 +289,7 @@ func (rpcc *RPCConsumer) Start(ctx context.Context, options *rpcConsumerStartOpt activeSubscriptionProvidersStorage := lavasession.NewActiveSubscriptionProvidersStorage() consumerSessionManager := lavasession.NewConsumerSessionManager(rpcEndpoint, optimizer, consumerMetricsManager, consumerReportsManager, consumerAddr.String(), activeSubscriptionProvidersStorage) // Register For Updates - rpcc.consumerStateTracker.RegisterConsumerSessionManagerForPairingUpdates(ctx, consumerSessionManager) + rpcc.consumerStateTracker.RegisterConsumerSessionManagerForPairingUpdates(ctx, consumerSessionManager, options.staticProvidersList) var relaysMonitor *metrics.RelaysMonitor if options.cmdFlags.RelaysHealthEnableFlag { @@ -505,6 +507,20 @@ rpcconsumer consumer_examples/full_consumer_example.yml --cache-be "127.0.0.1:77 if gasPricesStr == "" { gasPricesStr = statetracker.DefaultGasPrice } + + // check if StaticProvidersConfigName exists in viper, if it does parse it with ParseStaticProvider function + var staticProviderEndpoints []*lavasession.RPCProviderEndpoint + if viper.IsSet(common.StaticProvidersConfigName) { + staticProviderEndpoints, err = rpcprovider.ParseEndpointsCustomName(viper.GetViper(), common.StaticProvidersConfigName, geolocation) + if err != nil { + return utils.LavaFormatError("invalid static providers definition", err) + } + for _, endpoint := range staticProviderEndpoints { + utils.LavaFormatInfo("Static Provider Endpoint:", utils.Attribute{Key: "Urls", Value: endpoint.NodeUrls}, utils.Attribute{Key: "Chain ID", Value: endpoint.ChainID}, utils.Attribute{Key: "API Interface", Value: endpoint.ApiInterface}) + } + } + + // set up the txFactory with gas adjustments and gas txFactory = txFactory.WithGasAdjustment(viper.GetFloat64(flags.FlagGasAdjustment)) txFactory = txFactory.WithGasPrices(gasPricesStr) utils.LavaFormatInfo("Setting gas for tx Factory", utils.LogAttr("gas-prices", gasPricesStr), utils.LogAttr("gas-adjustment", txFactory.GasAdjustment())) @@ -560,7 +576,7 @@ rpcconsumer consumer_examples/full_consumer_example.yml --cache-be "127.0.0.1:77 } rpcConsumerSharedState := viper.GetBool(common.SharedStateFlag) - err = rpcConsumer.Start(ctx, &rpcConsumerStartOptions{txFactory, clientCtx, rpcEndpoints, requiredResponses, cache, strategyFlag.Strategy, maxConcurrentProviders, analyticsServerAddressess, consumerPropagatedFlags, rpcConsumerSharedState, refererData}) + err = rpcConsumer.Start(ctx, &rpcConsumerStartOptions{txFactory, clientCtx, rpcEndpoints, requiredResponses, cache, strategyFlag.Strategy, maxConcurrentProviders, analyticsServerAddressess, consumerPropagatedFlags, rpcConsumerSharedState, refererData, staticProviderEndpoints}) return err }, } diff --git a/protocol/rpcconsumer/rpcconsumer_server.go b/protocol/rpcconsumer/rpcconsumer_server.go index fb56e3c533..f00af0206f 100644 --- a/protocol/rpcconsumer/rpcconsumer_server.go +++ b/protocol/rpcconsumer/rpcconsumer_server.go @@ -1043,16 +1043,22 @@ func (rpccs *RPCConsumerServer) relayInner(ctx context.Context, singleConsumerSe filteredHeaders, _, ignoredHeaders := rpccs.chainParser.HandleHeaders(reply.Metadata, chainMessage.GetApiCollection(), spectypes.Header_pass_reply) reply.Metadata = filteredHeaders - err = lavaprotocol.VerifyRelayReply(ctx, reply, relayRequest, providerPublicAddress) - if err != nil { - return 0, err, false + + // check the signature on the reply + if !singleConsumerSession.StaticProvider { + err = lavaprotocol.VerifyRelayReply(ctx, reply, relayRequest, providerPublicAddress) + if err != nil { + return 0, err, false + } } reply.Metadata = append(reply.Metadata, ignoredHeaders...) // TODO: response data sanity, check its under an expected format add that format to spec enabled, _ := rpccs.chainParser.DataReliabilityParams() - if enabled { + if enabled && !singleConsumerSession.StaticProvider { + // TODO: allow static providers to detect hash mismatches, + // triggering conflict with them is impossible so we skip this for now, but this can be used to block malicious providers finalizedBlocks, err := finalizationverification.VerifyFinalizationData(reply, relayRequest, providerPublicAddress, rpccs.ConsumerAddress, existingSessionLatestBlock, int64(blockDistanceForFinalizedData), int64(blocksInFinalizationProof)) if err != nil { if sdkerrors.IsOf(err, protocolerrors.ProviderFinalizationDataAccountabilityError) { diff --git a/protocol/rpcprovider/rewardserver/reward_server.go b/protocol/rpcprovider/rewardserver/reward_server.go index 13901b7743..292b4f4649 100644 --- a/protocol/rpcprovider/rewardserver/reward_server.go +++ b/protocol/rpcprovider/rewardserver/reward_server.go @@ -84,6 +84,21 @@ type PaymentConfiguration struct { shouldAddExpectedPayment bool } +// used to disable provider rewards claiming +type DisabledRewardServer struct{} + +func (rws *DisabledRewardServer) SendNewProof(ctx context.Context, proof *pairingtypes.RelaySession, epoch uint64, consumerAddr string, apiInterface string) (existingCU uint64, updatedWithProof bool) { + return 0, true +} + +func (rws *DisabledRewardServer) SubscribeStarted(consumer string, epoch uint64, subscribeID string) { + // TODO: hold off reward claims for subscription while this is still active +} + +func (rws *DisabledRewardServer) SubscribeEnded(consumer string, epoch uint64, subscribeID string) { + // TODO: can collect now +} + type RewardServer struct { rewardsTxSender RewardsTxSender lock sync.RWMutex @@ -464,6 +479,9 @@ func (rws *RewardServer) updateCUPaid(cu uint64) { } func (rws *RewardServer) AddDataBase(specId string, providerPublicAddress string, shardID uint) { + if rws == nil { + return + } // the db itself doesn't need locks. as it self manages locks inside. // but opening a db can race. (NewLocalDB) so we lock this method. // Also, we construct the in-memory rewards from the DB, so that needs a lock as well @@ -477,6 +495,9 @@ func (rws *RewardServer) AddDataBase(specId string, providerPublicAddress string } func (rws *RewardServer) CloseAllDataBases() error { + if rws == nil { + return nil + } return rws.rewardDB.Close() } diff --git a/protocol/rpcprovider/rpcprovider.go b/protocol/rpcprovider/rpcprovider.go index ebb1cc147c..225f175d57 100644 --- a/protocol/rpcprovider/rpcprovider.go +++ b/protocol/rpcprovider/rpcprovider.go @@ -108,6 +108,7 @@ type rpcProviderStartOptions struct { rewardsSnapshotThreshold uint rewardsSnapshotTimeoutSec uint healthCheckMetricsOptions *rpcProviderHealthCheckMetricsOptions + staticProvider bool } type rpcProviderHealthCheckMetricsOptions struct { @@ -137,6 +138,7 @@ type RPCProvider struct { relaysHealthCheckInterval time.Duration grpcHealthCheckEndpoint string providerUniqueId string + staticProvider bool } func (rpcp *RPCProvider) Start(options *rpcProviderStartOptions) (err error) { @@ -159,6 +161,7 @@ func (rpcp *RPCProvider) Start(options *rpcProviderStartOptions) (err error) { rpcp.relaysHealthCheckInterval = options.healthCheckMetricsOptions.relaysHealthIntervalFlag rpcp.relaysMonitorAggregator = metrics.NewRelaysMonitorAggregator(rpcp.relaysHealthCheckInterval, rpcp.providerMetricsManager) rpcp.grpcHealthCheckEndpoint = options.healthCheckMetricsOptions.grpcHealthCheckEndpoint + rpcp.staticProvider = options.staticProvider // single state tracker lavaChainFetcher := chainlib.NewLavaChainFetcher(ctx, options.clientCtx) providerStateTracker, err := statetracker.NewProviderStateTracker(ctx, options.txFactory, options.clientCtx, lavaChainFetcher, rpcp.providerMetricsManager) @@ -176,10 +179,12 @@ func (rpcp *RPCProvider) Start(options *rpcProviderStartOptions) (err error) { rpcp.providerStateTracker.RegisterForVersionUpdates(ctx, version.Version, &upgrade.ProtocolVersion{}) // single reward server - rewardDB := rewardserver.NewRewardDBWithTTL(options.rewardTTL) - rpcp.rewardServer = rewardserver.NewRewardServer(providerStateTracker, rpcp.providerMetricsManager, rewardDB, options.rewardStoragePath, options.rewardsSnapshotThreshold, options.rewardsSnapshotTimeoutSec, rpcp.chainTrackers) - rpcp.providerStateTracker.RegisterForEpochUpdates(ctx, rpcp.rewardServer) - rpcp.providerStateTracker.RegisterPaymentUpdatableForPayments(ctx, rpcp.rewardServer) + if !options.staticProvider { + rewardDB := rewardserver.NewRewardDBWithTTL(options.rewardTTL) + rpcp.rewardServer = rewardserver.NewRewardServer(providerStateTracker, rpcp.providerMetricsManager, rewardDB, options.rewardStoragePath, options.rewardsSnapshotThreshold, options.rewardsSnapshotTimeoutSec, rpcp.chainTrackers) + rpcp.providerStateTracker.RegisterForEpochUpdates(ctx, rpcp.rewardServer) + rpcp.providerStateTracker.RegisterPaymentUpdatableForPayments(ctx, rpcp.rewardServer) + } keyName, err := sigs.GetKeyName(options.clientCtx) if err != nil { utils.LavaFormatFatal("failed getting key name from clientCtx", err) @@ -467,7 +472,7 @@ func (rpcp *RPCProvider) SetupEndpoint(ctx context.Context, rpcProviderEndpoint providerNodeSubscriptionManager = chainlib.NewProviderNodeSubscriptionManager(chainRouter, chainParser, rpcProviderServer, rpcp.privKey) } - rpcProviderServer.ServeRPCRequests(ctx, rpcProviderEndpoint, chainParser, rpcp.rewardServer, providerSessionManager, reliabilityManager, rpcp.privKey, rpcp.cache, chainRouter, rpcp.providerStateTracker, rpcp.addr, rpcp.lavaChainID, DEFAULT_ALLOWED_MISSING_CU, providerMetrics, relaysMonitor, providerNodeSubscriptionManager) + rpcProviderServer.ServeRPCRequests(ctx, rpcProviderEndpoint, chainParser, rpcp.rewardServer, providerSessionManager, reliabilityManager, rpcp.privKey, rpcp.cache, chainRouter, rpcp.providerStateTracker, rpcp.addr, rpcp.lavaChainID, DEFAULT_ALLOWED_MISSING_CU, providerMetrics, relaysMonitor, providerNodeSubscriptionManager, rpcp.staticProvider) // set up grpc listener var listener *ProviderListener func() { @@ -500,8 +505,8 @@ func (rpcp *RPCProvider) SetupEndpoint(ctx context.Context, rpcProviderEndpoint return nil } -func ParseEndpoints(viper_endpoints *viper.Viper, geolocation uint64) (endpoints []*lavasession.RPCProviderEndpoint, err error) { - err = viper_endpoints.UnmarshalKey(common.EndpointsConfigName, &endpoints) +func ParseEndpointsCustomName(viper_endpoints *viper.Viper, endpointsConfigName string, geolocation uint64) (endpoints []*lavasession.RPCProviderEndpoint, err error) { + err = viper_endpoints.UnmarshalKey(endpointsConfigName, &endpoints) if err != nil { utils.LavaFormatFatal("could not unmarshal endpoints", err, utils.Attribute{Key: "viper_endpoints", Value: viper_endpoints.AllSettings()}) } @@ -511,6 +516,10 @@ func ParseEndpoints(viper_endpoints *viper.Viper, geolocation uint64) (endpoints return } +func ParseEndpoints(viper_endpoints *viper.Viper, geolocation uint64) (endpoints []*lavasession.RPCProviderEndpoint, err error) { + return ParseEndpointsCustomName(viper_endpoints, common.EndpointsConfigName, geolocation) +} + func CreateRPCProviderCobraCommand() *cobra.Command { cmdRPCProvider := &cobra.Command{ Use: `rpcprovider [config-file] | { {listen-ip:listen-port spec-chain-id api-interface "comma-separated-node-urls"} ... } --gas-adjustment "1.5" --gas "auto" --gas-prices $GASPRICE`, @@ -690,6 +699,11 @@ rpcprovider 127.0.0.1:3333 OSMOSIS tendermintrpc "wss://www.node-path.com:80,htt enableRelaysHealth := viper.GetBool(common.RelaysHealthEnableFlag) relaysHealthInterval := viper.GetDuration(common.RelayHealthIntervalFlag) healthCheckURLPath := viper.GetString(HealthCheckURLPathFlagName) + staticProvider := viper.GetBool(common.StaticProvidersConfigName) + + if staticProvider { + utils.LavaFormatWarning("Running in static provider mode, skipping rewards and allowing requests from anyone", nil) + } rpcProviderHealthCheckMetricsOptions := rpcProviderHealthCheckMetricsOptions{ enableRelaysHealth, @@ -711,6 +725,7 @@ rpcprovider 127.0.0.1:3333 OSMOSIS tendermintrpc "wss://www.node-path.com:80,htt rewardsSnapshotThreshold, rewardsSnapshotTimeoutSec, &rpcProviderHealthCheckMetricsOptions, + staticProvider, } rpcProvider := RPCProvider{} @@ -722,6 +737,7 @@ rpcprovider 127.0.0.1:3333 OSMOSIS tendermintrpc "wss://www.node-path.com:80,htt // RPCProvider command flags flags.AddTxFlagsToCmd(cmdRPCProvider) cmdRPCProvider.MarkFlagRequired(flags.FlagFrom) + cmdRPCProvider.Flags().Bool(common.StaticProvidersConfigName, false, "set the provider as static, allowing it to get requests from anyone, and skipping rewards, can be used for local tests") cmdRPCProvider.Flags().Bool(common.SaveConfigFlagName, false, "save cmd args to a config file") cmdRPCProvider.Flags().Uint64(common.GeolocationFlag, 0, "geolocation to run from") cmdRPCProvider.MarkFlagRequired(common.GeolocationFlag) diff --git a/protocol/rpcprovider/rpcprovider_server.go b/protocol/rpcprovider/rpcprovider_server.go index 0190050906..272836f8c9 100644 --- a/protocol/rpcprovider/rpcprovider_server.go +++ b/protocol/rpcprovider/rpcprovider_server.go @@ -25,6 +25,7 @@ import ( "github.com/lavanet/lava/v2/protocol/metrics" "github.com/lavanet/lava/v2/protocol/performance" "github.com/lavanet/lava/v2/protocol/provideroptimizer" + rewardserver "github.com/lavanet/lava/v2/protocol/rpcprovider/rewardserver" "github.com/lavanet/lava/v2/protocol/upgrade" "github.com/lavanet/lava/v2/utils" "github.com/lavanet/lava/v2/utils/lavaslices" @@ -65,6 +66,7 @@ type RPCProviderServer struct { relaysMonitor *metrics.RelaysMonitor providerNodeSubscriptionManager *chainlib.ProviderNodeSubscriptionManager providerUniqueId string + StaticProvider bool } type ReliabilityManagerInf interface { @@ -105,12 +107,18 @@ func (rpcps *RPCProviderServer) ServeRPCRequests( providerMetrics *metrics.ProviderMetrics, relaysMonitor *metrics.RelaysMonitor, providerNodeSubscriptionManager *chainlib.ProviderNodeSubscriptionManager, + staticProvider bool, ) { rpcps.cache = cache rpcps.chainRouter = chainRouter rpcps.privKey = privKey rpcps.providerSessionManager = providerSessionManager rpcps.reliabilityManager = reliabilityManager + if rewardServer == nil { + utils.LavaFormatError("disabled rewards for provider, reward server not defined", nil) + rewardServer = &rewardserver.DisabledRewardServer{} + } + rpcps.StaticProvider = staticProvider rpcps.rewardServer = rewardServer rpcps.chainParser = chainParser rpcps.rpcProviderEndpoint = rpcProviderEndpoint @@ -220,6 +228,11 @@ func (rpcps *RPCProviderServer) Relay(ctx context.Context, request *pairingtypes reply, err = rpcps.TryRelay(ctx, request, consumerAddress, chainMessage) } + // static provider doesnt handle sessions, so just return the response + if rpcps.StaticProvider { + return reply, rpcps.handleRelayErrorStatus(err) + } + if err != nil || common.ContextOutOfTime(ctx) { // failed to send relay. we need to adjust session state. cuSum and relayNumber. relayFailureError := rpcps.providerSessionManager.OnSessionFailure(relaySession, request.RelaySession.RelayNum) @@ -274,17 +287,18 @@ func (rpcps *RPCProviderServer) Relay(ctx context.Context, request *pairingtypes } func (rpcps *RPCProviderServer) initRelay(ctx context.Context, request *pairingtypes.RelayRequest) (relaySession *lavasession.SingleProviderSession, consumerAddress sdk.AccAddress, chainMessage chainlib.ChainMessage, err error) { - relaySession, consumerAddress, err = rpcps.verifyRelaySession(ctx, request) - if err != nil { - return nil, nil, nil, err - } - defer func(relaySession *lavasession.SingleProviderSession) { - // if we error in here until PrepareSessionForUsage was called successfully we can't call OnSessionFailure + if !rpcps.StaticProvider { + relaySession, consumerAddress, err = rpcps.verifyRelaySession(ctx, request) if err != nil { - relaySession.DisbandSession() + return nil, nil, nil, err } - }(relaySession) // lock in the session address - + defer func(relaySession *lavasession.SingleProviderSession) { + // if we error in here until PrepareSessionForUsage was called successfully we can't call OnSessionFailure + if err != nil { + relaySession.DisbandSession() + } + }(relaySession) // lock in the session address + } extensionInfo := extensionslib.ExtensionInfo{LatestBlock: 0, ExtensionOverride: request.RelayData.Extensions} if extensionInfo.ExtensionOverride == nil { // in case consumer did not set an extension, we skip the extension parsing and we are sending it to the regular url extensionInfo.ExtensionOverride = []string{} @@ -294,6 +308,10 @@ func (rpcps *RPCProviderServer) initRelay(ctx context.Context, request *pairingt if err != nil { return nil, nil, nil, err } + // we only need the chainMessage for a static provider + if rpcps.StaticProvider { + return nil, nil, chainMessage, nil + } relayCU := chainMessage.GetApi().ComputeUnits virtualEpoch := rpcps.stateTracker.GetVirtualEpoch(uint64(request.RelaySession.Epoch)) err = relaySession.PrepareSessionForUsage(ctx, relayCU, request.RelaySession.CuSum, rpcps.allowedMissingCUThreshold, virtualEpoch) diff --git a/protocol/statetracker/consumer_state_tracker.go b/protocol/statetracker/consumer_state_tracker.go index 26dc8fd01f..d074602ec9 100644 --- a/protocol/statetracker/consumer_state_tracker.go +++ b/protocol/statetracker/consumer_state_tracker.go @@ -54,7 +54,7 @@ func NewConsumerStateTracker(ctx context.Context, txFactory tx.Factory, clientCt return cst, err } -func (cst *ConsumerStateTracker) RegisterConsumerSessionManagerForPairingUpdates(ctx context.Context, consumerSessionManager *lavasession.ConsumerSessionManager) { +func (cst *ConsumerStateTracker) RegisterConsumerSessionManagerForPairingUpdates(ctx context.Context, consumerSessionManager *lavasession.ConsumerSessionManager, staticProvidersList []*lavasession.RPCProviderEndpoint) { // register this CSM to get the updated pairing list when a new epoch starts pairingUpdater := updaters.NewPairingUpdater(cst.stateQuery, consumerSessionManager.RPCEndpoint().ChainID) pairingUpdaterRaw := cst.StateTracker.RegisterForUpdates(ctx, pairingUpdater) @@ -63,7 +63,7 @@ func (cst *ConsumerStateTracker) RegisterConsumerSessionManagerForPairingUpdates utils.LavaFormatFatal("invalid updater type returned from RegisterForUpdates", nil, utils.Attribute{Key: "updater", Value: pairingUpdaterRaw}) } - err := pairingUpdater.RegisterPairing(ctx, consumerSessionManager) + err := pairingUpdater.RegisterPairing(ctx, consumerSessionManager, staticProvidersList) if err != nil { // if failed registering pairing, continue trying asynchronously go func() { @@ -71,7 +71,7 @@ func (cst *ConsumerStateTracker) RegisterConsumerSessionManagerForPairingUpdates for { utils.LavaFormatError("Failed retry RegisterPairing", err, utils.LogAttr("attempt", numberOfAttempts), utils.Attribute{Key: "data", Value: consumerSessionManager.RPCEndpoint()}) time.Sleep(5 * time.Second) // sleep so we don't spam get pairing for no reason - err := pairingUpdater.RegisterPairing(ctx, consumerSessionManager) + err := pairingUpdater.RegisterPairing(ctx, consumerSessionManager, staticProvidersList) if err == nil { break } diff --git a/protocol/statetracker/updaters/pairing_updater.go b/protocol/statetracker/updaters/pairing_updater.go index 1445e712df..ba8f65603e 100644 --- a/protocol/statetracker/updaters/pairing_updater.go +++ b/protocol/statetracker/updaters/pairing_updater.go @@ -1,9 +1,12 @@ package updaters import ( + "math" + "strconv" "sync" "time" + sdk "github.com/cosmos/cosmos-sdk/types" "github.com/lavanet/lava/v2/protocol/lavasession" "github.com/lavanet/lava/v2/utils" epochstoragetypes "github.com/lavanet/lava/v2/x/epochstorage/types" @@ -19,20 +22,43 @@ type PairingUpdatable interface { UpdateEpoch(epoch uint64) } +type ConsumerStateQueryInf interface { + GetPairing(ctx context.Context, chainID string, blockHeight int64) ([]epochstoragetypes.StakeEntry, uint64, uint64, error) + GetMaxCUForUser(ctx context.Context, chainID string, epoch uint64) (uint64, error) +} + +type ConsumerSessionManagerInf interface { + RPCEndpoint() lavasession.RPCEndpoint + UpdateAllProviders(epoch uint64, pairingList map[uint64]*lavasession.ConsumerSessionsWithProvider) error +} + type PairingUpdater struct { lock sync.RWMutex - consumerSessionManagersMap map[string][]*lavasession.ConsumerSessionManager // key is chainID so we don;t run getPairing more than once per chain + consumerSessionManagersMap map[string][]ConsumerSessionManagerInf // key is chainID so we don;t run getPairing more than once per chain nextBlockForUpdate uint64 - stateQuery *ConsumerStateQuery + stateQuery ConsumerStateQueryInf pairingUpdatables []*PairingUpdatable specId string + staticProviders []*lavasession.RPCProviderEndpoint +} + +func NewPairingUpdater(stateQuery ConsumerStateQueryInf, specId string) *PairingUpdater { + return &PairingUpdater{consumerSessionManagersMap: map[string][]ConsumerSessionManagerInf{}, stateQuery: stateQuery, specId: specId, staticProviders: []*lavasession.RPCProviderEndpoint{}} } -func NewPairingUpdater(stateQuery *ConsumerStateQuery, specId string) *PairingUpdater { - return &PairingUpdater{consumerSessionManagersMap: map[string][]*lavasession.ConsumerSessionManager{}, stateQuery: stateQuery, specId: specId} +func (pu *PairingUpdater) updateStaticProviders(staticProviders []*lavasession.RPCProviderEndpoint) { + pu.lock.Lock() + defer pu.lock.Unlock() + if len(staticProviders) > 0 && len(pu.staticProviders) == 0 { + for _, staticProvider := range staticProviders { + if staticProvider.ChainID == pu.specId { + pu.staticProviders = append(pu.staticProviders, staticProvider) + } + } + } } -func (pu *PairingUpdater) RegisterPairing(ctx context.Context, consumerSessionManager *lavasession.ConsumerSessionManager) error { +func (pu *PairingUpdater) RegisterPairing(ctx context.Context, consumerSessionManager ConsumerSessionManagerInf, staticProviders []*lavasession.RPCProviderEndpoint) error { chainID := consumerSessionManager.RPCEndpoint().ChainID timeoutCtx, cancel := context.WithTimeout(context.Background(), time.Second*10) defer cancel() @@ -40,6 +66,7 @@ func (pu *PairingUpdater) RegisterPairing(ctx context.Context, consumerSessionMa if err != nil { return err } + pu.updateStaticProviders(staticProviders) pu.updateConsumerSessionManager(ctx, pairingList, consumerSessionManager, epoch) if nextBlockForUpdate > pu.nextBlockForUpdate { // make sure we don't update twice, this updates pu.nextBlockForUpdate @@ -49,7 +76,7 @@ func (pu *PairingUpdater) RegisterPairing(ctx context.Context, consumerSessionMa defer pu.lock.Unlock() consumerSessionsManagersList, ok := pu.consumerSessionManagersMap[chainID] if !ok { - pu.consumerSessionManagersMap[chainID] = []*lavasession.ConsumerSessionManager{consumerSessionManager} + pu.consumerSessionManagersMap[chainID] = []ConsumerSessionManagerInf{consumerSessionManager} return nil } pu.consumerSessionManagersMap[chainID] = append(consumerSessionsManagersList, consumerSessionManager) @@ -134,39 +161,69 @@ func (pu *PairingUpdater) Update(latestBlock int64) { pu.updateInner(latestBlock) } -func (pu *PairingUpdater) updateConsumerSessionManager(ctx context.Context, pairingList []epochstoragetypes.StakeEntry, consumerSessionManager *lavasession.ConsumerSessionManager, epoch uint64) (err error) { +func (pu *PairingUpdater) updateConsumerSessionManager(ctx context.Context, pairingList []epochstoragetypes.StakeEntry, consumerSessionManager ConsumerSessionManagerInf, epoch uint64) (err error) { pairingListForThisCSM, err := pu.filterPairingListByEndpoint(ctx, planstypes.Geolocation(consumerSessionManager.RPCEndpoint().Geolocation), pairingList, consumerSessionManager.RPCEndpoint(), epoch) if err != nil { return err } + if len(pu.staticProviders) > 0 { + pairingListForThisCSM = pu.addStaticProvidersToPairingList(pairingListForThisCSM, consumerSessionManager.RPCEndpoint(), epoch) + } err = consumerSessionManager.UpdateAllProviders(epoch, pairingListForThisCSM) return } +func (pu *PairingUpdater) addStaticProvidersToPairingList(pairingList map[uint64]*lavasession.ConsumerSessionsWithProvider, rpcEndpoint lavasession.RPCEndpoint, epoch uint64) map[uint64]*lavasession.ConsumerSessionsWithProvider { + startIdx := uint64(0) + for key := range pairingList { + if key >= startIdx { + startIdx = key + 1 + } + } + for idx, provider := range pu.staticProviders { + // only take the provider entries relevant for this apiInterface + if provider.ApiInterface != rpcEndpoint.ApiInterface { + continue + } + endpoints := []*lavasession.Endpoint{} + for _, url := range provider.NodeUrls { + extensions := map[string]struct{}{} + for _, extension := range url.Addons { + extensions[extension] = struct{}{} + } + endpoint := &lavasession.Endpoint{ + NetworkAddress: url.Url, + Enabled: true, + Addons: map[string]struct{}{}, // TODO: does not support addons, if required need to add the functionality to differentiate the two + Extensions: extensions, + Connections: []*lavasession.EndpointConnection{}, + } + endpoints = append(endpoints, endpoint) + } + staticProviderEntry := lavasession.NewConsumerSessionWithProvider( + "StaticProvider_"+strconv.Itoa(idx), + endpoints, + math.MaxUint64/2, + epoch, + sdk.NewInt64Coin("ulava", 1000000000000000), // 1b LAVA + ) + staticProviderEntry.StaticProvider = true + pairingList[startIdx+uint64(idx)] = staticProviderEntry + } + return pairingList +} + func (pu *PairingUpdater) filterPairingListByEndpoint(ctx context.Context, currentGeo planstypes.Geolocation, pairingList []epochstoragetypes.StakeEntry, rpcEndpoint lavasession.RPCEndpoint, epoch uint64) (filteredList map[uint64]*lavasession.ConsumerSessionsWithProvider, err error) { // go over stake entries, and filter endpoints that match geolocation and api interface pairing := map[uint64]*lavasession.ConsumerSessionsWithProvider{} for providerIdx, provider := range pairingList { // // Sanity - providerEndpoints := provider.GetEndpoints() - if len(providerEndpoints) == 0 { - utils.LavaFormatError("skipping provider with no endoints", nil, utils.Attribute{Key: "Address", Value: provider.Address}, utils.Attribute{Key: "ChainID", Value: provider.Chain}) - continue - } - - relevantEndpoints := []epochstoragetypes.Endpoint{} - for _, endpoint := range providerEndpoints { - // only take into account endpoints that use the same api interface and the same geolocation - for _, endpointApiInterface := range endpoint.ApiInterfaces { - if endpointApiInterface == rpcEndpoint.ApiInterface { // we take all geolocations provided by the chain. the provider optimizer will prioritize the relevant ones - relevantEndpoints = append(relevantEndpoints, endpoint) - break - } - } - } + // only take into account endpoints that use the same api interface and the same geolocation + // we take all geolocations provided by the chain. the provider optimizer will prioritize the relevant ones + relevantEndpoints := getRelevantEndpointsFromProvider(provider, rpcEndpoint) if len(relevantEndpoints) == 0 { - utils.LavaFormatError("skipping provider, No relevant endpoints for apiInterface", nil, utils.Attribute{Key: "Address", Value: provider.Address}, utils.Attribute{Key: "ChainID", Value: provider.Chain}, utils.Attribute{Key: "apiInterface", Value: rpcEndpoint.ApiInterface}, utils.Attribute{Key: "Endpoints", Value: providerEndpoints}) + utils.LavaFormatError("skipping provider, No relevant endpoints for apiInterface", nil, utils.Attribute{Key: "Address", Value: provider.Address}, utils.Attribute{Key: "ChainID", Value: provider.Chain}, utils.Attribute{Key: "apiInterface", Value: rpcEndpoint.ApiInterface}, utils.Attribute{Key: "Endpoints", Value: provider.GetEndpoints()}) continue } @@ -204,3 +261,22 @@ func (pu *PairingUpdater) filterPairingListByEndpoint(ctx context.Context, curre // replace previous pairing with new providers return pairing, nil } + +func getRelevantEndpointsFromProvider(provider epochstoragetypes.StakeEntry, rpcEndpoint lavasession.RPCEndpoint) []epochstoragetypes.Endpoint { + providerEndpoints := provider.GetEndpoints() + if len(providerEndpoints) == 0 { + utils.LavaFormatError("skipping provider with no endoints", nil, utils.Attribute{Key: "Address", Value: provider.Address}, utils.Attribute{Key: "ChainID", Value: provider.Chain}) + return nil + } + + relevantEndpoints := []epochstoragetypes.Endpoint{} + for _, endpoint := range providerEndpoints { + for _, endpointApiInterface := range endpoint.ApiInterfaces { + if endpointApiInterface == rpcEndpoint.ApiInterface { + relevantEndpoints = append(relevantEndpoints, endpoint) + break + } + } + } + return relevantEndpoints +} diff --git a/protocol/statetracker/updaters/pairing_updater_test.go b/protocol/statetracker/updaters/pairing_updater_test.go new file mode 100644 index 0000000000..fb4f574512 --- /dev/null +++ b/protocol/statetracker/updaters/pairing_updater_test.go @@ -0,0 +1,204 @@ +package updaters + +import ( + "context" + "testing" + + "github.com/golang/mock/gomock" + "github.com/lavanet/lava/v2/protocol/common" + "github.com/lavanet/lava/v2/protocol/lavasession" + "github.com/lavanet/lava/v2/utils/rand" + epochstoragetypes "github.com/lavanet/lava/v2/x/epochstorage/types" + "github.com/stretchr/testify/require" +) + +type matcher struct { + expected map[uint64]*lavasession.ConsumerSessionsWithProvider +} + +func (m matcher) Matches(arg interface{}) bool { + actual, ok := arg.(map[uint64]*lavasession.ConsumerSessionsWithProvider) + if !ok { + return false + } + if len(actual) != len(m.expected) { + return false + } + for k, v := range m.expected { + if actual[k].StaticProvider != v.StaticProvider { + return false + } + } + return true +} + +func (m matcher) String() string { + return "" +} + +func TestPairingUpdater(t *testing.T) { + rand.InitRandomSeed() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + specID := "test-spec" + apiInterface := "test-inf" + initialPairingList := []epochstoragetypes.StakeEntry{ + { + Address: "initial2", + Endpoints: []epochstoragetypes.Endpoint{ + { + IPPORT: "1234567", + Geolocation: 1, + Addons: []string{}, + ApiInterfaces: []string{"banana"}, + Extensions: []string{}, + }, + }, + Geolocation: 1, + Chain: specID, + }, + { + Address: "initial0", + Endpoints: []epochstoragetypes.Endpoint{ + { + IPPORT: "123", + Geolocation: 0, + Addons: []string{}, + ApiInterfaces: []string{apiInterface}, + Extensions: []string{}, + }, + }, + Geolocation: 0, + Chain: specID, + }, + { + Address: "initial1", + Endpoints: []epochstoragetypes.Endpoint{ + { + IPPORT: "1234", + Geolocation: 0, + Addons: []string{}, + ApiInterfaces: []string{apiInterface}, + Extensions: []string{}, + }, + }, + Geolocation: 0, + Chain: specID, + }, + } + // Create a new mock object + stateQuery := NewMockConsumerStateQueryInf(ctrl) + stateQuery.EXPECT().GetPairing(gomock.Any(), gomock.Any(), gomock.Any()).Return(initialPairingList, uint64(0), uint64(0), nil).AnyTimes() + stateQuery.EXPECT().GetMaxCUForUser(gomock.Any(), gomock.Any(), gomock.Any()).Return(uint64(999999999), nil).AnyTimes() + + t.Run("UpdateStaticProviders", func(t *testing.T) { + pu := NewPairingUpdater(stateQuery, specID) + staticProviders := []*lavasession.RPCProviderEndpoint{ + { + ChainID: specID, + ApiInterface: apiInterface, + Geolocation: 0, + NodeUrls: []common.NodeUrl{ + { + Url: "0123", + }, + }, + }, + { + ChainID: "banana", + ApiInterface: apiInterface, + Geolocation: 0, + NodeUrls: []common.NodeUrl{ + { + Url: "01234", + }, + }, + }, + { + ChainID: "specID", + ApiInterface: "wrong", + Geolocation: 0, + NodeUrls: []common.NodeUrl{ + { + Url: "01235", + }, + }, + }, + } + + pu.updateStaticProviders(staticProviders) + + // only one of the specs is relevant + require.Len(t, pu.staticProviders, 1) + + staticProviders = []*lavasession.RPCProviderEndpoint{ + { + ChainID: specID, + ApiInterface: apiInterface, + Geolocation: 0, + NodeUrls: []common.NodeUrl{ + { + Url: "01236", + }, + }, + }, + { + ChainID: "banana", + ApiInterface: apiInterface, + Geolocation: 0, + NodeUrls: []common.NodeUrl{ + { + Url: "01237", + }, + }, + }, + } + + pu.updateStaticProviders(staticProviders) + + // can only update them once + require.Len(t, pu.staticProviders, 1) + }) + + t.Run("RegisterPairing", func(t *testing.T) { + pu := NewPairingUpdater(stateQuery, specID) + consumerSessionManager := NewMockConsumerSessionManagerInf(ctrl) + consumerSessionManager.EXPECT().RPCEndpoint().Return(lavasession.RPCEndpoint{ + ChainID: specID, + ApiInterface: apiInterface, + Geolocation: 0, + }).AnyTimes() + + staticProviders := []*lavasession.RPCProviderEndpoint{ + { + ChainID: specID, + ApiInterface: apiInterface, + Geolocation: 0, + NodeUrls: []common.NodeUrl{ + { + Url: "00123", + }, + }, + }, + } + pairingMatcher := matcher{ + expected: map[uint64]*lavasession.ConsumerSessionsWithProvider{ + 1: {}, + 2: {}, + 3: {StaticProvider: true}, + }, + } + consumerSessionManager.EXPECT().UpdateAllProviders(gomock.Any(), pairingMatcher).Times(1).Return(nil) + err := pu.RegisterPairing(context.Background(), consumerSessionManager, staticProviders) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if len(pu.consumerSessionManagersMap) != 1 { + t.Errorf("Expected 1 consumer session manager, got %d", len(pu.consumerSessionManagersMap)) + } + + consumerSessionManager.EXPECT().UpdateAllProviders(gomock.Any(), pairingMatcher).Times(1).Return(nil) + pu.Update(20) + }) +} diff --git a/protocol/statetracker/updaters/updaters_mock.go b/protocol/statetracker/updaters/updaters_mock.go new file mode 100644 index 0000000000..002299d90b --- /dev/null +++ b/protocol/statetracker/updaters/updaters_mock.go @@ -0,0 +1,155 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: protocol/statetracker/updaters/pairing_updater.go + +// Package updaters is a generated GoMock package. +package updaters + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + lavasession "github.com/lavanet/lava/v2/protocol/lavasession" + types "github.com/lavanet/lava/v2/x/epochstorage/types" + context "golang.org/x/net/context" +) + +// MockPairingUpdatable is a mock of PairingUpdatable interface. +type MockPairingUpdatable struct { + ctrl *gomock.Controller + recorder *MockPairingUpdatableMockRecorder +} + +// MockPairingUpdatableMockRecorder is the mock recorder for MockPairingUpdatable. +type MockPairingUpdatableMockRecorder struct { + mock *MockPairingUpdatable +} + +// NewMockPairingUpdatable creates a new mock instance. +func NewMockPairingUpdatable(ctrl *gomock.Controller) *MockPairingUpdatable { + mock := &MockPairingUpdatable{ctrl: ctrl} + mock.recorder = &MockPairingUpdatableMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockPairingUpdatable) EXPECT() *MockPairingUpdatableMockRecorder { + return m.recorder +} + +// UpdateEpoch mocks base method. +func (m *MockPairingUpdatable) UpdateEpoch(epoch uint64) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "UpdateEpoch", epoch) +} + +// UpdateEpoch indicates an expected call of UpdateEpoch. +func (mr *MockPairingUpdatableMockRecorder) UpdateEpoch(epoch interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateEpoch", reflect.TypeOf((*MockPairingUpdatable)(nil).UpdateEpoch), epoch) +} + +// MockConsumerStateQueryInf is a mock of ConsumerStateQueryInf interface. +type MockConsumerStateQueryInf struct { + ctrl *gomock.Controller + recorder *MockConsumerStateQueryInfMockRecorder +} + +// MockConsumerStateQueryInfMockRecorder is the mock recorder for MockConsumerStateQueryInf. +type MockConsumerStateQueryInfMockRecorder struct { + mock *MockConsumerStateQueryInf +} + +// NewMockConsumerStateQueryInf creates a new mock instance. +func NewMockConsumerStateQueryInf(ctrl *gomock.Controller) *MockConsumerStateQueryInf { + mock := &MockConsumerStateQueryInf{ctrl: ctrl} + mock.recorder = &MockConsumerStateQueryInfMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockConsumerStateQueryInf) EXPECT() *MockConsumerStateQueryInfMockRecorder { + return m.recorder +} + +// GetMaxCUForUser mocks base method. +func (m *MockConsumerStateQueryInf) GetMaxCUForUser(ctx context.Context, chainID string, epoch uint64) (uint64, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetMaxCUForUser", ctx, chainID, epoch) + ret0, _ := ret[0].(uint64) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetMaxCUForUser indicates an expected call of GetMaxCUForUser. +func (mr *MockConsumerStateQueryInfMockRecorder) GetMaxCUForUser(ctx, chainID, epoch interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMaxCUForUser", reflect.TypeOf((*MockConsumerStateQueryInf)(nil).GetMaxCUForUser), ctx, chainID, epoch) +} + +// GetPairing mocks base method. +func (m *MockConsumerStateQueryInf) GetPairing(ctx context.Context, chainID string, blockHeight int64) ([]types.StakeEntry, uint64, uint64, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetPairing", ctx, chainID, blockHeight) + ret0, _ := ret[0].([]types.StakeEntry) + ret1, _ := ret[1].(uint64) + ret2, _ := ret[2].(uint64) + ret3, _ := ret[3].(error) + return ret0, ret1, ret2, ret3 +} + +// GetPairing indicates an expected call of GetPairing. +func (mr *MockConsumerStateQueryInfMockRecorder) GetPairing(ctx, chainID, blockHeight interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPairing", reflect.TypeOf((*MockConsumerStateQueryInf)(nil).GetPairing), ctx, chainID, blockHeight) +} + +// MockConsumerSessionManagerInf is a mock of ConsumerSessionManagerInf interface. +type MockConsumerSessionManagerInf struct { + ctrl *gomock.Controller + recorder *MockConsumerSessionManagerInfMockRecorder +} + +// MockConsumerSessionManagerInfMockRecorder is the mock recorder for MockConsumerSessionManagerInf. +type MockConsumerSessionManagerInfMockRecorder struct { + mock *MockConsumerSessionManagerInf +} + +// NewMockConsumerSessionManagerInf creates a new mock instance. +func NewMockConsumerSessionManagerInf(ctrl *gomock.Controller) *MockConsumerSessionManagerInf { + mock := &MockConsumerSessionManagerInf{ctrl: ctrl} + mock.recorder = &MockConsumerSessionManagerInfMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockConsumerSessionManagerInf) EXPECT() *MockConsumerSessionManagerInfMockRecorder { + return m.recorder +} + +// RPCEndpoint mocks base method. +func (m *MockConsumerSessionManagerInf) RPCEndpoint() lavasession.RPCEndpoint { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RPCEndpoint") + ret0, _ := ret[0].(lavasession.RPCEndpoint) + return ret0 +} + +// RPCEndpoint indicates an expected call of RPCEndpoint. +func (mr *MockConsumerSessionManagerInfMockRecorder) RPCEndpoint() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RPCEndpoint", reflect.TypeOf((*MockConsumerSessionManagerInf)(nil).RPCEndpoint)) +} + +// UpdateAllProviders mocks base method. +func (m *MockConsumerSessionManagerInf) UpdateAllProviders(epoch uint64, pairingList map[uint64]*lavasession.ConsumerSessionsWithProvider) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateAllProviders", epoch, pairingList) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpdateAllProviders indicates an expected call of UpdateAllProviders. +func (mr *MockConsumerSessionManagerInfMockRecorder) UpdateAllProviders(epoch, pairingList interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAllProviders", reflect.TypeOf((*MockConsumerSessionManagerInf)(nil).UpdateAllProviders), epoch, pairingList) +} diff --git a/scripts/pre_setups/init_lava_static_provider.sh b/scripts/pre_setups/init_lava_static_provider.sh new file mode 100755 index 0000000000..c3e1b57ab3 --- /dev/null +++ b/scripts/pre_setups/init_lava_static_provider.sh @@ -0,0 +1,57 @@ +#!/bin/bash +__dir=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +source "$__dir"/../useful_commands.sh +. "${__dir}"/../vars/variables.sh + +LOGS_DIR=${__dir}/../../testutil/debugging/logs +mkdir -p $LOGS_DIR +rm $LOGS_DIR/*.log + +killall screen +screen -wipe + +echo "[Test Setup] installing all binaries" +make install-all + +echo "[Test Setup] setting up a new lava node" +screen -d -m -S node bash -c "./scripts/start_env_dev.sh" +screen -ls +echo "[Lavavisor Setup] sleeping 20 seconds for node to finish setup (if its not enough increase timeout)" +sleep 20 + +GASPRICE="0.00002ulava" +lavad tx gov submit-legacy-proposal spec-add ./cookbook/specs/ibc.json,./cookbook/specs/cosmoswasm.json,./cookbook/specs/tendermint.json,./cookbook/specs/cosmossdk.json,./cookbook/specs/cosmossdk_45.json,./cookbook/specs/cosmossdk_full.json,./cookbook/specs/ethermint.json,./cookbook/specs/ethereum.json,./cookbook/specs/cosmoshub.json,./cookbook/specs/lava.json,./cookbook/specs/osmosis.json,./cookbook/specs/fantom.json,./cookbook/specs/celo.json,./cookbook/specs/optimism.json,./cookbook/specs/arbitrum.json,./cookbook/specs/starknet.json,./cookbook/specs/aptos.json,./cookbook/specs/juno.json,./cookbook/specs/polygon.json,./cookbook/specs/evmos.json,./cookbook/specs/base.json,./cookbook/specs/canto.json,./cookbook/specs/sui.json,./cookbook/specs/solana.json,./cookbook/specs/bsc.json,./cookbook/specs/axelar.json,./cookbook/specs/avalanche.json,./cookbook/specs/fvm.json --lava-dev-test -y --from alice --gas-adjustment "1.5" --gas "auto" --gas-prices $GASPRICE & +wait_next_block +wait_next_block +lavad tx gov vote 1 yes -y --from alice --gas-adjustment "1.5" --gas "auto" --gas-prices $GASPRICE +sleep 4 + +# Plans proposal +lavad tx gov submit-legacy-proposal plans-add ./cookbook/plans/test_plans/default.json,./cookbook/plans/test_plans/temporary-add.json -y --from alice --gas-adjustment "1.5" --gas "auto" --gas-prices $GASPRICE +wait_next_block +wait_next_block +lavad tx gov vote 2 yes -y --from alice --gas-adjustment "1.5" --gas "auto" --gas-prices $GASPRICE + +sleep 4 + +CLIENTSTAKE="500000000000ulava" +PROVIDERSTAKE="500000000000ulava" + +PROVIDER1_LISTENER="127.0.0.1:2221" +# static configuration +PROVIDER4_LISTENER="127.0.0.1:2220" + +lavad tx subscription buy DefaultPlan $(lavad keys show user1 -a) -y --from user1 --gas-adjustment "1.5" --gas "auto" --gas-prices $GASPRICE +wait_next_block +lavad tx pairing stake-provider "LAV1" $PROVIDERSTAKE "$PROVIDER1_LISTENER,1" --delegate-limit 0ulava 1 $(operator_address) -y --from servicer1 --provider-moniker "dummyMoniker" --gas-adjustment "1.5" --gas "auto" --gas-prices $GASPRICE + +sleep_until_next_epoch + +screen -d -m -S provider4 bash -c "source ~/.bashrc; lavap rpcprovider provider_examples/lava_example.yml\ +$EXTRA_PROVIDER_FLAGS --geolocation 1 --log_level debug --from servicer4 --static-providers --chain-id lava 2>&1 | tee $LOGS_DIR/PROVIDER4.log" && sleep 0.25 + +screen -d -m -S consumers bash -c "source ~/.bashrc; lavap rpcconsumer consumer_examples/lava_consumer_static_peers.yml \ +$EXTRA_PORTAL_FLAGS --geolocation 1 --log_level debug --from user1 --chain-id lava --allow-insecure-provider-dialing --metrics-listen-address ":7779" 2>&1 | tee $LOGS_DIR/CONSUMERS.log" && sleep 0.25 + +echo "--- setting up screens done ---" +screen -ls \ No newline at end of file From e47fe18961c541ec1d0598f2ac2d67025fdfd0b5 Mon Sep 17 00:00:00 2001 From: Ran Mishael <106548467+ranlavanet@users.noreply.github.com> Date: Tue, 20 Aug 2024 13:54:22 +0200 Subject: [PATCH 02/12] feat: PRT - cache block hash storage (#1637) * feat: PRT - cache block hash storage * remove multiplier --- ecosystem/cache/cache_test.go | 302 ++++++++++++++++-- ecosystem/cache/command.go | 1 + ecosystem/cache/handlers.go | 82 ++++- ecosystem/cache/server.go | 51 ++- .../consumer_ws_subscription_manager_test.go | 2 +- 5 files changed, 397 insertions(+), 41 deletions(-) diff --git a/ecosystem/cache/cache_test.go b/ecosystem/cache/cache_test.go index 3f6e763a4e..1f2cfbbbbc 100644 --- a/ecosystem/cache/cache_test.go +++ b/ecosystem/cache/cache_test.go @@ -31,7 +31,16 @@ const ( func initTest() (context.Context, *cache.RelayerCacheServer) { ctx := context.Background() cs := cache.CacheServer{CacheMaxCost: 2 * 1024 * 1024 * 1024} - cs.InitCache(ctx, cache.DefaultExpirationTimeFinalized, cache.DefaultExpirationForNonFinalized, cache.DisabledFlagOption, cache.DefaultExpirationTimeFinalizedMultiplier, cache.DefaultExpirationTimeNonFinalizedMultiplier) + cs.InitCache( + ctx, + cache.DefaultExpirationTimeFinalized, + cache.DefaultExpirationForNonFinalized, + cache.DefaultExpirationNodeErrors, + cache.DefaultExpirationBlocksHashesToHeights, + cache.DisabledFlagOption, + cache.DefaultExpirationTimeFinalizedMultiplier, + cache.DefaultExpirationTimeNonFinalizedMultiplier, + ) cacheServer := &cache.RelayerCacheServer{CacheServer: &cs} return ctx, cacheServer } @@ -85,11 +94,12 @@ func TestCacheSetGet(t *testing.T) { Finalized: tt.finalized, RequestedBlock: request.RequestBlock, } - _, err = cacheServer.GetRelay(ctx, &messageGet) + reply, err := cacheServer.GetRelay(ctx, &messageGet) + require.NoError(t, err) if tt.valid { - require.NoError(t, err) + require.NotNil(t, reply.Reply) } else { - require.Error(t, err) + require.Nil(t, reply.Reply) } }) } @@ -169,9 +179,9 @@ func TestCacheGetWithoutSet(t *testing.T) { Finalized: tt.finalized, RequestedBlock: request.RequestBlock, } - _, err := cacheServer.GetRelay(ctx, &messageGet) - - require.Error(t, err) + reply, err := cacheServer.GetRelay(ctx, &messageGet) + require.Nil(t, reply.Reply) + require.NoError(t, err) }) } } @@ -333,7 +343,7 @@ func TestCacheSetGetLatest(t *testing.T) { require.Equal(t, cacheReply.GetReply().LatestBlock, latestBlockForRelay) } } else { - require.Error(t, err) + require.Nil(t, cacheReply.Reply) } }) } @@ -410,7 +420,7 @@ func TestCacheSetGetLatestWhenAdvancingLatest(t *testing.T) { require.Equal(t, cacheReply.GetReply().LatestBlock, latestBlockForRelay) } } else { - require.Error(t, err) + require.Nil(t, cacheReply.Reply) } request2 := shallowCopy(request) @@ -435,8 +445,9 @@ func TestCacheSetGetLatestWhenAdvancingLatest(t *testing.T) { RequestedBlock: request.RequestBlock, } // repeat our latest block get, this time we expect it to look for a newer block and fail - _, err = cacheServer.GetRelay(ctx, &messageGet) - require.Error(t, err) + reply, err := cacheServer.GetRelay(ctx, &messageGet) + require.NoError(t, err) + require.Nil(t, reply.Reply) }) } } @@ -462,7 +473,7 @@ func TestCacheSetGetJsonRPCWithID(t *testing.T) { {name: "NonFinalized With Hash", valid: true, delay: time.Millisecond, finalized: false, hash: []byte{1, 2, 3}}, {name: "NonFinalized After delay With Hash", valid: true, delay: cache.DefaultExpirationForNonFinalized + time.Millisecond, finalized: false, hash: []byte{1, 2, 3}}, - // Null ID in get and set + // // Null ID in get and set {name: "Finalized No Hash, with null id in get and set", valid: true, delay: time.Millisecond, finalized: true, hash: nil, nullIdInGet: true, nullIdInSet: true}, {name: "Finalized After delay No Hash, with null id in get and set", valid: true, delay: cache.DefaultExpirationForNonFinalized + time.Millisecond, finalized: true, hash: nil, nullIdInGet: true, nullIdInSet: true}, {name: "NonFinalized No Hash, with null id in get and set", valid: true, delay: time.Millisecond, finalized: false, hash: nil, nullIdInGet: true, nullIdInSet: true}, @@ -472,7 +483,7 @@ func TestCacheSetGetJsonRPCWithID(t *testing.T) { {name: "NonFinalized With Hash, with null id in get and set", valid: true, delay: time.Millisecond, finalized: false, hash: []byte{1, 2, 3}, nullIdInGet: true, nullIdInSet: true}, {name: "NonFinalized After delay With Hash, with null id in get and set", valid: true, delay: cache.DefaultExpirationForNonFinalized + time.Millisecond, finalized: false, hash: []byte{1, 2, 3}, nullIdInGet: true, nullIdInSet: true}, - // Null ID only in get + // // Null ID only in get {name: "Finalized No Hash, with null id only in get", valid: true, delay: time.Millisecond, finalized: true, hash: nil, nullIdInGet: true}, {name: "Finalized After delay No Hash, with null id only in get", valid: true, delay: cache.DefaultExpirationForNonFinalized + time.Millisecond, finalized: true, hash: nil, nullIdInGet: true}, {name: "NonFinalized No Hash, with null id only in get", valid: true, delay: time.Millisecond, finalized: false, hash: nil, nullIdInGet: true}, @@ -482,7 +493,7 @@ func TestCacheSetGetJsonRPCWithID(t *testing.T) { {name: "NonFinalized With Hash, with null id only in get", valid: true, delay: time.Millisecond, finalized: false, hash: []byte{1, 2, 3}, nullIdInGet: true}, {name: "NonFinalized After delay With Hash, with null id only in get", valid: true, delay: cache.DefaultExpirationForNonFinalized + time.Millisecond, finalized: false, hash: []byte{1, 2, 3}, nullIdInGet: true}, - // Null ID only in set + // // Null ID only in set {name: "Finalized No Hash, with null id only in set", valid: true, delay: time.Millisecond, finalized: true, hash: nil, nullIdInSet: true}, {name: "Finalized After delay No Hash, with null id only in set", valid: true, delay: cache.DefaultExpirationForNonFinalized + time.Millisecond, finalized: true, hash: nil, nullIdInSet: true}, {name: "NonFinalized No Hash, with null id only in set", valid: true, delay: time.Millisecond, finalized: false, hash: nil, nullIdInSet: true}, @@ -547,20 +558,21 @@ func TestCacheSetGetJsonRPCWithID(t *testing.T) { } cacheReply, err := cacheServer.GetRelay(ctx, &messageGet) + // because we always need a cache reply. we cant return an error in any case. + // grpc do not allow returning errors + messages + require.NoError(t, err) + if tt.valid { cacheReply.Reply.Data = outputFormatter(cacheReply.Reply.Data) - require.NoError(t, err) - result := gjson.GetBytes(cacheReply.GetReply().Data, format.IDFieldName) extractedID := result.Raw - if tt.nullIdInGet { require.Equal(t, "null", extractedID) } else { require.Equal(t, strconv.FormatInt(changedID, 10), extractedID) } } else { - require.Error(t, err) + require.Nil(t, cacheReply.Reply) } }) } @@ -583,7 +595,16 @@ func TestCacheExpirationMultiplier(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { cs := cache.CacheServer{CacheMaxCost: 2 * 1024 * 1024 * 1024} - cs.InitCache(context.Background(), cache.DefaultExpirationTimeFinalized, cache.DefaultExpirationForNonFinalized, cache.DisabledFlagOption, 1, tt.multiplier) + cs.InitCache( + context.Background(), + cache.DefaultExpirationTimeFinalized, + cache.DefaultExpirationForNonFinalized, + cache.DefaultExpirationNodeErrors, + cache.DefaultExpirationBlocksHashesToHeights, + cache.DisabledFlagOption, + cache.DefaultExpirationTimeFinalizedMultiplier, + tt.multiplier, + ) cacheServer := &cache.RelayerCacheServer{CacheServer: &cs} durationActual := cacheServer.CacheServer.ExpirationForChain(cache.DefaultExpirationForNonFinalized) @@ -591,3 +612,246 @@ func TestCacheExpirationMultiplier(t *testing.T) { }) } } + +func TestCacheSetGetBlocksHashesToHeightsHappyFlow(t *testing.T) { + t.Parallel() + const ( + SET_INPUT int = iota + GET_INPUT + EXPECTED_FROM_GET + ) + + type step struct { + blockHashesToHeights []*pairingtypes.BlockHashToHeight + inputOrExpected int + } + + steps := []step{ + { + inputOrExpected: SET_INPUT, + blockHashesToHeights: []*pairingtypes.BlockHashToHeight{}, + }, + { + inputOrExpected: GET_INPUT, + blockHashesToHeights: []*pairingtypes.BlockHashToHeight{ + {Hash: "H1"}, + {Hash: "H2"}, + }, + }, + { + inputOrExpected: EXPECTED_FROM_GET, + blockHashesToHeights: []*pairingtypes.BlockHashToHeight{ + { + Hash: "H1", + Height: spectypes.NOT_APPLICABLE, + }, + { + Hash: "H2", + Height: spectypes.NOT_APPLICABLE, + }, + }, + }, + { + inputOrExpected: SET_INPUT, + blockHashesToHeights: []*pairingtypes.BlockHashToHeight{ + { + Hash: "H1", + Height: 1, + }, + }, + }, + { + inputOrExpected: GET_INPUT, + blockHashesToHeights: []*pairingtypes.BlockHashToHeight{ + {Hash: "H1"}, + {Hash: "H2"}, + }, + }, + { + inputOrExpected: EXPECTED_FROM_GET, + blockHashesToHeights: []*pairingtypes.BlockHashToHeight{ + { + Hash: "H1", + Height: 1, + }, + { + Hash: "H2", + Height: spectypes.NOT_APPLICABLE, + }, + }, + }, + { + inputOrExpected: GET_INPUT, + blockHashesToHeights: []*pairingtypes.BlockHashToHeight{ + {Hash: "H1"}, + }, + }, + { + inputOrExpected: EXPECTED_FROM_GET, + blockHashesToHeights: []*pairingtypes.BlockHashToHeight{ + { + Hash: "H1", + Height: 1, + }, + }, + }, + { + inputOrExpected: GET_INPUT, + blockHashesToHeights: []*pairingtypes.BlockHashToHeight{ + {Hash: "H2"}, + }, + }, + { + inputOrExpected: EXPECTED_FROM_GET, + blockHashesToHeights: []*pairingtypes.BlockHashToHeight{ + { + Hash: "H2", + Height: spectypes.NOT_APPLICABLE, + }, + }, + }, + { + inputOrExpected: SET_INPUT, + blockHashesToHeights: []*pairingtypes.BlockHashToHeight{ + { + Hash: "H3", + Height: 3, + }, + }, + }, + { + inputOrExpected: GET_INPUT, + blockHashesToHeights: []*pairingtypes.BlockHashToHeight{ + {Hash: "H1"}, + {Hash: "H2"}, + }, + }, + { + inputOrExpected: EXPECTED_FROM_GET, + blockHashesToHeights: []*pairingtypes.BlockHashToHeight{ + { + Hash: "H1", + Height: 1, + }, + { + Hash: "H2", + Height: spectypes.NOT_APPLICABLE, + }, + }, + }, + { + inputOrExpected: GET_INPUT, + blockHashesToHeights: []*pairingtypes.BlockHashToHeight{ + {Hash: "H1"}, + {Hash: "H2"}, + {Hash: "H3"}, + }, + }, + { + inputOrExpected: EXPECTED_FROM_GET, + blockHashesToHeights: []*pairingtypes.BlockHashToHeight{ + { + Hash: "H1", + Height: 1, + }, + { + Hash: "H2", + Height: spectypes.NOT_APPLICABLE, + }, + { + Hash: "H3", + Height: 3, + }, + }, + }, + { + inputOrExpected: SET_INPUT, + blockHashesToHeights: []*pairingtypes.BlockHashToHeight{ + { + Hash: "H1", + Height: 4, + }, + { + Hash: "H2", + Height: 2, + }, + { + Hash: "H5", + Height: 7, + }, + }, + }, + { + inputOrExpected: GET_INPUT, + blockHashesToHeights: []*pairingtypes.BlockHashToHeight{ + {Hash: "H1"}, + {Hash: "H2"}, + {Hash: "H3"}, + {Hash: "H5"}, + }, + }, + { + inputOrExpected: EXPECTED_FROM_GET, + blockHashesToHeights: []*pairingtypes.BlockHashToHeight{ + { + Hash: "H1", + Height: 4, + }, + { + Hash: "H2", + Height: 2, + }, + { + Hash: "H3", + Height: 3, + }, + { + Hash: "H5", + Height: 7, + }, + }, + }, + } + + t.Run("run cache steps", func(t *testing.T) { + ctx, cacheServer := initTest() + request := getRequest(1230, []byte(StubSig), StubApiInterface) + + var lastCacheResult []*pairingtypes.BlockHashToHeight + for stepNum, step := range steps { + switch step.inputOrExpected { + case SET_INPUT: + messageSet := pairingtypes.RelayCacheSet{ + RequestHash: HashRequest(t, request, StubChainID), + BlockHash: []byte("123456789"), + ChainId: StubChainID, + Response: &pairingtypes.RelayReply{}, + Finalized: true, + RequestedBlock: request.RequestBlock, + BlocksHashesToHeights: step.blockHashesToHeights, + } + + _, err := cacheServer.SetRelay(ctx, &messageSet) + require.NoError(t, err, "step: %d", stepNum) + + // sleep to make sure it's in the cache + time.Sleep(3 * time.Millisecond) + case GET_INPUT: + messageGet := pairingtypes.RelayCacheGet{ + RequestHash: HashRequest(t, request, StubChainID), + BlockHash: []byte("123456789"), + ChainId: StubChainID, + Finalized: true, + RequestedBlock: request.RequestBlock, + BlocksHashesToHeights: step.blockHashesToHeights, + } + + cacheResult, err := cacheServer.GetRelay(ctx, &messageGet) + require.NoError(t, err, "step: %d", stepNum) + lastCacheResult = cacheResult.BlocksHashesToHeights + case EXPECTED_FROM_GET: + require.Equal(t, step.blockHashesToHeights, lastCacheResult, "step: %d", stepNum) + } + } + }) +} diff --git a/ecosystem/cache/command.go b/ecosystem/cache/command.go index 6a5c9972dd..c138a59134 100644 --- a/ecosystem/cache/command.go +++ b/ecosystem/cache/command.go @@ -44,6 +44,7 @@ longer DefaultExpirationForNonFinalized will reduce sync QoS for "latest" reques cacheCmd.Flags().Duration(ExpirationNonFinalizedFlagName, DefaultExpirationForNonFinalized, "how long does a cache entry lasts in the cache for a non finalized entry") cacheCmd.Flags().Float64(ExpirationTimeFinalizedMultiplierFlagName, DefaultExpirationTimeFinalizedMultiplier, "Multiplier for finalized cache entry expiration. 1 means no change (default), 1.2 means 20% longer.") cacheCmd.Flags().Float64(ExpirationTimeNonFinalizedMultiplierFlagName, DefaultExpirationTimeNonFinalizedMultiplier, "Multiplier for non-finalized cache entry expiration. 1 means no change (default), 1.2 means 20% longer.") + cacheCmd.Flags().Duration(ExpirationBlocksHashesToHeightsFlagName, DefaultExpirationBlocksHashesToHeights, "how long does the cache entry lasts in the cache for a block hash to height entry") cacheCmd.Flags().Duration(ExpirationNodeErrorsOnFinalizedFlagName, DefaultExpirationNodeErrors, "how long does a cache entry lasts in the cache for a finalized node error entry") cacheCmd.Flags().String(FlagMetricsAddress, DisabledFlagOption, "address to listen to prometheus metrics 127.0.0.1:5555, later you can curl http://127.0.0.1:5555/metrics") cacheCmd.Flags().Int64(FlagCacheSizeName, 2*1024*1024*1024, "the maximal amount of entries to save") diff --git a/ecosystem/cache/handlers.go b/ecosystem/cache/handlers.go index 9d4c3a3f73..49cbd6adba 100644 --- a/ecosystem/cache/handlers.go +++ b/ecosystem/cache/handlers.go @@ -85,12 +85,34 @@ func (s *RelayerCacheServer) getSeenBlockForSharedStateMode(chainId string, shar return 0 } +func (s *RelayerCacheServer) getBlockHeightsFromHashes(chainId string, hashes []*pairingtypes.BlockHashToHeight) []*pairingtypes.BlockHashToHeight { + for _, hashToHeight := range hashes { + formattedKey := s.formatChainIdWithHashKey(chainId, hashToHeight.Hash) + value, found := getNonExpiredFromCache(s.CacheServer.blocksHashesToHeightsCache, formattedKey) + if found { + if cacheValue, ok := value.(int64); ok { + hashToHeight.Height = cacheValue + } + } else { + hashToHeight.Height = spectypes.NOT_APPLICABLE + } + } + return hashes +} + func (s *RelayerCacheServer) GetRelay(ctx context.Context, relayCacheGet *pairingtypes.RelayCacheGet) (*pairingtypes.CacheRelayReply, error) { cacheReply := &pairingtypes.CacheRelayReply{} var cacheReplyTmp *pairingtypes.CacheRelayReply var err error var seenBlock int64 + // validating that if we had an error, we do not return a reply. + defer func() { + if err != nil { + cacheReply.Reply = nil + } + }() + originalRequestedBlock := relayCacheGet.RequestedBlock // save requested block prior to swap if originalRequestedBlock < 0 { // we need to fetch stored latest block information. getLatestBlock := s.getLatestBlock(latestBlockKey(relayCacheGet.ChainId, "")) @@ -104,11 +126,15 @@ func (s *RelayerCacheServer) GetRelay(ctx context.Context, relayCacheGet *pairin utils.Attribute{Key: "requested_block_parsed", Value: relayCacheGet.RequestedBlock}, utils.Attribute{Key: "seen_block", Value: relayCacheGet.SeenBlock}, ) + + var blockHashes []*pairingtypes.BlockHashToHeight if relayCacheGet.RequestedBlock >= 0 { // we can only fetch - // check seen block is larger than our requested block, we don't need to fetch seen block prior as its already larger than requested block + // we don't need to fetch seen block prior as its already larger than requested block waitGroup := sync.WaitGroup{} - waitGroup.Add(2) // currently we have two groups getRelayInner and getSeenBlock - // fetch all reads at the same time. + waitGroup.Add(3) // currently we have three groups: getRelayInner, getSeenBlock and getBlockHeightsFromHashes + + // fetch all reads at the same time: + // fetch the cache entry go func() { defer waitGroup.Done() cacheReplyTmp, err = s.getRelayInner(relayCacheGet) @@ -116,6 +142,8 @@ func (s *RelayerCacheServer) GetRelay(ctx context.Context, relayCacheGet *pairin cacheReply = cacheReplyTmp // set cache reply only if its not nil, as we need to store seen block in it. } }() + + // fetch seen block go func() { defer waitGroup.Done() // set seen block if required @@ -124,8 +152,16 @@ func (s *RelayerCacheServer) GetRelay(ctx context.Context, relayCacheGet *pairin relayCacheGet.SeenBlock = seenBlock // update state. } }() + + // fetch block hashes + go func() { + defer waitGroup.Done() + blockHashes = s.getBlockHeightsFromHashes(relayCacheGet.ChainId, relayCacheGet.BlocksHashesToHeights) + }() + // wait for all reads to complete before moving forward waitGroup.Wait() + if err == nil { // in case we got a hit validate seen block of the reply. // validate that the response seen block is larger or equal to our expectations. if cacheReply.SeenBlock < lavaslices.Min([]int64{relayCacheGet.SeenBlock, relayCacheGet.RequestedBlock}) { // TODO unitest this. @@ -138,6 +174,7 @@ func (s *RelayerCacheServer) GetRelay(ctx context.Context, relayCacheGet *pairin ) } } + // set seen block. if relayCacheGet.SeenBlock > cacheReply.SeenBlock { cacheReply.SeenBlock = relayCacheGet.SeenBlock @@ -148,22 +185,32 @@ func (s *RelayerCacheServer) GetRelay(ctx context.Context, relayCacheGet *pairin utils.LogAttr("requested block", relayCacheGet.RequestedBlock), utils.LogAttr("request_hash", string(relayCacheGet.RequestHash)), ) + // even if we don't have information on requested block, we can still check if we have data on the block hash array. + blockHashes = s.getBlockHeightsFromHashes(relayCacheGet.ChainId, relayCacheGet.BlocksHashesToHeights) + } + + cacheReply.BlocksHashesToHeights = blockHashes + if blockHashes != nil { + utils.LavaFormatDebug("block hashes:", utils.LogAttr("hashes", blockHashes)) } // add prometheus metrics asynchronously + cacheHit := cacheReply.Reply != nil go func() { cacheMetricsContext, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() - var hit bool - if err != nil { - s.cacheMiss(cacheMetricsContext, err) - } else { - hit = true + + if cacheHit { s.cacheHit(cacheMetricsContext) + } else { + s.cacheMiss(cacheMetricsContext, err) } - s.CacheServer.CacheMetrics.AddApiSpecific(originalRequestedBlock, relayCacheGet.ChainId, hit) + + s.CacheServer.CacheMetrics.AddApiSpecific(originalRequestedBlock, relayCacheGet.ChainId, cacheHit) }() - return cacheReply, err + // no matter what we return nil from cache. as we need additional info even if we had cache miss + // such as block hashes array, seen block, etc... + return cacheReply, nil } // formatHashKey formats the hash key by adding latestBlock information. @@ -173,6 +220,10 @@ func (s *RelayerCacheServer) formatHashKey(hash []byte, parsedRequestedBlock int return hash } +func (s *RelayerCacheServer) formatChainIdWithHashKey(chainId, hash string) string { + return chainId + "_" + hash +} + func (s *RelayerCacheServer) getRelayInner(relayCacheGet *pairingtypes.RelayCacheGet) (*pairingtypes.CacheRelayReply, error) { // cache key is compressed from: // 1. Request hash including all the information inside RelayPrivateData (Salt can cause issues if not dealt with on consumer side.) @@ -249,6 +300,15 @@ func (s *RelayerCacheServer) setSeenBlockOnSharedStateMode(chainId, sharedStateI s.performInt64WriteWithValidationAndRetry(get, set, seenBlock) } +func (s *RelayerCacheServer) setBlocksHashesToHeights(chainId string, blocksHashesToHeights []*pairingtypes.BlockHashToHeight) { + for _, hashToHeight := range blocksHashesToHeights { + if hashToHeight.Height >= 0 { + formattedKey := s.formatChainIdWithHashKey(chainId, hashToHeight.Hash) + s.CacheServer.blocksHashesToHeightsCache.SetWithTTL(formattedKey, hashToHeight.Height, 1, s.CacheServer.ExpirationBlocksHashesToHeights) + } + } +} + func (s *RelayerCacheServer) SetRelay(ctx context.Context, relayCacheSet *pairingtypes.RelayCacheSet) (*emptypb.Empty, error) { if relayCacheSet.RequestedBlock < 0 { return nil, utils.LavaFormatError("invalid relay cache set data, request block is negative", nil, utils.Attribute{Key: "requestBlock", Value: relayCacheSet.RequestedBlock}) @@ -265,6 +325,7 @@ func (s *RelayerCacheServer) SetRelay(ctx context.Context, relayCacheSet *pairin utils.Attribute{Key: "requestHash", Value: string(relayCacheSet.BlockHash)}, utils.Attribute{Key: "latestKnownBlock", Value: string(relayCacheSet.BlockHash)}, utils.Attribute{Key: "IsNodeError", Value: relayCacheSet.IsNodeError}, + utils.Attribute{Key: "BlocksHashesToHeights", Value: relayCacheSet.BlocksHashesToHeights}, ) // finalized entries can stay there if relayCacheSet.Finalized { @@ -282,6 +343,7 @@ func (s *RelayerCacheServer) SetRelay(ctx context.Context, relayCacheSet *pairin // Setting the seen block for shared state. s.setSeenBlockOnSharedStateMode(relayCacheSet.ChainId, relayCacheSet.SharedStateId, latestKnownBlock) s.setLatestBlock(latestBlockKey(relayCacheSet.ChainId, ""), latestKnownBlock) + s.setBlocksHashesToHeights(relayCacheSet.ChainId, relayCacheSet.BlocksHashesToHeights) return &emptypb.Empty{}, nil } diff --git a/ecosystem/cache/server.go b/ecosystem/cache/server.go index 53fcce34b2..5875a2d856 100644 --- a/ecosystem/cache/server.go +++ b/ecosystem/cache/server.go @@ -29,11 +29,13 @@ const ( ExpirationTimeFinalizedMultiplierFlagName = "expiration-multiplier" ExpirationNonFinalizedFlagName = "expiration-non-finalized" ExpirationTimeNonFinalizedMultiplierFlagName = "expiration-non-finalized-multiplier" + ExpirationBlocksHashesToHeightsFlagName = "expiration-blocks-hashes-to-heights" ExpirationNodeErrorsOnFinalizedFlagName = "expiration-finalized-node-errors" FlagCacheSizeName = "max-items" DefaultExpirationForNonFinalized = 500 * time.Millisecond DefaultExpirationTimeFinalizedMultiplier = 1.0 DefaultExpirationTimeNonFinalizedMultiplier = 1.0 + DefaultExpirationBlocksHashesToHeights = 48 * time.Hour DefaultExpirationTimeFinalized = time.Hour DefaultExpirationNodeErrors = 250 * time.Millisecond CacheNumCounters = 100000000 // expect 10M items @@ -41,31 +43,48 @@ const ( ) type CacheServer struct { - finalizedCache *ristretto.Cache - tempCache *ristretto.Cache - ExpirationFinalized time.Duration - ExpirationNonFinalized time.Duration - ExpirationNodeErrors time.Duration + finalizedCache *ristretto.Cache + tempCache *ristretto.Cache // cache for temporary inputs, such as latest blocks + blocksHashesToHeightsCache *ristretto.Cache + ExpirationFinalized time.Duration + ExpirationNonFinalized time.Duration + ExpirationNodeErrors time.Duration + ExpirationBlocksHashesToHeights time.Duration CacheMetrics *CacheMetrics CacheMaxCost int64 } -func (cs *CacheServer) InitCache(ctx context.Context, expiration time.Duration, expirationNonFinalized time.Duration, metricsAddr string, expirationFinalizedMultiplier float64, expirationNonFinalizedMultiplier float64) { +func (cs *CacheServer) InitCache( + ctx context.Context, + expiration time.Duration, + expirationNonFinalized time.Duration, + expirationNodeErrorsOnFinalized time.Duration, + expirationBlocksHashesToHeights time.Duration, + metricsAddr string, + expirationFinalizedMultiplier float64, + expirationNonFinalizedMultiplier float64, +) { cs.ExpirationFinalized = time.Duration(float64(expiration) * expirationFinalizedMultiplier) cs.ExpirationNonFinalized = time.Duration(float64(expirationNonFinalized) * expirationNonFinalizedMultiplier) + cs.ExpirationNodeErrors = expirationNodeErrorsOnFinalized + cs.ExpirationBlocksHashesToHeights = time.Duration(float64(expirationBlocksHashesToHeights)) - cache, err := ristretto.NewCache(&ristretto.Config{NumCounters: CacheNumCounters, MaxCost: cs.CacheMaxCost, BufferItems: 64}) + var err error + cs.tempCache, err = ristretto.NewCache(&ristretto.Config{NumCounters: CacheNumCounters, MaxCost: cs.CacheMaxCost, BufferItems: 64}) if err != nil { utils.LavaFormatFatal("could not create cache", err) } - cs.tempCache = cache - cache, err = ristretto.NewCache(&ristretto.Config{NumCounters: CacheNumCounters, MaxCost: cs.CacheMaxCost, BufferItems: 64}) + cs.finalizedCache, err = ristretto.NewCache(&ristretto.Config{NumCounters: CacheNumCounters, MaxCost: cs.CacheMaxCost, BufferItems: 64}) if err != nil { utils.LavaFormatFatal("could not create finalized cache", err) } - cs.finalizedCache = cache + + cs.blocksHashesToHeightsCache, err = ristretto.NewCache(&ristretto.Config{NumCounters: CacheNumCounters, MaxCost: cs.CacheMaxCost, BufferItems: 64}) + if err != nil { + utils.LavaFormatFatal("could not create blocks hashes to heights cache", err) + } // initialize prometheus cs.CacheMetrics = NewCacheMetricsServer(metricsAddr) @@ -183,6 +202,16 @@ func Server( utils.LavaFormatFatal("failed to read flag", err, utils.Attribute{Key: "flag", Value: ExpirationNonFinalizedFlagName}) } + expirationNodeErrorsOnFinalizedFlagName, err := flags.GetDuration(ExpirationNodeErrorsOnFinalizedFlagName) + if err != nil { + utils.LavaFormatFatal("failed to read flag", err, utils.Attribute{Key: "flag", Value: ExpirationNodeErrorsOnFinalizedFlagName}) + } + + expirationBlocksHashesToHeights, err := flags.GetDuration(ExpirationBlocksHashesToHeightsFlagName) + if err != nil { + utils.LavaFormatFatal("failed to read flag", err, utils.Attribute{Key: "flag", Value: ExpirationBlocksHashesToHeightsFlagName}) + } + expirationFinalizedMultiplier, err := flags.GetFloat64(ExpirationTimeFinalizedMultiplierFlagName) if err != nil { utils.LavaFormatFatal("failed to read flag", err, utils.Attribute{Key: "flag", Value: ExpirationTimeFinalizedMultiplierFlagName}) @@ -199,7 +228,7 @@ func Server( } cs := CacheServer{CacheMaxCost: cacheMaxCost} - cs.InitCache(ctx, expiration, expirationNonFinalized, metricsAddr, expirationFinalizedMultiplier, expirationNonFinalizedMultiplier) + cs.InitCache(ctx, expiration, expirationNonFinalized, expirationNodeErrorsOnFinalizedFlagName, expirationBlocksHashesToHeights, metricsAddr, expirationFinalizedMultiplier, expirationNonFinalizedMultiplier) // TODO: have a state tracker cs.Serve(ctx, listenAddr) } diff --git a/protocol/chainlib/consumer_ws_subscription_manager_test.go b/protocol/chainlib/consumer_ws_subscription_manager_test.go index 02de8604e5..ed79e1bea5 100644 --- a/protocol/chainlib/consumer_ws_subscription_manager_test.go +++ b/protocol/chainlib/consumer_ws_subscription_manager_test.go @@ -151,7 +151,7 @@ func TestConsumerWSSubscriptionManagerParallelSubscriptionsOnSameDappIdIp(t *tes firstReply, repliesChan, err = manager.StartSubscription(ctx, chainMessage1, nil, nil, dapp, ip, uniqueIdentifiers[index], nil) go func() { for subMsg := range repliesChan { - utils.LavaFormatInfo("got reply for index", utils.LogAttr("index", index)) + // utils.LavaFormatInfo("got reply for index", utils.LogAttr("index", index)) require.Equal(t, string(play.subscriptionFirstReply1), string(subMsg.Data)) } }() From 4a53d00a8b893828586c806c162805b6bbbdb2ac Mon Sep 17 00:00:00 2001 From: Ran Mishael <106548467+ranlavanet@users.noreply.github.com> Date: Tue, 20 Aug 2024 13:55:50 +0200 Subject: [PATCH 03/12] feat: PRT - add offline spec feat (#1635) * added option to configure static providers * who doesnt like some lint on comments? * disabled verifications for static provider on consumer, added static provider on provider side, disabled provider sessions on static provider code * added unitests for static providers * fix lock hanging * added tests * lint * added examples prints and script to run static provider * feat: PRT - allow offline spec loading --------- Co-authored-by: omerlavanet Co-authored-by: Omer <100387053+omerlavanet@users.noreply.github.com> --- protocol/chainlib/common_test_utils.go | 5 +- protocol/chainlib/jsonRPC_test.go | 4 +- protocol/common/cobra_common.go | 2 + protocol/rpcconsumer/rpcconsumer.go | 24 +++++++-- testutil/common/tester.go | 3 +- {testutil => utils}/keeper/spec.go | 69 ++++++++++++++++++------- x/pairing/keeper/pairing_test.go | 3 +- x/spec/ante/ante_test.go | 4 +- x/spec/genesis_test.go | 4 +- x/spec/keeper/grpc_query_params_test.go | 4 +- x/spec/keeper/grpc_query_spec_test.go | 8 +-- x/spec/keeper/msg_server_test.go | 4 +- x/spec/keeper/params_test.go | 4 +- 13 files changed, 95 insertions(+), 43 deletions(-) rename {testutil => utils}/keeper/spec.go (56%) diff --git a/protocol/chainlib/common_test_utils.go b/protocol/chainlib/common_test_utils.go index 2782799b23..1e4db5dc47 100644 --- a/protocol/chainlib/common_test_utils.go +++ b/protocol/chainlib/common_test_utils.go @@ -26,6 +26,7 @@ import ( "github.com/lavanet/lava/v2/protocol/lavasession" testcommon "github.com/lavanet/lava/v2/testutil/common" keepertest "github.com/lavanet/lava/v2/testutil/keeper" + specutils "github.com/lavanet/lava/v2/utils/keeper" plantypes "github.com/lavanet/lava/v2/x/plans/types" spectypes "github.com/lavanet/lava/v2/x/spec/types" "github.com/stretchr/testify/require" @@ -126,7 +127,7 @@ func CreateChainLibMocks( ) (cpar ChainParser, crout ChainRouter, cfetc chaintracker.ChainFetcher, closeServer func(), endpointRet *lavasession.RPCProviderEndpoint, errRet error) { utils.SetGlobalLoggingLevel("debug") closeServer = nil - spec, err := keepertest.GetASpec(specIndex, getToTopMostPath, nil, nil) + spec, err := specutils.GetASpec(specIndex, getToTopMostPath, nil, nil) if err != nil { return nil, nil, nil, nil, nil, err } @@ -250,7 +251,7 @@ func SetupForTests(t *testing.T, numOfProviders int, specID string, getToTopMost ts.Providers = append(ts.Providers, testcommon.CreateNewAccount(ts.Ctx, *ts.Keepers, balance)) } sdkContext := sdk.UnwrapSDKContext(ts.Ctx) - spec, err := keepertest.GetASpec(specID, getToTopMostPath, &sdkContext, &ts.Keepers.Spec) + spec, err := specutils.GetASpec(specID, getToTopMostPath, &sdkContext, &ts.Keepers.Spec) if err != nil { require.NoError(t, err) } diff --git a/protocol/chainlib/jsonRPC_test.go b/protocol/chainlib/jsonRPC_test.go index a110b22bca..15f9db0dc2 100644 --- a/protocol/chainlib/jsonRPC_test.go +++ b/protocol/chainlib/jsonRPC_test.go @@ -13,7 +13,7 @@ import ( "github.com/lavanet/lava/v2/protocol/chainlib/chainproxy/rpcInterfaceMessages" "github.com/lavanet/lava/v2/protocol/chainlib/extensionslib" "github.com/lavanet/lava/v2/protocol/common" - keepertest "github.com/lavanet/lava/v2/testutil/keeper" + specutils "github.com/lavanet/lava/v2/utils/keeper" plantypes "github.com/lavanet/lava/v2/x/plans/types" spectypes "github.com/lavanet/lava/v2/x/spec/types" "github.com/stretchr/testify/assert" @@ -253,7 +253,7 @@ func TestExtensions(t *testing.T) { configuredExtensions := map[string]struct{}{ "archive": {}, } - spec, err := keepertest.GetASpec(specname, "../../", nil, nil) + spec, err := specutils.GetASpec(specname, "../../", nil, nil) require.NoError(t, err) chainParser.SetPolicy(&plantypes.Policy{ChainPolicies: []plantypes.ChainPolicy{{ChainId: specname, Requirements: []plantypes.ChainRequirement{{Collection: spectypes.CollectionData{ApiInterface: "jsonrpc"}, Extensions: []string{"archive"}}}}}}, specname, "jsonrpc") diff --git a/protocol/common/cobra_common.go b/protocol/common/cobra_common.go index 19a64487d2..d7daa95e59 100644 --- a/protocol/common/cobra_common.go +++ b/protocol/common/cobra_common.go @@ -32,6 +32,7 @@ const ( // Disable relay retries when we get node errors. // This feature is suppose to help with successful relays in some chains that return node errors on rare race conditions on the serviced chains. DisableRetryOnNodeErrorsFlag = "disable-retry-on-node-error" + UseOfflineSpecFlag = "use-offline-spec" // allows the user to manually load a spec providing a path, this is useful to test spec changes before they hit the blockchain ) const ( @@ -55,6 +56,7 @@ type ConsumerCmdFlags struct { DebugRelays bool // enables debug mode for relays DisableConflictTransactions bool // disable conflict transactions DisableRetryOnNodeErrors bool // disable retries on node errors + OfflineSpecPath string // path to the spec file, works only when bootstrapping a single chain. } // default rolling logs behavior (if enabled) will store 3 files each 100MB for up to 1 day every time. diff --git a/protocol/rpcconsumer/rpcconsumer.go b/protocol/rpcconsumer/rpcconsumer.go index 8cee4014e9..c5b5a32e5b 100644 --- a/protocol/rpcconsumer/rpcconsumer.go +++ b/protocol/rpcconsumer/rpcconsumer.go @@ -29,6 +29,7 @@ import ( "github.com/lavanet/lava/v2/protocol/statetracker/updaters" "github.com/lavanet/lava/v2/protocol/upgrade" "github.com/lavanet/lava/v2/utils" + specutils "github.com/lavanet/lava/v2/utils/keeper" "github.com/lavanet/lava/v2/utils/rand" "github.com/lavanet/lava/v2/utils/sigs" conflicttypes "github.com/lavanet/lava/v2/x/conflict/types" @@ -214,8 +215,19 @@ func (rpcc *RPCConsumer) Start(ctx context.Context, options *rpcConsumerStartOpt } else { policyUpdaters.Store(rpcEndpoint.ChainID, updaters.NewPolicyUpdater(chainID, consumerStateTracker, consumerAddr.String(), chainParser, *rpcEndpoint)) } - // register for spec updates - err = rpcc.consumerStateTracker.RegisterForSpecUpdates(ctx, chainParser, *rpcEndpoint) + + if options.cmdFlags.OfflineSpecPath != "" { + // offline spec mode. + parsedOfflineSpec, loadError := specutils.GetSpecFromPath(options.cmdFlags.OfflineSpecPath, rpcEndpoint.ChainID, nil, nil) + if loadError != nil { + err = utils.LavaFormatError("failed loading offline spec", err, utils.LogAttr("spec_path", options.cmdFlags.OfflineSpecPath), utils.LogAttr("spec_id", rpcEndpoint.ChainID)) + } + utils.LavaFormatInfo("Loaded offline spec successfully", utils.LogAttr("spec_path", options.cmdFlags.OfflineSpecPath), utils.LogAttr("chain_id", parsedOfflineSpec.Index)) + chainParser.SetSpec(parsedOfflineSpec) + } else { + // register for spec updates + err = rpcc.consumerStateTracker.RegisterForSpecUpdates(ctx, chainParser, *rpcEndpoint) + } if err != nil { err = utils.LavaFormatError("failed registering for spec updates", err, utils.Attribute{Key: "endpoint", Value: rpcEndpoint}) errCh <- err @@ -561,7 +573,6 @@ rpcconsumer consumer_examples/full_consumer_example.yml --cache-be "127.0.0.1:77 } maxConcurrentProviders := viper.GetUint(common.MaximumConcurrentProvidersFlagName) - consumerPropagatedFlags := common.ConsumerCmdFlags{ HeadersFlag: viper.GetString(common.CorsHeadersFlag), CredentialsFlag: viper.GetString(common.CorsCredentialsFlag), @@ -573,6 +584,12 @@ rpcconsumer consumer_examples/full_consumer_example.yml --cache-be "127.0.0.1:77 DebugRelays: viper.GetBool(DebugRelaysFlagName), DisableConflictTransactions: viper.GetBool(common.DisableConflictTransactionsFlag), DisableRetryOnNodeErrors: viper.GetBool(common.DisableRetryOnNodeErrorsFlag), + OfflineSpecPath: viper.GetString(common.UseOfflineSpecFlag), + } + + // validate user is does not provide multi chain setup when using the offline spec feature. + if consumerPropagatedFlags.OfflineSpecPath != "" && len(rpcEndpoints) > 1 { + utils.LavaFormatFatal("offline spec modifications are supported only in single chain bootstrapping", nil, utils.LogAttr("len(rpcEndpoints)", len(rpcEndpoints)), utils.LogAttr("rpcEndpoints", rpcEndpoints)) } rpcConsumerSharedState := viper.GetBool(common.SharedStateFlag) @@ -615,6 +632,7 @@ rpcconsumer consumer_examples/full_consumer_example.yml --cache-be "127.0.0.1:77 cmdRPCConsumer.Flags().Bool(common.DisableConflictTransactionsFlag, false, "disabling conflict transactions, this flag should not be used as it harms the network's data reliability and therefore the service.") cmdRPCConsumer.Flags().DurationVar(&updaters.TimeOutForFetchingLavaBlocks, common.TimeOutForFetchingLavaBlocksFlag, time.Second*5, "setting the timeout for fetching lava blocks") cmdRPCConsumer.Flags().Bool(common.DisableRetryOnNodeErrorsFlag, false, "Disable relay retries on node errors, prevent the rpcconsumer trying a different provider") + cmdRPCConsumer.Flags().String(common.UseOfflineSpecFlag, "", "load offline spec provided path to spec file, used to test specs before they are proposed on chain") common.AddRollingLogConfig(cmdRPCConsumer) return cmdRPCConsumer diff --git a/testutil/common/tester.go b/testutil/common/tester.go index fd5ca31827..d6437ac909 100644 --- a/testutil/common/tester.go +++ b/testutil/common/tester.go @@ -16,6 +16,7 @@ import ( stakingtypes "github.com/cosmos/cosmos-sdk/x/staking/types" testkeeper "github.com/lavanet/lava/v2/testutil/keeper" "github.com/lavanet/lava/v2/utils" + specutils "github.com/lavanet/lava/v2/utils/keeper" "github.com/lavanet/lava/v2/utils/lavaslices" "github.com/lavanet/lava/v2/utils/sigs" dualstakingante "github.com/lavanet/lava/v2/x/dualstaking/ante" @@ -1127,7 +1128,7 @@ func (ts *Tester) SetupForTests(getToTopMostPath string, specId string, validato } sdkContext := sdk.UnwrapSDKContext(ts.Ctx) - spec, err := testkeeper.GetASpec(specId, getToTopMostPath, &sdkContext, &ts.Keepers.Spec) + spec, err := specutils.GetASpec(specId, getToTopMostPath, &sdkContext, &ts.Keepers.Spec) if err != nil { return err } diff --git a/testutil/keeper/spec.go b/utils/keeper/spec.go similarity index 56% rename from testutil/keeper/spec.go rename to utils/keeper/spec.go index 3574f15040..f76965c5a4 100644 --- a/testutil/keeper/spec.go +++ b/utils/keeper/spec.go @@ -5,7 +5,6 @@ import ( "encoding/json" "fmt" "os" - "strings" "testing" tmdb "github.com/cometbft/cometbft-db" @@ -67,7 +66,20 @@ func specKeeper() (*keeper.Keeper, sdk.Context, error) { return k, ctx, nil } -func GetASpec(specIndex, getToTopMostPath string, ctxArg *sdk.Context, keeper *keeper.Keeper) (specRet spectypes.Spec, err error) { +func decodeProposal(path string) (utils.SpecAddProposalJSON, error) { + proposal := utils.SpecAddProposalJSON{} + contents, err := os.ReadFile(path) + if err != nil { + return proposal, err + } + decoder := json.NewDecoder(bytes.NewReader(contents)) + decoder.DisallowUnknownFields() // This will make the unmarshal fail if there are unused fields + + err = decoder.Decode(&proposal) + return proposal, err +} + +func GetSpecFromPath(path string, specIndex string, ctxArg *sdk.Context, keeper *keeper.Keeper) (specRet spectypes.Spec, err error) { var ctx sdk.Context if keeper == nil || ctxArg == nil { keeper, ctx, err = specKeeper() @@ -77,31 +89,48 @@ func GetASpec(specIndex, getToTopMostPath string, ctxArg *sdk.Context, keeper *k } else { ctx = *ctxArg } - proposalFile := "./cookbook/specs/ibc.json,./cookbook/specs/cosmoswasm.json,./cookbook/specs/tendermint.json,./cookbook/specs/cosmossdk.json,./cookbook/specs/cosmossdk_full.json,./cookbook/specs/ethereum.json,./cookbook/specs/cosmoshub.json,./cookbook/specs/lava.json,./cookbook/specs/osmosis.json,./cookbook/specs/fantom.json,./cookbook/specs/celo.json,./cookbook/specs/optimism.json,./cookbook/specs/arbitrum.json,./cookbook/specs/starknet.json,./cookbook/specs/aptos.json,./cookbook/specs/juno.json,./cookbook/specs/polygon.json,./cookbook/specs/evmos.json,./cookbook/specs/base.json,./cookbook/specs/canto.json,./cookbook/specs/sui.json,./cookbook/specs/solana.json,./cookbook/specs/bsc.json,./cookbook/specs/axelar.json,./cookbook/specs/avalanche.json,./cookbook/specs/fvm.json" - for _, fileName := range strings.Split(proposalFile, ",") { - proposal := utils.SpecAddProposalJSON{} - contents, err := os.ReadFile(getToTopMostPath + fileName) + proposal, err := decodeProposal(path) + if err != nil { + return spectypes.Spec{}, err + } + + for _, spec := range proposal.Proposal.Specs { + keeper.SetSpec(ctx, spec) + if specIndex != spec.Index { + continue + } + fullspec, err := keeper.ExpandSpec(ctx, spec) if err != nil { return spectypes.Spec{}, err } - decoder := json.NewDecoder(bytes.NewReader(contents)) - decoder.DisallowUnknownFields() // This will make the unmarshal fail if there are unused fields + return fullspec, nil + } + return spectypes.Spec{}, fmt.Errorf("spec not found %s", path) +} - if err := decoder.Decode(&proposal); err != nil { +func GetASpec(specIndex, getToTopMostPath string, ctxArg *sdk.Context, keeper *keeper.Keeper) (specRet spectypes.Spec, err error) { + var ctx sdk.Context + if keeper == nil || ctxArg == nil { + keeper, ctx, err = specKeeper() + if err != nil { return spectypes.Spec{}, err } - - for _, spec := range proposal.Proposal.Specs { - keeper.SetSpec(ctx, spec) - if specIndex != spec.Index { - continue - } - fullspec, err := keeper.ExpandSpec(ctx, spec) - if err != nil { - return spectypes.Spec{}, err - } - return fullspec, nil + } else { + ctx = *ctxArg + } + proposalDirectory := "cookbook/specs/" + proposalFiles := []string{ + "ibc.json", "cosmoswasm.json", "tendermint.json", "cosmossdk.json", "cosmossdk_full.json", + "ethereum.json", "cosmoshub.json", "lava.json", "osmosis.json", "fantom.json", "celo.json", + "optimism.json", "arbitrum.json", "starknet.json", "aptos.json", "juno.json", "polygon.json", + "evmos.json", "base.json", "canto.json", "sui.json", "solana.json", "bsc.json", "axelar.json", + "avalanche.json", "fvm.json", "near.json", + } + for _, fileName := range proposalFiles { + spec, err := GetSpecFromPath(getToTopMostPath+proposalDirectory+fileName, specIndex, &ctx, keeper) + if err == nil { + return spec, nil } } return spectypes.Spec{}, fmt.Errorf("spec not found %s", specIndex) diff --git a/x/pairing/keeper/pairing_test.go b/x/pairing/keeper/pairing_test.go index 74165525af..a9cf57336b 100644 --- a/x/pairing/keeper/pairing_test.go +++ b/x/pairing/keeper/pairing_test.go @@ -10,6 +10,7 @@ import ( sdk "github.com/cosmos/cosmos-sdk/types" "github.com/lavanet/lava/v2/testutil/common" testkeeper "github.com/lavanet/lava/v2/testutil/keeper" + specutils "github.com/lavanet/lava/v2/utils/keeper" "github.com/lavanet/lava/v2/utils/lavaslices" "github.com/lavanet/lava/v2/utils/sigs" epochstoragetypes "github.com/lavanet/lava/v2/x/epochstorage/types" @@ -2212,7 +2213,7 @@ func TestMixBothExetensionAndAddonPairing(t *testing.T) { func TestMixSelectedProvidersAndArchivePairing(t *testing.T) { ts := newTester(t) ts.setupForPayments(1, 0, 0) // 1 provider, 0 client, default providers-to-pair - specEth, err := testkeeper.GetASpec("ETH1", "../../../", nil, nil) + specEth, err := specutils.GetASpec("ETH1", "../../../", nil, nil) if err != nil { require.NoError(t, err) } diff --git a/x/spec/ante/ante_test.go b/x/spec/ante/ante_test.go index 9047af0f1d..82571e5413 100644 --- a/x/spec/ante/ante_test.go +++ b/x/spec/ante/ante_test.go @@ -13,7 +13,7 @@ import ( v1 "github.com/cosmos/cosmos-sdk/x/gov/types/v1" "github.com/cosmos/gogoproto/proto" "github.com/lavanet/lava/v2/app" - testkeeper "github.com/lavanet/lava/v2/testutil/keeper" + specutils "github.com/lavanet/lava/v2/utils/keeper" plantypes "github.com/lavanet/lava/v2/x/plans/types" "github.com/lavanet/lava/v2/x/spec/ante" spectypes "github.com/lavanet/lava/v2/x/spec/types" @@ -181,7 +181,7 @@ func TestNewExpeditedProposalFilterAnteDecorator(t *testing.T) { tt := tt t.Run(tt.name, func(t *testing.T) { - k, ctx := testkeeper.SpecKeeper(t) + k, ctx := specutils.SpecKeeper(t) params := spectypes.DefaultParams() params.AllowlistedExpeditedMsgs = []string{ proto.MessageName(&banktypes.MsgSend{}), diff --git a/x/spec/genesis_test.go b/x/spec/genesis_test.go index 8604f72bc1..9faaa7e9ae 100644 --- a/x/spec/genesis_test.go +++ b/x/spec/genesis_test.go @@ -6,8 +6,8 @@ import ( types2 "github.com/cosmos/cosmos-sdk/x/auth/types" "github.com/cosmos/gogoproto/proto" - keepertest "github.com/lavanet/lava/v2/testutil/keeper" "github.com/lavanet/lava/v2/testutil/nullify" + specutils "github.com/lavanet/lava/v2/utils/keeper" "github.com/lavanet/lava/v2/x/spec" "github.com/lavanet/lava/v2/x/spec/types" "github.com/stretchr/testify/require" @@ -32,7 +32,7 @@ func TestGenesis(t *testing.T) { // this line is used by starport scaffolding # genesis/test/state } - k, ctx := keepertest.SpecKeeper(t) + k, ctx := specutils.SpecKeeper(t) spec.InitGenesis(ctx, *k, genesisState) got := spec.ExportGenesis(ctx, *k) require.NotNil(t, got) diff --git a/x/spec/keeper/grpc_query_params_test.go b/x/spec/keeper/grpc_query_params_test.go index ada5f2d3f0..5d94e82188 100644 --- a/x/spec/keeper/grpc_query_params_test.go +++ b/x/spec/keeper/grpc_query_params_test.go @@ -4,13 +4,13 @@ import ( "testing" sdk "github.com/cosmos/cosmos-sdk/types" - testkeeper "github.com/lavanet/lava/v2/testutil/keeper" + specutils "github.com/lavanet/lava/v2/utils/keeper" "github.com/lavanet/lava/v2/x/spec/types" "github.com/stretchr/testify/require" ) func TestParamsQuery(t *testing.T) { - keeper, ctx := testkeeper.SpecKeeper(t) + keeper, ctx := specutils.SpecKeeper(t) wctx := sdk.WrapSDKContext(ctx) params := types.DefaultParams() keeper.SetParams(ctx, params) diff --git a/x/spec/keeper/grpc_query_spec_test.go b/x/spec/keeper/grpc_query_spec_test.go index 0b1e33d1a8..ad97b52fbf 100644 --- a/x/spec/keeper/grpc_query_spec_test.go +++ b/x/spec/keeper/grpc_query_spec_test.go @@ -10,8 +10,8 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/status" - keepertest "github.com/lavanet/lava/v2/testutil/keeper" "github.com/lavanet/lava/v2/testutil/nullify" + specutils "github.com/lavanet/lava/v2/utils/keeper" "github.com/lavanet/lava/v2/x/spec/types" ) @@ -19,7 +19,7 @@ import ( var _ = strconv.IntSize func TestSpecQuerySingle(t *testing.T) { - keeper, ctx := keepertest.SpecKeeper(t) + keeper, ctx := specutils.SpecKeeper(t) wctx := sdk.WrapSDKContext(ctx) msgs := createNSpec(keeper, ctx, 2) for _, tc := range []struct { @@ -70,7 +70,7 @@ func TestSpecQuerySingle(t *testing.T) { } func TestSpecQuerySingleRaw(t *testing.T) { - keeper, ctx := keepertest.SpecKeeper(t) + keeper, ctx := specutils.SpecKeeper(t) wctx := sdk.WrapSDKContext(ctx) msgs := createNSpec(keeper, ctx, 2) @@ -98,7 +98,7 @@ func TestSpecQuerySingleRaw(t *testing.T) { } func TestSpecQueryPaginated(t *testing.T) { - keeper, ctx := keepertest.SpecKeeper(t) + keeper, ctx := specutils.SpecKeeper(t) wctx := sdk.WrapSDKContext(ctx) msgs := createNSpec(keeper, ctx, 5) diff --git a/x/spec/keeper/msg_server_test.go b/x/spec/keeper/msg_server_test.go index 241f721f95..b15805f52b 100644 --- a/x/spec/keeper/msg_server_test.go +++ b/x/spec/keeper/msg_server_test.go @@ -5,12 +5,12 @@ import ( "testing" sdk "github.com/cosmos/cosmos-sdk/types" - keepertest "github.com/lavanet/lava/v2/testutil/keeper" + specutils "github.com/lavanet/lava/v2/utils/keeper" "github.com/lavanet/lava/v2/x/spec/keeper" "github.com/lavanet/lava/v2/x/spec/types" ) func setupMsgServer(t testing.TB) (types.MsgServer, context.Context) { - k, ctx := keepertest.SpecKeeper(t) + k, ctx := specutils.SpecKeeper(t) return keeper.NewMsgServerImpl(*k), sdk.WrapSDKContext(ctx) } diff --git a/x/spec/keeper/params_test.go b/x/spec/keeper/params_test.go index 089696024e..429f0d7410 100644 --- a/x/spec/keeper/params_test.go +++ b/x/spec/keeper/params_test.go @@ -3,13 +3,13 @@ package keeper_test import ( "testing" - testkeeper "github.com/lavanet/lava/v2/testutil/keeper" + specutils "github.com/lavanet/lava/v2/utils/keeper" "github.com/lavanet/lava/v2/x/spec/types" "github.com/stretchr/testify/require" ) func TestGetParams(t *testing.T) { - k, ctx := testkeeper.SpecKeeper(t) + k, ctx := specutils.SpecKeeper(t) params := types.DefaultParams() k.SetParams(ctx, params) From 16542182c2f88b5560ccb595ebdef60db23a9c3d Mon Sep 17 00:00:00 2001 From: Ran Mishael <106548467+ranlavanet@users.noreply.github.com> Date: Tue, 20 Aug 2024 14:43:43 +0200 Subject: [PATCH 04/12] chore: PRT - adding logs to error for requested block mismatch (#1638) --- protocol/rpcprovider/rpcprovider_server.go | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/protocol/rpcprovider/rpcprovider_server.go b/protocol/rpcprovider/rpcprovider_server.go index 272836f8c9..4de897dd9b 100644 --- a/protocol/rpcprovider/rpcprovider_server.go +++ b/protocol/rpcprovider/rpcprovider_server.go @@ -367,7 +367,18 @@ func (rpcps *RPCProviderServer) ValidateRequest(chainMessage chainlib.ChainMessa utils.Attribute{Key: "provider_requested_block", Value: reqBlock}, utils.Attribute{Key: "consumer_requested_block", Value: request.RelayData.RequestBlock}, utils.Attribute{Key: "GUID", Value: ctx}) - return utils.LavaFormatError("requested block mismatch between consumer and provider", nil, utils.LogAttr("method", chainMessage.GetApi().Name), utils.Attribute{Key: "provider_parsed_block_pre_update", Value: providerRequestedBlockPreUpdate}, utils.Attribute{Key: "provider_requested_block", Value: reqBlock}, utils.Attribute{Key: "consumer_requested_block", Value: request.RelayData.RequestBlock}, utils.Attribute{Key: "GUID", Value: ctx}, utils.Attribute{Key: "metadata", Value: request.RelayData.Metadata}) + // TODO, we need to return an error here, this was disabled so relays will pass, but it will cause data reliability issues. + // once we understand the issue return the error. + utils.LavaFormatError("requested block mismatch between consumer and provider", nil, + utils.LogAttr("request data", string(request.RelayData.Data)), + utils.LogAttr("request path", request.RelayData.ApiUrl), + utils.LogAttr("method", chainMessage.GetApi().Name), + utils.Attribute{Key: "provider_parsed_block_pre_update", Value: providerRequestedBlockPreUpdate}, + utils.Attribute{Key: "provider_requested_block", Value: reqBlock}, + utils.Attribute{Key: "consumer_requested_block", Value: request.RelayData.RequestBlock}, + utils.Attribute{Key: "GUID", Value: ctx}, + utils.Attribute{Key: "metadata", Value: request.RelayData.Metadata}, + ) } } return nil From 5937abe86e917aab29f0b9fcb6abd925f8f56610 Mon Sep 17 00:00:00 2001 From: Omer <100387053+omerlavanet@users.noreply.github.com> Date: Wed, 21 Aug 2024 10:36:39 +0300 Subject: [PATCH 05/12] chore: consollidate-chain-message-data (#1636) * create protocolMessage class * fix bug * fix arg mismatch * fix test --------- Co-authored-by: Ran Mishael --- protocol/chainlib/chainlib.go | 6 +- protocol/chainlib/chainlib_mock.go | 141 ++++++++++++------ .../chainlib/consumer_websocket_manager.go | 24 +-- .../consumer_ws_subscription_manager.go | 86 +++++------ .../consumer_ws_subscription_manager_test.go | 87 ++++++----- protocol/chainlib/protocol_message.go | 53 +++++++ .../consumer_session_manager_test.go | 23 ++- protocol/lavasession/used_providers.go | 15 +- protocol/rpcconsumer/rpcconsumer_server.go | 115 +++++++------- x/pairing/types/relay_mock.pb.go | 48 +++--- 10 files changed, 357 insertions(+), 241 deletions(-) create mode 100644 protocol/chainlib/protocol_message.go diff --git a/protocol/chainlib/chainlib.go b/protocol/chainlib/chainlib.go index 83aa8e1e30..8ed037669d 100644 --- a/protocol/chainlib/chainlib.go +++ b/protocol/chainlib/chainlib.go @@ -125,15 +125,13 @@ type RelaySender interface { consumerIp string, analytics *metrics.RelayMetrics, metadata []pairingtypes.Metadata, - ) (ChainMessage, map[string]string, *pairingtypes.RelayPrivateData, error) + ) (ProtocolMessage, error) SendParsedRelay( ctx context.Context, dappID string, consumerIp string, analytics *metrics.RelayMetrics, - chainMessage ChainMessage, - directiveHeaders map[string]string, - relayRequestData *pairingtypes.RelayPrivateData, + protocolMessage ProtocolMessage, ) (relayResult *common.RelayResult, errRet error) CreateDappKey(dappID, consumerIp string) string CancelSubscriptionContext(subscriptionKey string) diff --git a/protocol/chainlib/chainlib_mock.go b/protocol/chainlib/chainlib_mock.go index 284a6c9b26..757c2cd9e0 100644 --- a/protocol/chainlib/chainlib_mock.go +++ b/protocol/chainlib/chainlib_mock.go @@ -1,10 +1,5 @@ // Code generated by MockGen. DO NOT EDIT. // Source: protocol/chainlib/chainlib.go -// -// Generated by this command: -// -// mockgen -source protocol/chainlib/chainlib.go -destination protocol/chainlib/chainlib_mock.go -package chainlib -// // Package chainlib is a generated GoMock package. package chainlib @@ -14,6 +9,7 @@ import ( reflect "reflect" time "time" + gomock "github.com/golang/mock/gomock" rpcInterfaceMessages "github.com/lavanet/lava/v2/protocol/chainlib/chainproxy/rpcInterfaceMessages" rpcclient "github.com/lavanet/lava/v2/protocol/chainlib/chainproxy/rpcclient" extensionslib "github.com/lavanet/lava/v2/protocol/chainlib/extensionslib" @@ -21,7 +17,6 @@ import ( metrics "github.com/lavanet/lava/v2/protocol/metrics" types "github.com/lavanet/lava/v2/x/pairing/types" types0 "github.com/lavanet/lava/v2/x/spec/types" - gomock "go.uber.org/mock/gomock" ) // MockChainParser is a mock of ChainParser interface. @@ -100,7 +95,7 @@ func (m *MockChainParser) CraftMessage(parser *types0.ParseDirective, connection } // CraftMessage indicates an expected call of CraftMessage. -func (mr *MockChainParserMockRecorder) CraftMessage(parser, connectionType, craftData, metadata any) *gomock.Call { +func (mr *MockChainParserMockRecorder) CraftMessage(parser, connectionType, craftData, metadata interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CraftMessage", reflect.TypeOf((*MockChainParser)(nil).CraftMessage), parser, connectionType, craftData, metadata) } @@ -145,7 +140,7 @@ func (m *MockChainParser) GetParsingByTag(tag types0.FUNCTION_TAG) (*types0.Pars } // GetParsingByTag indicates an expected call of GetParsingByTag. -func (mr *MockChainParserMockRecorder) GetParsingByTag(tag any) *gomock.Call { +func (mr *MockChainParserMockRecorder) GetParsingByTag(tag interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetParsingByTag", reflect.TypeOf((*MockChainParser)(nil).GetParsingByTag), tag) } @@ -174,7 +169,7 @@ func (m *MockChainParser) GetVerifications(supported []string) ([]VerificationCo } // GetVerifications indicates an expected call of GetVerifications. -func (mr *MockChainParserMockRecorder) GetVerifications(supported any) *gomock.Call { +func (mr *MockChainParserMockRecorder) GetVerifications(supported interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetVerifications", reflect.TypeOf((*MockChainParser)(nil).GetVerifications), supported) } @@ -190,7 +185,7 @@ func (m *MockChainParser) HandleHeaders(metadata []types.Metadata, apiCollection } // HandleHeaders indicates an expected call of HandleHeaders. -func (mr *MockChainParserMockRecorder) HandleHeaders(metadata, apiCollection, headersDirection any) *gomock.Call { +func (mr *MockChainParserMockRecorder) HandleHeaders(metadata, apiCollection, headersDirection interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleHeaders", reflect.TypeOf((*MockChainParser)(nil).HandleHeaders), metadata, apiCollection, headersDirection) } @@ -205,7 +200,7 @@ func (m *MockChainParser) ParseMsg(url string, data []byte, connectionType strin } // ParseMsg indicates an expected call of ParseMsg. -func (mr *MockChainParserMockRecorder) ParseMsg(url, data, connectionType, metadata, extensionInfo any) *gomock.Call { +func (mr *MockChainParserMockRecorder) ParseMsg(url, data, connectionType, metadata, extensionInfo interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ParseMsg", reflect.TypeOf((*MockChainParser)(nil).ParseMsg), url, data, connectionType, metadata, extensionInfo) } @@ -221,7 +216,7 @@ func (m *MockChainParser) SeparateAddonsExtensions(supported []string) ([]string } // SeparateAddonsExtensions indicates an expected call of SeparateAddonsExtensions. -func (mr *MockChainParserMockRecorder) SeparateAddonsExtensions(supported any) *gomock.Call { +func (mr *MockChainParserMockRecorder) SeparateAddonsExtensions(supported interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SeparateAddonsExtensions", reflect.TypeOf((*MockChainParser)(nil).SeparateAddonsExtensions), supported) } @@ -235,7 +230,7 @@ func (m *MockChainParser) SetPolicy(policy PolicyInf, chainId, apiInterface stri } // SetPolicy indicates an expected call of SetPolicy. -func (mr *MockChainParserMockRecorder) SetPolicy(policy, chainId, apiInterface any) *gomock.Call { +func (mr *MockChainParserMockRecorder) SetPolicy(policy, chainId, apiInterface interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetPolicy", reflect.TypeOf((*MockChainParser)(nil).SetPolicy), policy, chainId, apiInterface) } @@ -247,7 +242,7 @@ func (m *MockChainParser) SetSpec(spec types0.Spec) { } // SetSpec indicates an expected call of SetSpec. -func (mr *MockChainParserMockRecorder) SetSpec(spec any) *gomock.Call { +func (mr *MockChainParserMockRecorder) SetSpec(spec interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetSpec", reflect.TypeOf((*MockChainParser)(nil).SetSpec), spec) } @@ -259,7 +254,7 @@ func (m *MockChainParser) UpdateBlockTime(newBlockTime time.Duration) { } // UpdateBlockTime indicates an expected call of UpdateBlockTime. -func (mr *MockChainParserMockRecorder) UpdateBlockTime(newBlockTime any) *gomock.Call { +func (mr *MockChainParserMockRecorder) UpdateBlockTime(newBlockTime interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateBlockTime", reflect.TypeOf((*MockChainParser)(nil).UpdateBlockTime), newBlockTime) } @@ -294,7 +289,7 @@ func (m *MockChainMessage) AppendHeader(metadata []types.Metadata) { } // AppendHeader indicates an expected call of AppendHeader. -func (mr *MockChainMessageMockRecorder) AppendHeader(metadata any) *gomock.Call { +func (mr *MockChainMessageMockRecorder) AppendHeader(metadata interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AppendHeader", reflect.TypeOf((*MockChainMessage)(nil).AppendHeader), metadata) } @@ -309,7 +304,7 @@ func (m *MockChainMessage) CheckResponseError(data []byte, httpStatusCode int) ( } // CheckResponseError indicates an expected call of CheckResponseError. -func (mr *MockChainMessageMockRecorder) CheckResponseError(data, httpStatusCode any) *gomock.Call { +func (mr *MockChainMessageMockRecorder) CheckResponseError(data, httpStatusCode interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CheckResponseError", reflect.TypeOf((*MockChainMessage)(nil).CheckResponseError), data, httpStatusCode) } @@ -410,6 +405,21 @@ func (mr *MockChainMessageMockRecorder) GetRPCMessage() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRPCMessage", reflect.TypeOf((*MockChainMessage)(nil).GetRPCMessage)) } +// GetRawRequestHash mocks base method. +func (m *MockChainMessage) GetRawRequestHash() ([]byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetRawRequestHash") + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetRawRequestHash indicates an expected call of GetRawRequestHash. +func (mr *MockChainMessageMockRecorder) GetRawRequestHash() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRawRequestHash", reflect.TypeOf((*MockChainMessage)(nil).GetRawRequestHash)) +} + // OverrideExtensions mocks base method. func (m *MockChainMessage) OverrideExtensions(extensionNames []string, extensionParser *extensionslib.ExtensionParser) { m.ctrl.T.Helper() @@ -417,7 +427,7 @@ func (m *MockChainMessage) OverrideExtensions(extensionNames []string, extension } // OverrideExtensions indicates an expected call of OverrideExtensions. -func (mr *MockChainMessageMockRecorder) OverrideExtensions(extensionNames, extensionParser any) *gomock.Call { +func (mr *MockChainMessageMockRecorder) OverrideExtensions(extensionNames, extensionParser interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OverrideExtensions", reflect.TypeOf((*MockChainMessage)(nil).OverrideExtensions), extensionNames, extensionParser) } @@ -446,15 +456,29 @@ func (m *MockChainMessage) SetForceCacheRefresh(force bool) bool { } // SetForceCacheRefresh indicates an expected call of SetForceCacheRefresh. -func (mr *MockChainMessageMockRecorder) SetForceCacheRefresh(force any) *gomock.Call { +func (mr *MockChainMessageMockRecorder) SetForceCacheRefresh(force interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetForceCacheRefresh", reflect.TypeOf((*MockChainMessage)(nil).SetForceCacheRefresh), force) } +// SubscriptionIdExtractor mocks base method. +func (m *MockChainMessage) SubscriptionIdExtractor(reply *rpcclient.JsonrpcMessage) string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SubscriptionIdExtractor", reply) + ret0, _ := ret[0].(string) + return ret0 +} + +// SubscriptionIdExtractor indicates an expected call of SubscriptionIdExtractor. +func (mr *MockChainMessageMockRecorder) SubscriptionIdExtractor(reply interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SubscriptionIdExtractor", reflect.TypeOf((*MockChainMessage)(nil).SubscriptionIdExtractor), reply) +} + // TimeoutOverride mocks base method. func (m *MockChainMessage) TimeoutOverride(arg0 ...time.Duration) time.Duration { m.ctrl.T.Helper() - varargs := []any{} + varargs := []interface{}{} for _, a := range arg0 { varargs = append(varargs, a) } @@ -464,7 +488,7 @@ func (m *MockChainMessage) TimeoutOverride(arg0 ...time.Duration) time.Duration } // TimeoutOverride indicates an expected call of TimeoutOverride. -func (mr *MockChainMessageMockRecorder) TimeoutOverride(arg0 ...any) *gomock.Call { +func (mr *MockChainMessageMockRecorder) TimeoutOverride(arg0 ...interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TimeoutOverride", reflect.TypeOf((*MockChainMessage)(nil).TimeoutOverride), arg0...) } @@ -478,7 +502,7 @@ func (m *MockChainMessage) UpdateLatestBlockInMessage(latestBlock int64, modifyC } // UpdateLatestBlockInMessage indicates an expected call of UpdateLatestBlockInMessage. -func (mr *MockChainMessageMockRecorder) UpdateLatestBlockInMessage(latestBlock, modifyContent any) *gomock.Call { +func (mr *MockChainMessageMockRecorder) UpdateLatestBlockInMessage(latestBlock, modifyContent interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateLatestBlockInMessage", reflect.TypeOf((*MockChainMessage)(nil).UpdateLatestBlockInMessage), latestBlock, modifyContent) } @@ -506,6 +530,21 @@ func (m *MockChainMessageForSend) EXPECT() *MockChainMessageForSendMockRecorder return m.recorder } +// CheckResponseError mocks base method. +func (m *MockChainMessageForSend) CheckResponseError(data []byte, httpStatusCode int) (bool, string) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CheckResponseError", data, httpStatusCode) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(string) + return ret0, ret1 +} + +// CheckResponseError indicates an expected call of CheckResponseError. +func (mr *MockChainMessageForSendMockRecorder) CheckResponseError(data, httpStatusCode interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CheckResponseError", reflect.TypeOf((*MockChainMessageForSend)(nil).CheckResponseError), data, httpStatusCode) +} + // GetApi mocks base method. func (m *MockChainMessageForSend) GetApi() *types0.Api { m.ctrl.T.Helper() @@ -565,7 +604,7 @@ func (mr *MockChainMessageForSendMockRecorder) GetRPCMessage() *gomock.Call { // TimeoutOverride mocks base method. func (m *MockChainMessageForSend) TimeoutOverride(arg0 ...time.Duration) time.Duration { m.ctrl.T.Helper() - varargs := []any{} + varargs := []interface{}{} for _, a := range arg0 { varargs = append(varargs, a) } @@ -575,7 +614,7 @@ func (m *MockChainMessageForSend) TimeoutOverride(arg0 ...time.Duration) time.Du } // TimeoutOverride indicates an expected call of TimeoutOverride. -func (mr *MockChainMessageForSendMockRecorder) TimeoutOverride(arg0 ...any) *gomock.Call { +func (mr *MockChainMessageForSendMockRecorder) TimeoutOverride(arg0 ...interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TimeoutOverride", reflect.TypeOf((*MockChainMessageForSend)(nil).TimeoutOverride), arg0...) } @@ -647,7 +686,7 @@ func (m *MockRelaySender) CancelSubscriptionContext(subscriptionKey string) { } // CancelSubscriptionContext indicates an expected call of CancelSubscriptionContext. -func (mr *MockRelaySenderMockRecorder) CancelSubscriptionContext(subscriptionKey any) *gomock.Call { +func (mr *MockRelaySenderMockRecorder) CancelSubscriptionContext(subscriptionKey interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CancelSubscriptionContext", reflect.TypeOf((*MockRelaySender)(nil).CancelSubscriptionContext), subscriptionKey) } @@ -661,41 +700,39 @@ func (m *MockRelaySender) CreateDappKey(dappID, consumerIp string) string { } // CreateDappKey indicates an expected call of CreateDappKey. -func (mr *MockRelaySenderMockRecorder) CreateDappKey(dappID, consumerIp any) *gomock.Call { +func (mr *MockRelaySenderMockRecorder) CreateDappKey(dappID, consumerIp interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateDappKey", reflect.TypeOf((*MockRelaySender)(nil).CreateDappKey), dappID, consumerIp) } // ParseRelay mocks base method. -func (m *MockRelaySender) ParseRelay(ctx context.Context, url, req, connectionType, dappID, consumerIp string, analytics *metrics.RelayMetrics, metadata []types.Metadata) (ChainMessage, map[string]string, *types.RelayPrivateData, error) { +func (m *MockRelaySender) ParseRelay(ctx context.Context, url, req, connectionType, dappID, consumerIp string, analytics *metrics.RelayMetrics, metadata []types.Metadata) (ProtocolMessage, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "ParseRelay", ctx, url, req, connectionType, dappID, consumerIp, analytics, metadata) - ret0, _ := ret[0].(ChainMessage) - ret1, _ := ret[1].(map[string]string) - ret2, _ := ret[2].(*types.RelayPrivateData) - ret3, _ := ret[3].(error) - return ret0, ret1, ret2, ret3 + ret0, _ := ret[0].(ProtocolMessage) + ret1, _ := ret[1].(error) + return ret0, ret1 } // ParseRelay indicates an expected call of ParseRelay. -func (mr *MockRelaySenderMockRecorder) ParseRelay(ctx, url, req, connectionType, dappID, consumerIp, analytics, metadata any) *gomock.Call { +func (mr *MockRelaySenderMockRecorder) ParseRelay(ctx, url, req, connectionType, dappID, consumerIp, analytics, metadata interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ParseRelay", reflect.TypeOf((*MockRelaySender)(nil).ParseRelay), ctx, url, req, connectionType, dappID, consumerIp, analytics, metadata) } // SendParsedRelay mocks base method. -func (m *MockRelaySender) SendParsedRelay(ctx context.Context, dappID, consumerIp string, analytics *metrics.RelayMetrics, chainMessage ChainMessage, directiveHeaders map[string]string, relayRequestData *types.RelayPrivateData) (*common.RelayResult, error) { +func (m *MockRelaySender) SendParsedRelay(ctx context.Context, dappID, consumerIp string, analytics *metrics.RelayMetrics, protocolMessage ProtocolMessage) (*common.RelayResult, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SendParsedRelay", ctx, dappID, consumerIp, analytics, chainMessage, directiveHeaders, relayRequestData) + ret := m.ctrl.Call(m, "SendParsedRelay", ctx, dappID, consumerIp, analytics, protocolMessage) ret0, _ := ret[0].(*common.RelayResult) ret1, _ := ret[1].(error) return ret0, ret1 } // SendParsedRelay indicates an expected call of SendParsedRelay. -func (mr *MockRelaySenderMockRecorder) SendParsedRelay(ctx, dappID, consumerIp, analytics, chainMessage, directiveHeaders, relayRequestData any) *gomock.Call { +func (mr *MockRelaySenderMockRecorder) SendParsedRelay(ctx, dappID, consumerIp, analytics, protocolMessage interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendParsedRelay", reflect.TypeOf((*MockRelaySender)(nil).SendParsedRelay), ctx, dappID, consumerIp, analytics, chainMessage, directiveHeaders, relayRequestData) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendParsedRelay", reflect.TypeOf((*MockRelaySender)(nil).SendParsedRelay), ctx, dappID, consumerIp, analytics, protocolMessage) } // SendRelay mocks base method. @@ -708,7 +745,7 @@ func (m *MockRelaySender) SendRelay(ctx context.Context, url, req, connectionTyp } // SendRelay indicates an expected call of SendRelay. -func (mr *MockRelaySenderMockRecorder) SendRelay(ctx, url, req, connectionType, dappID, consumerIp, analytics, metadataValues any) *gomock.Call { +func (mr *MockRelaySenderMockRecorder) SendRelay(ctx, url, req, connectionType, dappID, consumerIp, analytics, metadataValues interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendRelay", reflect.TypeOf((*MockRelaySender)(nil).SendRelay), ctx, url, req, connectionType, dappID, consumerIp, analytics, metadataValues) } @@ -720,7 +757,7 @@ func (m *MockRelaySender) SetConsistencySeenBlock(blockSeen int64, key string) { } // SetConsistencySeenBlock indicates an expected call of SetConsistencySeenBlock. -func (mr *MockRelaySenderMockRecorder) SetConsistencySeenBlock(blockSeen, key any) *gomock.Call { +func (mr *MockRelaySenderMockRecorder) SetConsistencySeenBlock(blockSeen, key interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetConsistencySeenBlock", reflect.TypeOf((*MockRelaySender)(nil).SetConsistencySeenBlock), blockSeen, key) } @@ -748,6 +785,20 @@ func (m *MockChainListener) EXPECT() *MockChainListenerMockRecorder { return m.recorder } +// GetListeningAddress mocks base method. +func (m *MockChainListener) GetListeningAddress() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetListeningAddress") + ret0, _ := ret[0].(string) + return ret0 +} + +// GetListeningAddress indicates an expected call of GetListeningAddress. +func (mr *MockChainListenerMockRecorder) GetListeningAddress() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetListeningAddress", reflect.TypeOf((*MockChainListener)(nil).GetListeningAddress)) +} + // Serve mocks base method. func (m *MockChainListener) Serve(ctx context.Context, cmdFlags common.ConsumerCmdFlags) { m.ctrl.T.Helper() @@ -755,7 +806,7 @@ func (m *MockChainListener) Serve(ctx context.Context, cmdFlags common.ConsumerC } // Serve indicates an expected call of Serve. -func (mr *MockChainListenerMockRecorder) Serve(ctx, cmdFlags any) *gomock.Call { +func (mr *MockChainListenerMockRecorder) Serve(ctx, cmdFlags interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Serve", reflect.TypeOf((*MockChainListener)(nil).Serve), ctx, cmdFlags) } @@ -792,13 +843,13 @@ func (m *MockChainRouter) ExtensionsSupported(arg0 []string) bool { } // ExtensionsSupported indicates an expected call of ExtensionsSupported. -func (mr *MockChainRouterMockRecorder) ExtensionsSupported(arg0 any) *gomock.Call { +func (mr *MockChainRouterMockRecorder) ExtensionsSupported(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExtensionsSupported", reflect.TypeOf((*MockChainRouter)(nil).ExtensionsSupported), arg0) } // SendNodeMsg mocks base method. -func (m *MockChainRouter) SendNodeMsg(ctx context.Context, ch chan any, chainMessage ChainMessageForSend, extensions []string) (*RelayReplyWrapper, string, *rpcclient.ClientSubscription, common.NodeUrl, string, error) { +func (m *MockChainRouter) SendNodeMsg(ctx context.Context, ch chan interface{}, chainMessage ChainMessageForSend, extensions []string) (*RelayReplyWrapper, string, *rpcclient.ClientSubscription, common.NodeUrl, string, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "SendNodeMsg", ctx, ch, chainMessage, extensions) ret0, _ := ret[0].(*RelayReplyWrapper) @@ -811,7 +862,7 @@ func (m *MockChainRouter) SendNodeMsg(ctx context.Context, ch chan any, chainMes } // SendNodeMsg indicates an expected call of SendNodeMsg. -func (mr *MockChainRouterMockRecorder) SendNodeMsg(ctx, ch, chainMessage, extensions any) *gomock.Call { +func (mr *MockChainRouterMockRecorder) SendNodeMsg(ctx, ch, chainMessage, extensions interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendNodeMsg", reflect.TypeOf((*MockChainRouter)(nil).SendNodeMsg), ctx, ch, chainMessage, extensions) } @@ -855,7 +906,7 @@ func (mr *MockChainProxyMockRecorder) GetChainProxyInformation() *gomock.Call { } // SendNodeMsg mocks base method. -func (m *MockChainProxy) SendNodeMsg(ctx context.Context, ch chan any, chainMessage ChainMessageForSend) (*RelayReplyWrapper, string, *rpcclient.ClientSubscription, error) { +func (m *MockChainProxy) SendNodeMsg(ctx context.Context, ch chan interface{}, chainMessage ChainMessageForSend) (*RelayReplyWrapper, string, *rpcclient.ClientSubscription, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "SendNodeMsg", ctx, ch, chainMessage) ret0, _ := ret[0].(*RelayReplyWrapper) @@ -866,7 +917,7 @@ func (m *MockChainProxy) SendNodeMsg(ctx context.Context, ch chan any, chainMess } // SendNodeMsg indicates an expected call of SendNodeMsg. -func (mr *MockChainProxyMockRecorder) SendNodeMsg(ctx, ch, chainMessage any) *gomock.Call { +func (mr *MockChainProxyMockRecorder) SendNodeMsg(ctx, ch, chainMessage interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendNodeMsg", reflect.TypeOf((*MockChainProxy)(nil).SendNodeMsg), ctx, ch, chainMessage) } diff --git a/protocol/chainlib/consumer_websocket_manager.go b/protocol/chainlib/consumer_websocket_manager.go index 3017328db4..75dd3ca1ba 100644 --- a/protocol/chainlib/consumer_websocket_manager.go +++ b/protocol/chainlib/consumer_websocket_manager.go @@ -149,7 +149,7 @@ func (cwm *ConsumerWebsocketManager) ListenToMessages() { metricsData := metrics.NewRelayAnalytics(dappID, cwm.chainId, cwm.apiInterface) - chainMessage, directiveHeaders, relayRequestData, err := cwm.relaySender.ParseRelay(webSocketCtx, "", string(msg), cwm.connectionType, dappID, userIp, metricsData, nil) + protocolMessage, err := cwm.relaySender.ParseRelay(webSocketCtx, "", string(msg), cwm.connectionType, dappID, userIp, metricsData, nil) if err != nil { formatterMsg := logger.AnalyzeWebSocketErrorAndGetFormattedMessage(websocketConn.LocalAddr().String(), utils.LavaFormatError("could not parse message", err), msgSeed, msg, cwm.apiInterface, time.Since(startTime)) if formatterMsg != nil { @@ -159,9 +159,9 @@ func (cwm *ConsumerWebsocketManager) ListenToMessages() { } // check whether its a normal relay / unsubscribe / unsubscribe_all otherwise its a subscription flow. - if !IsFunctionTagOfType(chainMessage, spectypes.FUNCTION_TAG_SUBSCRIBE) { - if IsFunctionTagOfType(chainMessage, spectypes.FUNCTION_TAG_UNSUBSCRIBE) { - err := cwm.consumerWsSubscriptionManager.Unsubscribe(webSocketCtx, chainMessage, directiveHeaders, relayRequestData, dappID, userIp, cwm.WebsocketConnectionUID, metricsData) + if !IsFunctionTagOfType(protocolMessage, spectypes.FUNCTION_TAG_SUBSCRIBE) { + if IsFunctionTagOfType(protocolMessage, spectypes.FUNCTION_TAG_UNSUBSCRIBE) { + err := cwm.consumerWsSubscriptionManager.Unsubscribe(webSocketCtx, protocolMessage, dappID, userIp, cwm.WebsocketConnectionUID, metricsData) if err != nil { utils.LavaFormatWarning("error unsubscribing from subscription", err, utils.LogAttr("GUID", webSocketCtx)) if err == common.SubscriptionNotFoundError { @@ -174,7 +174,7 @@ func (cwm *ConsumerWebsocketManager) ListenToMessages() { } } continue - } else if IsFunctionTagOfType(chainMessage, spectypes.FUNCTION_TAG_UNSUBSCRIBE_ALL) { + } else if IsFunctionTagOfType(protocolMessage, spectypes.FUNCTION_TAG_UNSUBSCRIBE_ALL) { err := cwm.consumerWsSubscriptionManager.UnsubscribeAll(webSocketCtx, dappID, userIp, cwm.WebsocketConnectionUID, metricsData) if err != nil { utils.LavaFormatWarning("error unsubscribing from all subscription", err, utils.LogAttr("GUID", webSocketCtx)) @@ -182,7 +182,7 @@ func (cwm *ConsumerWebsocketManager) ListenToMessages() { continue } else { // Normal relay over websocket. (not subscription related) - relayResult, err := cwm.relaySender.SendParsedRelay(webSocketCtx, dappID, userIp, metricsData, chainMessage, directiveHeaders, relayRequestData) + relayResult, err := cwm.relaySender.SendParsedRelay(webSocketCtx, dappID, userIp, metricsData, protocolMessage) if err != nil { formatterMsg := logger.AnalyzeWebSocketErrorAndGetFormattedMessage(websocketConn.LocalAddr().String(), utils.LavaFormatError("could not send parsed relay", err), msgSeed, msg, cwm.apiInterface, time.Since(startTime)) if formatterMsg != nil { @@ -202,16 +202,16 @@ func (cwm *ConsumerWebsocketManager) ListenToMessages() { } // Subscription flow - inputFormatter, outputFormatter := formatter.FormatterForRelayRequestAndResponse(relayRequestData.ApiInterface) // we use this to preserve the original jsonrpc id - inputFormatter(relayRequestData.Data) // set the extracted jsonrpc id + inputFormatter, outputFormatter := formatter.FormatterForRelayRequestAndResponse(protocolMessage.GetApiCollection().CollectionData.ApiInterface) // we use this to preserve the original jsonrpc id + inputFormatter(protocolMessage.RelayPrivateData().Data) // set the extracted jsonrpc id - reply, subscriptionMsgsChan, err := cwm.consumerWsSubscriptionManager.StartSubscription(webSocketCtx, chainMessage, directiveHeaders, relayRequestData, dappID, userIp, cwm.WebsocketConnectionUID, metricsData) + reply, subscriptionMsgsChan, err := cwm.consumerWsSubscriptionManager.StartSubscription(webSocketCtx, protocolMessage, dappID, userIp, cwm.WebsocketConnectionUID, metricsData) if err != nil { utils.LavaFormatWarning("StartSubscription returned an error", err, utils.LogAttr("GUID", webSocketCtx), utils.LogAttr("dappID", dappID), utils.LogAttr("userIp", userIp), - utils.LogAttr("params", chainMessage.GetRPCMessage().GetParams()), + utils.LogAttr("params", protocolMessage.GetRPCMessage().GetParams()), ) formatterMsg := logger.AnalyzeWebSocketErrorAndGetFormattedMessage(websocketConn.LocalAddr().String(), utils.LavaFormatError("could not start subscription", err), msgSeed, msg, cwm.apiInterface, time.Since(startTime)) @@ -239,7 +239,7 @@ func (cwm *ConsumerWebsocketManager) ListenToMessages() { utils.LogAttr("GUID", webSocketCtx), utils.LogAttr("dappID", dappID), utils.LogAttr("userIp", userIp), - utils.LogAttr("params", chainMessage.GetRPCMessage().GetParams()), + utils.LogAttr("params", protocolMessage.GetRPCMessage().GetParams()), ) for subscriptionMsgReply := range subscriptionMsgsChan { @@ -250,7 +250,7 @@ func (cwm *ConsumerWebsocketManager) ListenToMessages() { utils.LogAttr("GUID", webSocketCtx), utils.LogAttr("dappID", dappID), utils.LogAttr("userIp", userIp), - utils.LogAttr("params", chainMessage.GetRPCMessage().GetParams()), + utils.LogAttr("params", protocolMessage.GetRPCMessage().GetParams()), ) }() } diff --git a/protocol/chainlib/consumer_ws_subscription_manager.go b/protocol/chainlib/consumer_ws_subscription_manager.go index 6b993588cd..dda1405573 100644 --- a/protocol/chainlib/consumer_ws_subscription_manager.go +++ b/protocol/chainlib/consumer_ws_subscription_manager.go @@ -19,9 +19,7 @@ import ( ) type unsubscribeRelayData struct { - chainMessage ChainMessage - directiveHeaders map[string]string - relayRequestData *pairingtypes.RelayPrivateData + protocolMessage ProtocolMessage } type activeSubscriptionHolder struct { @@ -186,15 +184,13 @@ func (cwsm *ConsumerWSSubscriptionManager) checkForActiveSubscriptionWithLock( func (cwsm *ConsumerWSSubscriptionManager) StartSubscription( webSocketCtx context.Context, - chainMessage ChainMessage, - directiveHeaders map[string]string, - relayRequestData *pairingtypes.RelayPrivateData, + protocolMessage ProtocolMessage, dappID string, consumerIp string, webSocketConnectionUniqueId string, metricsData *metrics.RelayMetrics, ) (firstReply *pairingtypes.RelayReply, repliesChan <-chan *pairingtypes.RelayReply, err error) { - hashedParams, _, err := cwsm.getHashedParams(chainMessage) + hashedParams, _, err := cwsm.getHashedParams(protocolMessage) if err != nil { return nil, nil, utils.LavaFormatError("could not marshal params", err) } @@ -229,7 +225,7 @@ func (cwsm *ConsumerWSSubscriptionManager) StartSubscription( <-closeWebsocketRepliesChan utils.LavaFormatTrace("requested to close websocketRepliesChan", utils.LogAttr("GUID", webSocketCtx), - utils.LogAttr("params", chainMessage.GetRPCMessage().GetParams()), + utils.LogAttr("params", protocolMessage.GetRPCMessage().GetParams()), utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)), utils.LogAttr("dappKey", dappKey), ) @@ -242,7 +238,7 @@ func (cwsm *ConsumerWSSubscriptionManager) StartSubscription( <-webSocketCtx.Done() utils.LavaFormatTrace("websocket context is done, removing websocket from active subscriptions", utils.LogAttr("GUID", webSocketCtx), - utils.LogAttr("params", chainMessage.GetRPCMessage().GetParams()), + utils.LogAttr("params", protocolMessage.GetRPCMessage().GetParams()), utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)), utils.LogAttr("dappKey", dappKey), ) @@ -258,7 +254,7 @@ func (cwsm *ConsumerWSSubscriptionManager) StartSubscription( }() // Validated there are no active subscriptions that we can use. - firstSubscriptionReply, returnWebsocketRepliesChan := cwsm.checkForActiveSubscriptionWithLock(webSocketCtx, hashedParams, chainMessage, dappKey, websocketRepliesSafeChannelSender, closeWebsocketRepliesChannel) + firstSubscriptionReply, returnWebsocketRepliesChan := cwsm.checkForActiveSubscriptionWithLock(webSocketCtx, hashedParams, protocolMessage, dappKey, websocketRepliesSafeChannelSender, closeWebsocketRepliesChannel) if firstSubscriptionReply != nil { if returnWebsocketRepliesChan { return firstSubscriptionReply, websocketRepliesChan, nil @@ -279,7 +275,7 @@ func (cwsm *ConsumerWSSubscriptionManager) StartSubscription( utils.LavaFormatTrace("Finished pending for subscription, have results", utils.LogAttr("success", res)) // Check res is valid, if not fall through logs and try again with a new client. if res { - firstSubscriptionReply, returnWebsocketRepliesChan := cwsm.checkForActiveSubscriptionWithLock(webSocketCtx, hashedParams, chainMessage, dappKey, websocketRepliesSafeChannelSender, closeWebsocketRepliesChannel) + firstSubscriptionReply, returnWebsocketRepliesChan := cwsm.checkForActiveSubscriptionWithLock(webSocketCtx, hashedParams, protocolMessage, dappKey, websocketRepliesSafeChannelSender, closeWebsocketRepliesChannel) if firstSubscriptionReply != nil { if returnWebsocketRepliesChan { return firstSubscriptionReply, websocketRepliesChan, nil @@ -300,12 +296,12 @@ func (cwsm *ConsumerWSSubscriptionManager) StartSubscription( utils.LavaFormatTrace("could not find active subscription for given params, creating new one", utils.LogAttr("GUID", webSocketCtx), - utils.LogAttr("params", chainMessage.GetRPCMessage().GetParams()), + utils.LogAttr("params", protocolMessage.GetRPCMessage().GetParams()), utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)), utils.LogAttr("dappKey", dappKey), ) - relayResult, err := cwsm.relaySender.SendParsedRelay(webSocketCtx, dappID, consumerIp, metricsData, chainMessage, directiveHeaders, relayRequestData) + relayResult, err := cwsm.relaySender.SendParsedRelay(webSocketCtx, dappID, consumerIp, metricsData, protocolMessage) if err != nil { onSubscriptionFailure() return nil, nil, utils.LavaFormatError("could not send subscription relay", err) @@ -313,7 +309,7 @@ func (cwsm *ConsumerWSSubscriptionManager) StartSubscription( utils.LavaFormatTrace("got relay result from SendParsedRelay", utils.LogAttr("GUID", webSocketCtx), - utils.LogAttr("params", chainMessage.GetRPCMessage().GetParams()), + utils.LogAttr("params", protocolMessage.GetRPCMessage().GetParams()), utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)), utils.LogAttr("dappKey", dappKey), utils.LogAttr("relayResult", relayResult), @@ -325,7 +321,7 @@ func (cwsm *ConsumerWSSubscriptionManager) StartSubscription( onSubscriptionFailure() return nil, nil, utils.LavaFormatError("reply server is nil, probably an error with the subscription initiation", nil, utils.LogAttr("GUID", webSocketCtx), - utils.LogAttr("params", chainMessage.GetRPCMessage().GetParams()), + utils.LogAttr("params", protocolMessage.GetRPCMessage().GetParams()), utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)), utils.LogAttr("dappKey", dappKey), ) @@ -336,7 +332,7 @@ func (cwsm *ConsumerWSSubscriptionManager) StartSubscription( onSubscriptionFailure() return nil, nil, utils.LavaFormatError("Reply data is nil", nil, utils.LogAttr("GUID", webSocketCtx), - utils.LogAttr("params", chainMessage.GetRPCMessage().GetParams()), + utils.LogAttr("params", protocolMessage.GetRPCMessage().GetParams()), utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)), utils.LogAttr("dappKey", dappKey), ) @@ -348,18 +344,18 @@ func (cwsm *ConsumerWSSubscriptionManager) StartSubscription( onSubscriptionFailure() return nil, nil, utils.LavaFormatError("could not copy relay request", err, utils.LogAttr("GUID", webSocketCtx), - utils.LogAttr("params", chainMessage.GetRPCMessage().GetParams()), + utils.LogAttr("params", protocolMessage.GetRPCMessage().GetParams()), utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)), utils.LogAttr("dappKey", dappKey), ) } - err = cwsm.verifySubscriptionMessage(hashedParams, chainMessage, relayResult.Request, &reply, relayResult.ProviderInfo.ProviderAddress) + err = cwsm.verifySubscriptionMessage(hashedParams, protocolMessage, relayResult.Request, &reply, relayResult.ProviderInfo.ProviderAddress) if err != nil { onSubscriptionFailure() return nil, nil, utils.LavaFormatError("Failed VerifyRelayReply on subscription message", err, utils.LogAttr("GUID", webSocketCtx), - utils.LogAttr("params", chainMessage.GetRPCMessage().GetParams()), + utils.LogAttr("params", protocolMessage.GetRPCMessage().GetParams()), utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)), utils.LogAttr("dappKey", dappKey), utils.LogAttr("reply", string(reply.Data)), @@ -373,7 +369,7 @@ func (cwsm *ConsumerWSSubscriptionManager) StartSubscription( onSubscriptionFailure() return nil, nil, utils.LavaFormatError("could not parse reply into json", err, utils.LogAttr("GUID", webSocketCtx), - utils.LogAttr("params", chainMessage.GetRPCMessage().GetParams()), + utils.LogAttr("params", protocolMessage.GetRPCMessage().GetParams()), utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)), utils.LogAttr("dappKey", dappKey), utils.LogAttr("reply", reply.Data), @@ -391,7 +387,7 @@ func (cwsm *ConsumerWSSubscriptionManager) StartSubscription( cwsm.lock.Lock() defer cwsm.lock.Unlock() - subscriptionId := chainMessage.SubscriptionIdExtractor(&replyJsonrpcMessage) + subscriptionId := protocolMessage.SubscriptionIdExtractor(&replyJsonrpcMessage) subscriptionId = common.UnSquareBracket(subscriptionId) if common.IsQuoted(subscriptionId) { subscriptionId, _ = strconv.Unquote(subscriptionId) @@ -404,7 +400,7 @@ func (cwsm *ConsumerWSSubscriptionManager) StartSubscription( firstSubscriptionReplyAsJsonrpcMessage: &replyJsonrpcMessage, replyServer: replyServer, subscriptionOriginalRequest: copiedRequest, - subscriptionOriginalRequestChainMessage: chainMessage, + subscriptionOriginalRequestChainMessage: protocolMessage, closeSubscriptionChan: closeSubscriptionChan, connectedDappKeys: map[string]struct{}{dappKey: {}}, subscriptionId: subscriptionId, @@ -458,9 +454,7 @@ func (cwsm *ConsumerWSSubscriptionManager) listenForSubscriptionMessages( // we run the unsubscribe flow in an inner function so it wont prevent us from removing the activeSubscriptions at the end. func() { var err error - var chainMessage ChainMessage - var directiveHeaders map[string]string - var relayRequestData *pairingtypes.RelayPrivateData + var protocolMessage ProtocolMessage if unsubscribeData != nil { // This unsubscribe request was initiated by the user utils.LavaFormatTrace("unsubscribe request was made by the user", @@ -468,9 +462,7 @@ func (cwsm *ConsumerWSSubscriptionManager) listenForSubscriptionMessages( utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)), ) - chainMessage = unsubscribeData.chainMessage - directiveHeaders = unsubscribeData.directiveHeaders - relayRequestData = unsubscribeData.relayRequestData + protocolMessage = unsubscribeData.protocolMessage } else { // This unsubscribe request was initiated by us utils.LavaFormatTrace("unsubscribe request was made automatically", @@ -478,13 +470,13 @@ func (cwsm *ConsumerWSSubscriptionManager) listenForSubscriptionMessages( utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)), ) - chainMessage, directiveHeaders, relayRequestData, err = cwsm.craftUnsubscribeMessage(hashedParams, dappID, userIp, metricsData) + protocolMessage, err = cwsm.craftUnsubscribeMessage(hashedParams, dappID, userIp, metricsData) if err != nil { utils.LavaFormatError("could not craft unsubscribe message", err, utils.LogAttr("GUID", webSocketCtx)) return } - stringJson, err := gojson.Marshal(chainMessage.GetRPCMessage()) + stringJson, err := gojson.Marshal(protocolMessage.GetRPCMessage()) if err != nil { utils.LavaFormatError("could not marshal chain message", err, utils.LogAttr("GUID", webSocketCtx)) return @@ -498,16 +490,16 @@ func (cwsm *ConsumerWSSubscriptionManager) listenForSubscriptionMessages( } unsubscribeRelayCtx := utils.WithUniqueIdentifier(context.Background(), utils.GenerateUniqueIdentifier()) - err = cwsm.sendUnsubscribeMessage(unsubscribeRelayCtx, dappID, userIp, chainMessage, directiveHeaders, relayRequestData, metricsData) + err = cwsm.sendUnsubscribeMessage(unsubscribeRelayCtx, dappID, userIp, protocolMessage, metricsData) if err != nil { utils.LavaFormatError("could not send unsubscribe message due to a relay error", err, utils.LogAttr("GUID", webSocketCtx), - utils.LogAttr("relayRequestData", relayRequestData), + utils.LogAttr("relayRequestData", protocolMessage.RelayPrivateData()), utils.LogAttr("dappID", dappID), utils.LogAttr("userIp", userIp), - utils.LogAttr("api", chainMessage.GetApi().Name), - utils.LogAttr("params", chainMessage.GetRPCMessage().GetParams()), + utils.LogAttr("api", protocolMessage.GetApi().Name), + utils.LogAttr("params", protocolMessage.GetRPCMessage().GetParams()), ) } else { utils.LavaFormatTrace("success sending unsubscribe message, deleting hashed params from activeSubscriptions", @@ -645,7 +637,7 @@ func (cwsm *ConsumerWSSubscriptionManager) getHashedParams(chainMessage ChainMes return hashedParams, params, nil } -func (cwsm *ConsumerWSSubscriptionManager) Unsubscribe(webSocketCtx context.Context, chainMessage ChainMessage, directiveHeaders map[string]string, relayRequestData *pairingtypes.RelayPrivateData, dappID, consumerIp string, webSocketConnectionUniqueId string, metricsData *metrics.RelayMetrics) error { +func (cwsm *ConsumerWSSubscriptionManager) Unsubscribe(webSocketCtx context.Context, protocolMessage ProtocolMessage, dappID, consumerIp string, webSocketConnectionUniqueId string, metricsData *metrics.RelayMetrics) error { utils.LavaFormatTrace("want to unsubscribe", utils.LogAttr("GUID", webSocketCtx), utils.LogAttr("dappID", dappID), @@ -657,16 +649,16 @@ func (cwsm *ConsumerWSSubscriptionManager) Unsubscribe(webSocketCtx context.Cont cwsm.lock.Lock() defer cwsm.lock.Unlock() - hashedParams, err := cwsm.findActiveSubscriptionHashedParamsFromChainMessage(chainMessage) + hashedParams, err := cwsm.findActiveSubscriptionHashedParamsFromChainMessage(protocolMessage) if err != nil { return err } return cwsm.verifyAndDisconnectDappFromSubscription(webSocketCtx, dappKey, hashedParams, func() (*unsubscribeRelayData, error) { - return &unsubscribeRelayData{chainMessage, directiveHeaders, relayRequestData}, nil + return &unsubscribeRelayData{protocolMessage}, nil }) } -func (cwsm *ConsumerWSSubscriptionManager) craftUnsubscribeMessage(hashedParams, dappID, consumerIp string, metricsData *metrics.RelayMetrics) (ChainMessage, map[string]string, *pairingtypes.RelayPrivateData, error) { +func (cwsm *ConsumerWSSubscriptionManager) craftUnsubscribeMessage(hashedParams, dappID, consumerIp string, metricsData *metrics.RelayMetrics) (ProtocolMessage, error) { request := cwsm.activeSubscriptions[hashedParams].subscriptionOriginalRequestChainMessage subscriptionId := cwsm.activeSubscriptions[hashedParams].subscriptionId @@ -682,14 +674,14 @@ func (cwsm *ConsumerWSSubscriptionManager) craftUnsubscribeMessage(hashedParams, } if !found { - return nil, nil, nil, utils.LavaFormatError("could not find unsubscribe parse directive for given chain message", nil, + return nil, utils.LavaFormatError("could not find unsubscribe parse directive for given chain message", nil, utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)), utils.LogAttr("subscriptionId", subscriptionId), ) } if unsubscribeRequestData == "" { - return nil, nil, nil, utils.LavaFormatError("unsubscribe request data is empty", nil, + return nil, utils.LavaFormatError("unsubscribe request data is empty", nil, utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)), utils.LogAttr("subscriptionId", subscriptionId), ) @@ -697,9 +689,9 @@ func (cwsm *ConsumerWSSubscriptionManager) craftUnsubscribeMessage(hashedParams, // Craft the unsubscribe chain message ctx := context.Background() - chainMessage, directiveHeaders, relayRequestData, err := cwsm.relaySender.ParseRelay(ctx, "", unsubscribeRequestData, cwsm.connectionType, dappID, consumerIp, metricsData, nil) + protocolMessage, err := cwsm.relaySender.ParseRelay(ctx, "", unsubscribeRequestData, cwsm.connectionType, dappID, consumerIp, metricsData, nil) if err != nil { - return nil, nil, nil, utils.LavaFormatError("could not craft unsubscribe chain message", err, + return nil, utils.LavaFormatError("could not craft unsubscribe chain message", err, utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)), utils.LogAttr("subscriptionId", subscriptionId), utils.LogAttr("unsubscribeRequestData", unsubscribeRequestData), @@ -707,10 +699,10 @@ func (cwsm *ConsumerWSSubscriptionManager) craftUnsubscribeMessage(hashedParams, ) } - return chainMessage, directiveHeaders, relayRequestData, nil + return protocolMessage, nil } -func (cwsm *ConsumerWSSubscriptionManager) sendUnsubscribeMessage(ctx context.Context, dappID, consumerIp string, chainMessage ChainMessage, directiveHeaders map[string]string, relayRequestData *pairingtypes.RelayPrivateData, metricsData *metrics.RelayMetrics) error { +func (cwsm *ConsumerWSSubscriptionManager) sendUnsubscribeMessage(ctx context.Context, dappID, consumerIp string, protocolMessage ProtocolMessage, metricsData *metrics.RelayMetrics) error { // Send the crafted unsubscribe relay utils.LavaFormatTrace("sending unsubscribe relay", utils.LogAttr("GUID", ctx), @@ -718,7 +710,7 @@ func (cwsm *ConsumerWSSubscriptionManager) sendUnsubscribeMessage(ctx context.Co utils.LogAttr("consumerIp", consumerIp), ) - _, err := cwsm.relaySender.SendParsedRelay(ctx, dappID, consumerIp, metricsData, chainMessage, directiveHeaders, relayRequestData) + _, err := cwsm.relaySender.SendParsedRelay(ctx, dappID, consumerIp, metricsData, protocolMessage) if err != nil { return utils.LavaFormatError("could not send unsubscribe relay", err) } @@ -775,12 +767,12 @@ func (cwsm *ConsumerWSSubscriptionManager) UnsubscribeAll(webSocketCtx context.C ) unsubscribeRelayGetter := func() (*unsubscribeRelayData, error) { - chainMessage, directiveHeaders, relayRequestData, err := cwsm.craftUnsubscribeMessage(hashedParams, dappID, consumerIp, metricsData) + protocolMessage, err := cwsm.craftUnsubscribeMessage(hashedParams, dappID, consumerIp, metricsData) if err != nil { return nil, err } - return &unsubscribeRelayData{chainMessage, directiveHeaders, relayRequestData}, nil + return &unsubscribeRelayData{protocolMessage}, nil } cwsm.verifyAndDisconnectDappFromSubscription(webSocketCtx, dappKey, hashedParams, unsubscribeRelayGetter) diff --git a/protocol/chainlib/consumer_ws_subscription_manager_test.go b/protocol/chainlib/consumer_ws_subscription_manager_test.go index ed79e1bea5..81a59fea87 100644 --- a/protocol/chainlib/consumer_ws_subscription_manager_test.go +++ b/protocol/chainlib/consumer_ws_subscription_manager_test.go @@ -8,6 +8,7 @@ import ( "testing" "time" + gomock "github.com/golang/mock/gomock" "github.com/lavanet/lava/v2/protocol/chainlib/extensionslib" "github.com/lavanet/lava/v2/protocol/common" "github.com/lavanet/lava/v2/protocol/lavaprotocol" @@ -20,7 +21,7 @@ import ( spectypes "github.com/lavanet/lava/v2/x/spec/types" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - gomock "go.uber.org/mock/gomock" + gomockuber "go.uber.org/mock/gomock" ) const ( @@ -66,7 +67,7 @@ func TestConsumerWSSubscriptionManagerParallelSubscriptionsOnSameDappIdIp(t *tes chainMessage1, err := chainParser.ParseMsg("", play.subscriptionRequestData1, play.connectionType, nil, extensionslib.ExtensionInfo{LatestBlock: 0}) require.NoError(t, err) - + protocolMessage1 := NewProtocolMessage(chainMessage1, nil, nil) relaySender := NewMockRelaySender(ctrl) relaySender. EXPECT(). @@ -83,7 +84,7 @@ func TestConsumerWSSubscriptionManagerParallelSubscriptionsOnSameDappIdIp(t *tes relaySender. EXPECT(). ParseRelay(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - Return(chainMessage1, nil, nil, nil). + Return(protocolMessage1, nil). AnyTimes() mockRelayerClient1 := pairingtypes.NewMockRelayer_RelaySubscribeClient(ctrl) @@ -128,7 +129,7 @@ func TestConsumerWSSubscriptionManagerParallelSubscriptionsOnSameDappIdIp(t *tes relaySender. EXPECT(). - SendParsedRelay(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + SendParsedRelay(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). Return(relayResult1, nil). Times(1) // Should call SendParsedRelay, because it is the first time we subscribe @@ -139,6 +140,7 @@ func TestConsumerWSSubscriptionManagerParallelSubscriptionsOnSameDappIdIp(t *tes uniqueIdentifiers := make([]string, numberOfParallelSubscriptions) wg := sync.WaitGroup{} wg.Add(numberOfParallelSubscriptions) + // Start a new subscription for the first time, called SendParsedRelay once while in parallel calling 10 times subscribe with the same message // expected result is to have SendParsedRelay only once and 9 other messages waiting the broadcast. for i := 0; i < numberOfParallelSubscriptions; i++ { @@ -148,7 +150,8 @@ func TestConsumerWSSubscriptionManagerParallelSubscriptionsOnSameDappIdIp(t *tes ctx := utils.WithUniqueIdentifier(ts.Ctx, utils.GenerateUniqueIdentifier()) var repliesChan <-chan *pairingtypes.RelayReply var firstReply *pairingtypes.RelayReply - firstReply, repliesChan, err = manager.StartSubscription(ctx, chainMessage1, nil, nil, dapp, ip, uniqueIdentifiers[index], nil) + + firstReply, repliesChan, err = manager.StartSubscription(ctx, protocolMessage1, dapp, ip, uniqueIdentifiers[index], nil) go func() { for subMsg := range repliesChan { // utils.LavaFormatInfo("got reply for index", utils.LogAttr("index", index)) @@ -166,7 +169,7 @@ func TestConsumerWSSubscriptionManagerParallelSubscriptionsOnSameDappIdIp(t *tes // now we have numberOfParallelSubscriptions subscriptions currently running require.Len(t, manager.connectedDapps, numberOfParallelSubscriptions) // remove one - err = manager.Unsubscribe(ts.Ctx, chainMessage1, nil, nil, dapp, ip, uniqueIdentifiers[0], nil) + err = manager.Unsubscribe(ts.Ctx, protocolMessage1, dapp, ip, uniqueIdentifiers[0], nil) require.NoError(t, err) // now we have numberOfParallelSubscriptions - 1 require.Len(t, manager.connectedDapps, numberOfParallelSubscriptions-1) @@ -221,7 +224,7 @@ func TestConsumerWSSubscriptionManagerParallelSubscriptions(t *testing.T) { chainMessage1, err := chainParser.ParseMsg("", play.subscriptionRequestData1, play.connectionType, nil, extensionslib.ExtensionInfo{LatestBlock: 0}) require.NoError(t, err) - + protocolMessage1 := NewProtocolMessage(chainMessage1, nil, nil) relaySender := NewMockRelaySender(ctrl) relaySender. EXPECT(). @@ -238,7 +241,7 @@ func TestConsumerWSSubscriptionManagerParallelSubscriptions(t *testing.T) { relaySender. EXPECT(). ParseRelay(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - Return(chainMessage1, nil, nil, nil). + Return(protocolMessage1, nil). AnyTimes() mockRelayerClient1 := pairingtypes.NewMockRelayer_RelaySubscribeClient(ctrl) @@ -283,7 +286,7 @@ func TestConsumerWSSubscriptionManagerParallelSubscriptions(t *testing.T) { relaySender. EXPECT(). - SendParsedRelay(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + SendParsedRelay(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). Return(relayResult1, nil). Times(1) // Should call SendParsedRelay, because it is the first time we subscribe @@ -302,7 +305,7 @@ func TestConsumerWSSubscriptionManagerParallelSubscriptions(t *testing.T) { ctx := utils.WithUniqueIdentifier(ts.Ctx, utils.GenerateUniqueIdentifier()) var repliesChan <-chan *pairingtypes.RelayReply var firstReply *pairingtypes.RelayReply - firstReply, repliesChan, err = manager.StartSubscription(ctx, chainMessage1, nil, nil, dapp+strconv.Itoa(index), ts.Consumer.Addr.String(), uniqueId, nil) + firstReply, repliesChan, err = manager.StartSubscription(ctx, protocolMessage1, dapp+strconv.Itoa(index), ts.Consumer.Addr.String(), uniqueId, nil) go func() { for subMsg := range repliesChan { require.Equal(t, string(play.subscriptionFirstReply1), string(subMsg.Data)) @@ -427,16 +430,21 @@ func TestConsumerWSSubscriptionManager(t *testing.T) { subscribeChainMessage1, err := chainParser.ParseMsg("", play.subscriptionRequestData1, play.connectionType, nil, extensionslib.ExtensionInfo{LatestBlock: 0}) require.NoError(t, err) + subscribeProtocolMessage1 := NewProtocolMessage(subscribeChainMessage1, nil, nil) + unsubscribeProtocolMessage1 := NewProtocolMessage(unsubscribeChainMessage1, nil, &pairingtypes.RelayPrivateData{ + Data: play.unsubscribeMessage1, + }) relaySender := NewMockRelaySender(ctrl) relaySender. EXPECT(). - SendParsedRelay(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Cond(func(x any) bool { - relayPrivateData, ok := x.(*pairingtypes.RelayPrivateData) - if !ok || relayPrivateData == nil { + SendParsedRelay(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomockuber.Cond(func(x any) bool { + protocolMsg, ok := x.(ProtocolMessage) + require.True(t, ok) + require.NotNil(t, protocolMsg) + if protocolMsg.RelayPrivateData() == nil { return false } - - if strings.Contains(string(relayPrivateData.Data), "unsubscribe") { + if strings.Contains(string(protocolMsg.RelayPrivateData().Data), "unsubscribe") { unsubscribeMessageWg.Done() } @@ -460,26 +468,24 @@ func TestConsumerWSSubscriptionManager(t *testing.T) { relaySender. EXPECT(). - ParseRelay(gomock.Any(), gomock.Any(), gomock.Cond(func(x any) bool { + ParseRelay(gomock.Any(), gomock.Any(), gomockuber.Cond(func(x any) bool { reqData, ok := x.(string) require.True(t, ok) areEqual := reqData == string(play.unsubscribeMessage1) return areEqual }), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - Return(unsubscribeChainMessage1, nil, &pairingtypes.RelayPrivateData{ - Data: play.unsubscribeMessage1, - }, nil). + Return(unsubscribeProtocolMessage1, nil). AnyTimes() relaySender. EXPECT(). - ParseRelay(gomock.Any(), gomock.Any(), gomock.Cond(func(x any) bool { + ParseRelay(gomock.Any(), gomock.Any(), gomockuber.Cond(func(x any) bool { reqData, ok := x.(string) require.True(t, ok) areEqual := reqData == string(play.subscriptionRequestData1) return areEqual }), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - Return(subscribeChainMessage1, nil, nil, nil). + Return(subscribeProtocolMessage1, nil). AnyTimes() mockRelayerClient1 := pairingtypes.NewMockRelayer_RelaySubscribeClient(ctrl) @@ -524,7 +530,7 @@ func TestConsumerWSSubscriptionManager(t *testing.T) { relaySender. EXPECT(). - SendParsedRelay(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + SendParsedRelay(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). Return(relayResult1, nil). Times(1) // Should call SendParsedRelay, because it is the first time we subscribe @@ -535,7 +541,8 @@ func TestConsumerWSSubscriptionManager(t *testing.T) { // Start a new subscription for the first time, called SendParsedRelay once ctx = utils.WithUniqueIdentifier(ctx, utils.GenerateUniqueIdentifier()) - firstReply, repliesChan1, err := manager.StartSubscription(ctx, subscribeChainMessage1, nil, nil, dapp1, ts.Consumer.Addr.String(), uniqueId, nil) + + firstReply, repliesChan1, err := manager.StartSubscription(ctx, subscribeProtocolMessage1, dapp1, ts.Consumer.Addr.String(), uniqueId, nil) assert.NoError(t, err) unsubscribeMessageWg.Add(1) assert.Equal(t, string(play.subscriptionFirstReply1), string(firstReply.Data)) @@ -545,13 +552,13 @@ func TestConsumerWSSubscriptionManager(t *testing.T) { relaySender. EXPECT(). - SendParsedRelay(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + SendParsedRelay(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). Return(relayResult1, nil). Times(0) // Should not call SendParsedRelay, because it is already subscribed // Start a subscription again, same params, same dappKey, should not call SendParsedRelay ctx = utils.WithUniqueIdentifier(ctx, utils.GenerateUniqueIdentifier()) - firstReply, repliesChan2, err := manager.StartSubscription(ctx, subscribeChainMessage1, nil, nil, dapp1, ts.Consumer.Addr.String(), uniqueId, nil) + firstReply, repliesChan2, err := manager.StartSubscription(ctx, subscribeProtocolMessage1, dapp1, ts.Consumer.Addr.String(), uniqueId, nil) assert.NoError(t, err) assert.Equal(t, string(play.subscriptionFirstReply1), string(firstReply.Data)) assert.Nil(t, repliesChan2) // Same subscription, same dappKey, no need for a new channel @@ -560,7 +567,7 @@ func TestConsumerWSSubscriptionManager(t *testing.T) { // Start a subscription again, same params, different dappKey, should not call SendParsedRelay ctx = utils.WithUniqueIdentifier(ctx, utils.GenerateUniqueIdentifier()) - firstReply, repliesChan3, err := manager.StartSubscription(ctx, subscribeChainMessage1, nil, nil, dapp2, ts.Consumer.Addr.String(), uniqueId, nil) + firstReply, repliesChan3, err := manager.StartSubscription(ctx, subscribeProtocolMessage1, dapp2, ts.Consumer.Addr.String(), uniqueId, nil) assert.NoError(t, err) assert.Equal(t, string(play.subscriptionFirstReply1), string(firstReply.Data)) assert.NotNil(t, repliesChan3) // Same subscription, but different dappKey, so will create new channel @@ -571,32 +578,30 @@ func TestConsumerWSSubscriptionManager(t *testing.T) { // Prepare for the next subscription unsubscribeChainMessage2, err := chainParser.ParseMsg("", play.unsubscribeMessage2, play.connectionType, nil, extensionslib.ExtensionInfo{LatestBlock: 0}) require.NoError(t, err) - + unsubscribeProtocolMessage2 := NewProtocolMessage(unsubscribeChainMessage2, nil, &pairingtypes.RelayPrivateData{Data: play.unsubscribeMessage2}) subscribeChainMessage2, err := chainParser.ParseMsg("", play.subscriptionRequestData2, play.connectionType, nil, extensionslib.ExtensionInfo{LatestBlock: 0}) require.NoError(t, err) - + subscribeProtocolMessage2 := NewProtocolMessage(subscribeChainMessage2, nil, nil) relaySender. EXPECT(). - ParseRelay(gomock.Any(), gomock.Any(), gomock.Cond(func(x any) bool { + ParseRelay(gomock.Any(), gomock.Any(), gomockuber.Cond(func(x any) bool { reqData, ok := x.(string) require.True(t, ok) areEqual := reqData == string(play.unsubscribeMessage2) return areEqual }), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - Return(unsubscribeChainMessage2, nil, &pairingtypes.RelayPrivateData{ - Data: play.unsubscribeMessage2, - }, nil). + Return(unsubscribeProtocolMessage2, nil). AnyTimes() relaySender. EXPECT(). - ParseRelay(gomock.Any(), gomock.Any(), gomock.Cond(func(x any) bool { + ParseRelay(gomock.Any(), gomock.Any(), gomockuber.Cond(func(x any) bool { reqData, ok := x.(string) require.True(t, ok) areEqual := reqData == string(play.subscriptionRequestData2) return areEqual }), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - Return(subscribeChainMessage2, nil, nil, nil). + Return(subscribeProtocolMessage2, nil). AnyTimes() mockRelayerClient2 := pairingtypes.NewMockRelayer_RelaySubscribeClient(ctrl) @@ -639,13 +644,14 @@ func TestConsumerWSSubscriptionManager(t *testing.T) { relaySender. EXPECT(). - SendParsedRelay(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + SendParsedRelay(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). Return(relayResult2, nil). Times(1) // Should call SendParsedRelay, because it is the first time we subscribe // Start a subscription again, different params, same dappKey, should call SendParsedRelay ctx = utils.WithUniqueIdentifier(ctx, utils.GenerateUniqueIdentifier()) - firstReply, repliesChan4, err := manager.StartSubscription(ctx, subscribeChainMessage2, nil, nil, dapp1, ts.Consumer.Addr.String(), uniqueId, nil) + + firstReply, repliesChan4, err := manager.StartSubscription(ctx, subscribeProtocolMessage2, dapp1, ts.Consumer.Addr.String(), uniqueId, nil) assert.NoError(t, err) unsubscribeMessageWg.Add(1) assert.Equal(t, string(play.subscriptionFirstReply2), string(firstReply.Data)) @@ -658,12 +664,13 @@ func TestConsumerWSSubscriptionManager(t *testing.T) { // Prepare for unsubscribe from the first subscription relaySender. EXPECT(). - SendParsedRelay(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + SendParsedRelay(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). Return(relayResult1, nil). Times(0) // Should call SendParsedRelay, because it unsubscribed ctx = utils.WithUniqueIdentifier(ctx, utils.GenerateUniqueIdentifier()) - err = manager.Unsubscribe(ctx, unsubscribeChainMessage1, nil, relayResult1.Request.RelayData, dapp2, ts.Consumer.Addr.String(), uniqueId, nil) + unsubProtocolMessage := NewProtocolMessage(unsubscribeChainMessage1, nil, relayResult1.Request.RelayData) + err = manager.Unsubscribe(ctx, unsubProtocolMessage, dapp2, ts.Consumer.Addr.String(), uniqueId, nil) require.NoError(t, err) listenForExpectedMessages(ctx, repliesChan1, string(play.subscriptionFirstReply1)) @@ -681,8 +688,8 @@ func TestConsumerWSSubscriptionManager(t *testing.T) { // Prepare for unsubscribe from the second subscription relaySender. EXPECT(). - SendParsedRelay(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, dappID string, consumerIp string, analytics *metrics.RelayMetrics, chainMessage ChainMessage, directiveHeaders map[string]string, relayRequestData *pairingtypes.RelayPrivateData) (relayResult *common.RelayResult, errRet error) { + SendParsedRelay(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, dappID string, consumerIp string, analytics *metrics.RelayMetrics, protocolMessage ProtocolMessage) (relayResult *common.RelayResult, errRet error) { wg.Done() return relayResult2, nil }). diff --git a/protocol/chainlib/protocol_message.go b/protocol/chainlib/protocol_message.go new file mode 100644 index 0000000000..c9ed2ea01d --- /dev/null +++ b/protocol/chainlib/protocol_message.go @@ -0,0 +1,53 @@ +package chainlib + +import ( + "strings" + + "github.com/lavanet/lava/v2/protocol/common" + pairingtypes "github.com/lavanet/lava/v2/x/pairing/types" +) + +type BaseProtocolMessage struct { + ChainMessage + directiveHeaders map[string]string + relayRequestData *pairingtypes.RelayPrivateData +} + +func (bpm *BaseProtocolMessage) GetDirectiveHeaders() map[string]string { + return bpm.directiveHeaders +} + +func (bpm *BaseProtocolMessage) RelayPrivateData() *pairingtypes.RelayPrivateData { + return bpm.relayRequestData +} + +func (bpm *BaseProtocolMessage) HashCacheRequest(chainId string) ([]byte, func([]byte) []byte, error) { + return HashCacheRequest(bpm.relayRequestData, chainId) +} + +func (bpm *BaseProtocolMessage) GetBlockedProviders() []string { + if bpm.directiveHeaders == nil { + return nil + } + blockedProviders, ok := bpm.directiveHeaders[common.BLOCK_PROVIDERS_ADDRESSES_HEADER_NAME] + if ok { + return strings.Split(blockedProviders, ",") + } + return nil +} + +func NewProtocolMessage(chainMessage ChainMessage, directiveHeaders map[string]string, relayRequestData *pairingtypes.RelayPrivateData) ProtocolMessage { + return &BaseProtocolMessage{ + ChainMessage: chainMessage, + directiveHeaders: directiveHeaders, + relayRequestData: relayRequestData, + } +} + +type ProtocolMessage interface { + ChainMessage + GetDirectiveHeaders() map[string]string + RelayPrivateData() *pairingtypes.RelayPrivateData + HashCacheRequest(chainId string) ([]byte, func([]byte) []byte, error) + GetBlockedProviders() []string +} diff --git a/protocol/lavasession/consumer_session_manager_test.go b/protocol/lavasession/consumer_session_manager_test.go index 9bdf1b7fcf..ad44b5010e 100644 --- a/protocol/lavasession/consumer_session_manager_test.go +++ b/protocol/lavasession/consumer_session_manager_test.go @@ -8,6 +8,7 @@ import ( "net" "os" "strconv" + "strings" "testing" "time" @@ -221,6 +222,21 @@ func createGRPCServer(changeListener string, probeDelay time.Duration) error { const providerStr = "provider" +type DirectiveHeaders struct { + directiveHeaders map[string]string +} + +func (bpm DirectiveHeaders) GetBlockedProviders() []string { + if bpm.directiveHeaders == nil { + return nil + } + blockedProviders, ok := bpm.directiveHeaders[common.BLOCK_PROVIDERS_ADDRESSES_HEADER_NAME] + if ok { + return strings.Split(blockedProviders, ",") + } + return nil +} + func createPairingList(providerPrefixAddress string, enabled bool) map[uint64]*ConsumerSessionsWithProvider { cswpList := make(map[uint64]*ConsumerSessionsWithProvider, 0) pairingEndpoints := make([]*Endpoint, 1) @@ -322,7 +338,9 @@ func TestSecondChanceRecoveryFlow(t *testing.T) { timeLimit := time.Second * 30 loopStartTime := time.Now() for { - usedProviders := NewUsedProviders(map[string]string{"lava-providers-block": pairingList[1].PublicLavaAddress}) + // implement a struct that returns: map[string]string{"lava-providers-block": pairingList[1].PublicLavaAddress} in the implementation for the DirectiveHeadersInf interface + directiveHeaders := DirectiveHeaders{map[string]string{"lava-providers-block": pairingList[1].PublicLavaAddress}} + usedProviders := NewUsedProviders(directiveHeaders) css, err := csm.GetSessions(ctx, cuForFirstRequest, usedProviders, servicedBlockNumber, "", nil, common.NO_STATE, 0) // get a session require.NoError(t, err) _, expectedProviderAddress := css[pairingList[0].PublicLavaAddress] @@ -372,7 +390,8 @@ func TestSecondChanceRecoveryFlow(t *testing.T) { loopStartTime = time.Now() for { utils.LavaFormatDebug("Test", utils.LogAttr("csm.validAddresses", csm.validAddresses), utils.LogAttr("csm.currentlyBlockedProviderAddresses", csm.currentlyBlockedProviderAddresses), utils.LogAttr("csm.pairing[pairingList[0].PublicLavaAddress].blockedAndUsedWithChanceForRecoveryStatus", csm.pairing[pairingList[0].PublicLavaAddress].blockedAndUsedWithChanceForRecoveryStatus)) - usedProviders := NewUsedProviders(map[string]string{"lava-providers-block": pairingList[1].PublicLavaAddress}) + directiveHeaders := DirectiveHeaders{map[string]string{"lava-providers-block": pairingList[1].PublicLavaAddress}} + usedProviders := NewUsedProviders(directiveHeaders) require.Equal(t, BlockedProviderSessionUnusedStatus, csm.pairing[pairingList[0].PublicLavaAddress].blockedAndUsedWithChanceForRecoveryStatus) css, err := csm.GetSessions(ctx, cuForFirstRequest, usedProviders, servicedBlockNumber, "", nil, common.NO_STATE, 0) // get a session require.Equal(t, BlockedProviderSessionUnusedStatus, csm.pairing[pairingList[0].PublicLavaAddress].blockedAndUsedWithChanceForRecoveryStatus) diff --git a/protocol/lavasession/used_providers.go b/protocol/lavasession/used_providers.go index 854e4823ac..b1d72de953 100644 --- a/protocol/lavasession/used_providers.go +++ b/protocol/lavasession/used_providers.go @@ -2,22 +2,23 @@ package lavasession import ( "context" - "strings" "sync" "time" - "github.com/lavanet/lava/v2/protocol/common" "github.com/lavanet/lava/v2/utils" ) const MaximumNumberOfSelectionLockAttempts = 500 -func NewUsedProviders(directiveHeaders map[string]string) *UsedProviders { +type BlockedProvidersInf interface { + GetBlockedProviders() []string +} + +func NewUsedProviders(blockedProviders BlockedProvidersInf) *UsedProviders { unwantedProviders := map[string]struct{}{} - if len(directiveHeaders) > 0 { - blockedProviders, ok := directiveHeaders[common.BLOCK_PROVIDERS_ADDRESSES_HEADER_NAME] - if ok { - providerAddressesToBlock := strings.Split(blockedProviders, ",") + if blockedProviders != nil { + providerAddressesToBlock := blockedProviders.GetBlockedProviders() + if len(providerAddressesToBlock) > 0 { for _, providerAddress := range providerAddressesToBlock { unwantedProviders[providerAddress] = struct{}{} } diff --git a/protocol/rpcconsumer/rpcconsumer_server.go b/protocol/rpcconsumer/rpcconsumer_server.go index f00af0206f..595038b5a7 100644 --- a/protocol/rpcconsumer/rpcconsumer_server.go +++ b/protocol/rpcconsumer/rpcconsumer_server.go @@ -223,16 +223,16 @@ func (rpccs *RPCConsumerServer) craftRelay(ctx context.Context) (ok bool, relay return } -func (rpccs *RPCConsumerServer) sendRelayWithRetries(ctx context.Context, retries int, initialRelays bool, relay *pairingtypes.RelayPrivateData, chainMessage chainlib.ChainMessage) (bool, error) { +func (rpccs *RPCConsumerServer) sendRelayWithRetries(ctx context.Context, retries int, initialRelays bool, protocolMessage chainlib.ProtocolMessage) (bool, error) { success := false var err error - relayProcessor := NewRelayProcessor(ctx, lavasession.NewUsedProviders(nil), 1, chainMessage, rpccs.consumerConsistency, "-init-", "", rpccs.debugRelays, rpccs.rpcConsumerLogs, rpccs, rpccs.disableNodeErrorRetry, rpccs.relayRetriesManager) + relayProcessor := NewRelayProcessor(ctx, lavasession.NewUsedProviders(nil), 1, protocolMessage, rpccs.consumerConsistency, "-init-", "", rpccs.debugRelays, rpccs.rpcConsumerLogs, rpccs, rpccs.disableNodeErrorRetry, rpccs.relayRetriesManager) for i := 0; i < retries; i++ { - err = rpccs.sendRelayToProvider(ctx, chainMessage, relay, "-init-", "", relayProcessor, nil) + err = rpccs.sendRelayToProvider(ctx, protocolMessage, "-init-", "", relayProcessor, nil) if lavasession.PairingListEmptyError.Is(err) { // we don't have pairings anymore, could be related to unwanted providers relayProcessor.GetUsedProviders().ClearUnwanted() - err = rpccs.sendRelayToProvider(ctx, chainMessage, relay, "-init-", "", relayProcessor, nil) + err = rpccs.sendRelayToProvider(ctx, protocolMessage, "-init-", "", relayProcessor, nil) } if err != nil { utils.LavaFormatError("[-] failed sending init relay", err, []utils.Attribute{{Key: "chainID", Value: rpccs.listenEndpoint.ChainID}, {Key: "APIInterface", Value: rpccs.listenEndpoint.ApiInterface}, {Key: "relayProcessor", Value: relayProcessor}}...) @@ -285,8 +285,8 @@ func (rpccs *RPCConsumerServer) sendCraftedRelays(retries int, initialRelays boo } return false, err } - - return rpccs.sendRelayWithRetries(ctx, retries, initialRelays, relay, chainMessage) + protocolMessage := chainlib.NewProtocolMessage(chainMessage, nil, relay) + return rpccs.sendRelayWithRetries(ctx, retries, initialRelays, protocolMessage) } func (rpccs *RPCConsumerServer) getLatestBlock() uint64 { @@ -308,12 +308,12 @@ func (rpccs *RPCConsumerServer) SendRelay( analytics *metrics.RelayMetrics, metadata []pairingtypes.Metadata, ) (relayResult *common.RelayResult, errRet error) { - chainMessage, directiveHeaders, relayRequestData, err := rpccs.ParseRelay(ctx, url, req, connectionType, dappID, consumerIp, analytics, metadata) + protocolMessage, err := rpccs.ParseRelay(ctx, url, req, connectionType, dappID, consumerIp, analytics, metadata) if err != nil { return nil, err } - return rpccs.SendParsedRelay(ctx, dappID, consumerIp, analytics, chainMessage, directiveHeaders, relayRequestData) + return rpccs.SendParsedRelay(ctx, dappID, consumerIp, analytics, protocolMessage) } func (rpccs *RPCConsumerServer) ParseRelay( @@ -325,16 +325,16 @@ func (rpccs *RPCConsumerServer) ParseRelay( consumerIp string, analytics *metrics.RelayMetrics, metadata []pairingtypes.Metadata, -) (chainMessage chainlib.ChainMessage, directiveHeaders map[string]string, relayRequestData *pairingtypes.RelayPrivateData, err error) { +) (protocolMessage chainlib.ProtocolMessage, err error) { // gets the relay request data from the ChainListener // parses the request into an APIMessage, and validating it corresponds to the spec currently in use // construct the common data for a relay message, common data is identical across multiple sends and data reliability // remove lava directive headers - metadata, directiveHeaders = rpccs.LavaDirectiveHeaders(metadata) - chainMessage, err = rpccs.chainParser.ParseMsg(url, []byte(req), connectionType, metadata, rpccs.getExtensionsFromDirectiveHeaders(directiveHeaders)) + metadata, directiveHeaders := rpccs.LavaDirectiveHeaders(metadata) + chainMessage, err := rpccs.chainParser.ParseMsg(url, []byte(req), connectionType, metadata, rpccs.getExtensionsFromDirectiveHeaders(directiveHeaders)) if err != nil { - return nil, nil, nil, err + return nil, err } rpccs.HandleDirectiveHeadersForMessage(chainMessage, directiveHeaders) @@ -346,8 +346,9 @@ func (rpccs *RPCConsumerServer) ParseRelay( seenBlock = 0 } - relayRequestData = lavaprotocol.NewRelayData(ctx, connectionType, url, []byte(req), seenBlock, reqBlock, rpccs.listenEndpoint.ApiInterface, chainMessage.GetRPCMessage().GetHeaders(), chainlib.GetAddon(chainMessage), common.GetExtensionNames(chainMessage.GetExtensions())) - return chainMessage, directiveHeaders, relayRequestData, nil + relayRequestData := lavaprotocol.NewRelayData(ctx, connectionType, url, []byte(req), seenBlock, reqBlock, rpccs.listenEndpoint.ApiInterface, chainMessage.GetRPCMessage().GetHeaders(), chainlib.GetAddon(chainMessage), common.GetExtensionNames(chainMessage.GetExtensions())) + protocolMessage = chainlib.NewProtocolMessage(chainMessage, directiveHeaders, relayRequestData) + return protocolMessage, nil } func (rpccs *RPCConsumerServer) SendParsedRelay( @@ -355,9 +356,7 @@ func (rpccs *RPCConsumerServer) SendParsedRelay( dappID string, consumerIp string, analytics *metrics.RelayMetrics, - chainMessage chainlib.ChainMessage, - directiveHeaders map[string]string, - relayRequestData *pairingtypes.RelayPrivateData, + protocolMessage chainlib.ProtocolMessage, ) (relayResult *common.RelayResult, errRet error) { // sends a relay message to a provider // compares the result with other providers if defined so @@ -365,7 +364,7 @@ func (rpccs *RPCConsumerServer) SendParsedRelay( // asynchronously sends data reliability if necessary relaySentTime := time.Now() - relayProcessor, err := rpccs.ProcessRelaySend(ctx, directiveHeaders, chainMessage, relayRequestData, dappID, consumerIp, analytics) + relayProcessor, err := rpccs.ProcessRelaySend(ctx, protocolMessage, dappID, consumerIp, analytics) if err != nil && !relayProcessor.HasResults() { // we can't send anymore, and we don't have any responses utils.LavaFormatError("failed getting responses from providers", err, utils.Attribute{Key: "GUID", Value: ctx}, utils.LogAttr("endpoint", rpccs.listenEndpoint.Key()), utils.LogAttr("userIp", consumerIp), utils.LogAttr("relayProcessor", relayProcessor)) @@ -383,11 +382,11 @@ func (rpccs *RPCConsumerServer) SendParsedRelay( if found { dataReliabilityContext = utils.WithUniqueIdentifier(dataReliabilityContext, guid) } - go rpccs.sendDataReliabilityRelayIfApplicable(dataReliabilityContext, dappID, consumerIp, chainMessage, dataReliabilityThreshold, relayProcessor) // runs asynchronously + go rpccs.sendDataReliabilityRelayIfApplicable(dataReliabilityContext, dappID, consumerIp, protocolMessage, dataReliabilityThreshold, relayProcessor) // runs asynchronously } returnedResult, err := relayProcessor.ProcessingResult() - rpccs.appendHeadersToRelayResult(ctx, returnedResult, relayProcessor.ProtocolErrors(), relayProcessor, directiveHeaders) + rpccs.appendHeadersToRelayResult(ctx, returnedResult, relayProcessor.ProtocolErrors(), relayProcessor, protocolMessage.GetDirectiveHeaders()) if err != nil { return returnedResult, utils.LavaFormatError("failed processing responses from providers", err, utils.Attribute{Key: "GUID", Value: ctx}, utils.LogAttr("endpoint", rpccs.listenEndpoint.Key())) } @@ -395,7 +394,7 @@ func (rpccs *RPCConsumerServer) SendParsedRelay( if analytics != nil { currentLatency := time.Since(relaySentTime) analytics.Latency = currentLatency.Milliseconds() - api := chainMessage.GetApi() + api := protocolMessage.GetApi() analytics.ComputeUnits = api.ComputeUnits analytics.ApiMethod = api.Name } @@ -407,11 +406,11 @@ func (rpccs *RPCConsumerServer) GetChainIdAndApiInterface() (string, string) { return rpccs.listenEndpoint.ChainID, rpccs.listenEndpoint.ApiInterface } -func (rpccs *RPCConsumerServer) ProcessRelaySend(ctx context.Context, directiveHeaders map[string]string, chainMessage chainlib.ChainMessage, relayRequestData *pairingtypes.RelayPrivateData, dappID string, consumerIp string, analytics *metrics.RelayMetrics) (*RelayProcessor, error) { +func (rpccs *RPCConsumerServer) ProcessRelaySend(ctx context.Context, protocolMessage chainlib.ProtocolMessage, dappID string, consumerIp string, analytics *metrics.RelayMetrics) (*RelayProcessor, error) { // make sure all of the child contexts are cancelled when we exit ctx, cancel := context.WithCancel(ctx) defer cancel() - relayProcessor := NewRelayProcessor(ctx, lavasession.NewUsedProviders(directiveHeaders), rpccs.requiredResponses, chainMessage, rpccs.consumerConsistency, dappID, consumerIp, rpccs.debugRelays, rpccs.rpcConsumerLogs, rpccs, rpccs.disableNodeErrorRetry, rpccs.relayRetriesManager) + relayProcessor := NewRelayProcessor(ctx, lavasession.NewUsedProviders(protocolMessage), rpccs.requiredResponses, protocolMessage, rpccs.consumerConsistency, dappID, consumerIp, rpccs.debugRelays, rpccs.rpcConsumerLogs, rpccs, rpccs.disableNodeErrorRetry, rpccs.relayRetriesManager) var err error // try sending a relay 3 times. if failed return the error for retryFirstRelayAttempt := 0; retryFirstRelayAttempt < SendRelayAttempts; retryFirstRelayAttempt++ { @@ -419,7 +418,7 @@ func (rpccs *RPCConsumerServer) ProcessRelaySend(ctx context.Context, directiveH if analytics != nil && retryFirstRelayAttempt > 0 { analytics = nil } - err = rpccs.sendRelayToProvider(ctx, chainMessage, relayRequestData, dappID, consumerIp, relayProcessor, analytics) + err = rpccs.sendRelayToProvider(ctx, protocolMessage, dappID, consumerIp, relayProcessor, analytics) // check if we had an error. if we did, try again. if err == nil { @@ -434,7 +433,7 @@ func (rpccs *RPCConsumerServer) ProcessRelaySend(ctx context.Context, directiveH // a channel to be notified processing was done, true means we have results and can return gotResults := make(chan bool) - processingTimeout, relayTimeout := rpccs.getProcessingTimeout(chainMessage) + processingTimeout, relayTimeout := rpccs.getProcessingTimeout(protocolMessage) if rpccs.debugRelays { utils.LavaFormatDebug("Relay initiated with the following timeout schedule", utils.LogAttr("processingTimeout", processingTimeout), utils.LogAttr("newRelayTimeout", relayTimeout)) } @@ -486,7 +485,7 @@ func (rpccs *RPCConsumerServer) ProcessRelaySend(ctx context.Context, directiveH return relayProcessor, nil } // otherwise continue sending another relay - err := rpccs.sendRelayToProvider(processingCtx, chainMessage, relayRequestData, dappID, consumerIp, relayProcessor, nil) + err := rpccs.sendRelayToProvider(processingCtx, protocolMessage, dappID, consumerIp, relayProcessor, nil) go validateReturnCondition(err) go readResultsFromProcessor() // increase number of retries launched only if we still have pairing available, if we exhausted the list we don't want to break early @@ -499,7 +498,7 @@ func (rpccs *RPCConsumerServer) ProcessRelaySend(ctx context.Context, directiveH if relayProcessor.ShouldRetry(numberOfRetriesLaunched) { // limit the number of retries called from the new batch ticker flow. // if we pass the limit we just wait for the relays we sent to return. - err := rpccs.sendRelayToProvider(processingCtx, chainMessage, relayRequestData, dappID, consumerIp, relayProcessor, nil) + err := rpccs.sendRelayToProvider(processingCtx, protocolMessage, dappID, consumerIp, relayProcessor, nil) go validateReturnCondition(err) // add ticker launch metrics go rpccs.rpcConsumerLogs.SetRelaySentByNewBatchTickerMetric(rpccs.GetChainIdAndApiInterface()) @@ -524,7 +523,7 @@ func (rpccs *RPCConsumerServer) ProcessRelaySend(ctx context.Context, directiveH utils.LogAttr("processingTimeout", processingTimeout), utils.LogAttr("dappId", dappID), utils.LogAttr("consumerIp", consumerIp), - utils.LogAttr("chainMessage.GetApi().Name", chainMessage.GetApi().Name), + utils.LogAttr("protocolMessage.GetApi().Name", protocolMessage.GetApi().Name), utils.LogAttr("GUID", ctx), utils.LogAttr("relayProcessor", relayProcessor), ) @@ -553,8 +552,7 @@ func (rpccs *RPCConsumerServer) CancelSubscriptionContext(subscriptionKey string func (rpccs *RPCConsumerServer) sendRelayToProvider( ctx context.Context, - chainMessage chainlib.ChainMessage, - relayRequestData *pairingtypes.RelayPrivateData, + protocolMessage chainlib.ProtocolMessage, dappID string, consumerIp string, relayProcessor *RelayProcessor, @@ -581,27 +579,27 @@ func (rpccs *RPCConsumerServer) sendRelayToProvider( lavaChainID := rpccs.lavaChainID // Get Session. we get session here so we can use the epoch in the callbacks - reqBlock, _ := chainMessage.RequestedBlock() + reqBlock, _ := protocolMessage.RequestedBlock() // try using cache before sending relay var cacheError error if rpccs.cache.CacheActive() { // use cache only if its defined. - if !chainMessage.GetForceCacheRefresh() { // don't use cache if user specified + if !protocolMessage.GetForceCacheRefresh() { // don't use cache if user specified if reqBlock != spectypes.NOT_APPLICABLE { // don't use cache if requested block is not applicable var cacheReply *pairingtypes.CacheRelayReply - hashKey, outputFormatter, err := chainlib.HashCacheRequest(relayRequestData, chainId) + hashKey, outputFormatter, err := protocolMessage.HashCacheRequest(chainId) if err != nil { utils.LavaFormatError("sendRelayToProvider Failed getting Hash for cache request", err) } else { cacheCtx, cancel := context.WithTimeout(ctx, common.CacheTimeout) cacheReply, cacheError = rpccs.cache.GetEntry(cacheCtx, &pairingtypes.RelayCacheGet{ RequestHash: hashKey, - RequestedBlock: relayRequestData.RequestBlock, + RequestedBlock: reqBlock, ChainId: chainId, BlockHash: nil, Finalized: false, SharedStateId: sharedStateId, - SeenBlock: relayRequestData.SeenBlock, + SeenBlock: protocolMessage.RelayPrivateData().SeenBlock, }) // caching in the portal doesn't care about hashes, and we don't have data on finalization yet cancel() reply := cacheReply.GetReply() @@ -610,9 +608,9 @@ func (rpccs *RPCConsumerServer) sendRelayToProvider( cacheSeenBlock := cacheReply.GetSeenBlock() // check if the cache seen block is greater than my local seen block, this means the user requested this // request spoke with another consumer instance and use that block for inter consumer consistency. - if rpccs.sharedState && cacheSeenBlock > relayRequestData.SeenBlock { - utils.LavaFormatDebug("shared state seen block is newer", utils.LogAttr("cache_seen_block", cacheSeenBlock), utils.LogAttr("local_seen_block", relayRequestData.SeenBlock)) - relayRequestData.SeenBlock = cacheSeenBlock + if rpccs.sharedState && cacheSeenBlock > protocolMessage.RelayPrivateData().SeenBlock { + utils.LavaFormatDebug("shared state seen block is newer", utils.LogAttr("cache_seen_block", cacheSeenBlock), utils.LogAttr("local_seen_block", protocolMessage.RelayPrivateData().SeenBlock)) + protocolMessage.RelayPrivateData().SeenBlock = cacheSeenBlock // setting the fetched seen block from the cache server to our local cache as well. rpccs.consumerConsistency.SetSeenBlock(cacheSeenBlock, dappID, consumerIp) } @@ -625,7 +623,7 @@ func (rpccs *RPCConsumerServer) sendRelayToProvider( relayResult := common.RelayResult{ Reply: reply, Request: &pairingtypes.RelayRequest{ - RelayData: relayRequestData, + RelayData: protocolMessage.RelayPrivateData(), }, Finalized: false, // set false to skip data reliability StatusCode: 200, @@ -643,33 +641,33 @@ func (rpccs *RPCConsumerServer) sendRelayToProvider( } } } else { - utils.LavaFormatDebug("skipping cache due to requested block being NOT_APPLICABLE", utils.Attribute{Key: "api name", Value: chainMessage.GetApi().Name}) + utils.LavaFormatDebug("skipping cache due to requested block being NOT_APPLICABLE", utils.Attribute{Key: "api name", Value: protocolMessage.GetApi().Name}) } } } - if reqBlock == spectypes.LATEST_BLOCK && relayRequestData.SeenBlock != 0 { + if reqBlock == spectypes.LATEST_BLOCK && protocolMessage.RelayPrivateData().SeenBlock != 0 { // make optimizer select a provider that is likely to have the latest seen block - reqBlock = relayRequestData.SeenBlock + reqBlock = protocolMessage.RelayPrivateData().SeenBlock } // consumerEmergencyTracker always use latest virtual epoch virtualEpoch := rpccs.consumerTxSender.GetLatestVirtualEpoch() - addon := chainlib.GetAddon(chainMessage) - extensions := chainMessage.GetExtensions() + addon := chainlib.GetAddon(protocolMessage) + extensions := protocolMessage.GetExtensions() usedProviders := relayProcessor.GetUsedProviders() - sessions, err := rpccs.consumerSessionManager.GetSessions(ctx, chainlib.GetComputeUnits(chainMessage), usedProviders, reqBlock, addon, extensions, chainlib.GetStateful(chainMessage), virtualEpoch) + sessions, err := rpccs.consumerSessionManager.GetSessions(ctx, chainlib.GetComputeUnits(protocolMessage), usedProviders, reqBlock, addon, extensions, chainlib.GetStateful(protocolMessage), virtualEpoch) if err != nil { if lavasession.PairingListEmptyError.Is(err) { if addon != "" { return utils.LavaFormatError("No Providers For Addon", err, utils.LogAttr("addon", addon), utils.LogAttr("extensions", extensions), utils.LogAttr("userIp", consumerIp)) } else if len(extensions) > 0 && relayProcessor.GetAllowSessionDegradation() { // if we have no providers for that extension, use a regular provider, otherwise return the extension results - sessions, err = rpccs.consumerSessionManager.GetSessions(ctx, chainlib.GetComputeUnits(chainMessage), usedProviders, reqBlock, addon, []*spectypes.Extension{}, chainlib.GetStateful(chainMessage), virtualEpoch) + sessions, err = rpccs.consumerSessionManager.GetSessions(ctx, chainlib.GetComputeUnits(protocolMessage), usedProviders, reqBlock, addon, []*spectypes.Extension{}, chainlib.GetStateful(protocolMessage), virtualEpoch) if err != nil { return err } - relayProcessor.setSkipDataReliability(true) // disabling data reliability when disabling extensions. - relayRequestData.Extensions = []string{} // reset request data extensions - extensions = []*spectypes.Extension{} // reset extensions too so we wont hit SetDisallowDegradation + relayProcessor.setSkipDataReliability(true) // disabling data reliability when disabling extensions. + protocolMessage.RelayPrivateData().Extensions = []string{} // reset request data extensions + extensions = []*spectypes.Extension{} // reset extensions too so we wont hit SetDisallowDegradation } else { return err } @@ -726,7 +724,7 @@ func (rpccs *RPCConsumerServer) sendRelayToProvider( goroutineCtxCancel() }() - localRelayRequestData := *relayRequestData + localRelayRequestData := *protocolMessage.RelayPrivateData() // Extract fields from the sessionInfo singleConsumerSession := sessionInfo.Session @@ -744,10 +742,10 @@ func (rpccs *RPCConsumerServer) sendRelayToProvider( // set relay sent metric go rpccs.rpcConsumerLogs.SetRelaySentToProviderMetric(chainId, apiInterface) - if chainlib.IsFunctionTagOfType(chainMessage, spectypes.FUNCTION_TAG_SUBSCRIBE) { + if chainlib.IsFunctionTagOfType(protocolMessage, spectypes.FUNCTION_TAG_SUBSCRIBE) { utils.LavaFormatTrace("inside sendRelayToProvider, relay is subscription", utils.LogAttr("requestData", localRelayRequestData.Data)) - params, err := json.Marshal(chainMessage.GetRPCMessage().GetParams()) + params, err := json.Marshal(protocolMessage.GetRPCMessage().GetParams()) if err != nil { utils.LavaFormatError("could not marshal params", err) return @@ -781,7 +779,7 @@ func (rpccs *RPCConsumerServer) sendRelayToProvider( // unique per dappId and ip consumerToken := common.GetUniqueToken(dappID, consumerIp) - processingTimeout, expectedRelayTimeoutForQOS := rpccs.getProcessingTimeout(chainMessage) + processingTimeout, expectedRelayTimeoutForQOS := rpccs.getProcessingTimeout(protocolMessage) deadline, ok := ctx.Deadline() if ok { // we have ctx deadline. we cant go past it. processingTimeout = time.Until(deadline) @@ -796,7 +794,7 @@ func (rpccs *RPCConsumerServer) sendRelayToProvider( } } // send relay - relayLatency, errResponse, backoff := rpccs.relayInner(goroutineCtx, singleConsumerSession, localRelayResult, processingTimeout, chainMessage, consumerToken, analytics) + relayLatency, errResponse, backoff := rpccs.relayInner(goroutineCtx, singleConsumerSession, localRelayResult, processingTimeout, protocolMessage, consumerToken, analytics) if errResponse != nil { failRelaySession := func(origErr error, backoff_ bool) { backOffDuration := 0 * time.Second @@ -840,10 +838,10 @@ func (rpccs *RPCConsumerServer) sendRelayToProvider( ) } - errResponse = rpccs.consumerSessionManager.OnSessionDone(singleConsumerSession, latestBlock, chainlib.GetComputeUnits(chainMessage), relayLatency, singleConsumerSession.CalculateExpectedLatency(expectedRelayTimeoutForQOS), expectedBH, numOfProviders, pairingAddressesLen, chainMessage.GetApi().Category.HangingApi) // session done successfully + errResponse = rpccs.consumerSessionManager.OnSessionDone(singleConsumerSession, latestBlock, chainlib.GetComputeUnits(protocolMessage), relayLatency, singleConsumerSession.CalculateExpectedLatency(expectedRelayTimeoutForQOS), expectedBH, numOfProviders, pairingAddressesLen, protocolMessage.GetApi().Category.HangingApi) // session done successfully if rpccs.cache.CacheActive() && rpcclient.ValidateStatusCodes(localRelayResult.StatusCode, true) == nil { - isNodeError, _ := chainMessage.CheckResponseError(localRelayResult.Reply.Data, localRelayResult.StatusCode) + isNodeError, _ := protocolMessage.CheckResponseError(localRelayResult.Reply.Data, localRelayResult.StatusCode) // in case the error is a node error we don't want to cache if !isNodeError { // copy reply data so if it changes it doesn't panic mid async send @@ -863,7 +861,7 @@ func (rpccs *RPCConsumerServer) sendRelayToProvider( ) return } - chainMessageRequestedBlock, _ := chainMessage.RequestedBlock() + chainMessageRequestedBlock, _ := protocolMessage.RequestedBlock() if chainMessageRequestedBlock == spectypes.NOT_APPLICABLE { return } @@ -1234,8 +1232,9 @@ func (rpccs *RPCConsumerServer) sendDataReliabilityRelayIfApplicable(ctx context relayResult := results[0] if len(results) < 2 { relayRequestData := lavaprotocol.NewRelayData(ctx, relayResult.Request.RelayData.ConnectionType, relayResult.Request.RelayData.ApiUrl, relayResult.Request.RelayData.Data, relayResult.Request.RelayData.SeenBlock, reqBlock, relayResult.Request.RelayData.ApiInterface, chainMessage.GetRPCMessage().GetHeaders(), relayResult.Request.RelayData.Addon, relayResult.Request.RelayData.Extensions) + protocolMessage := chainlib.NewProtocolMessage(chainMessage, nil, relayRequestData) relayProcessorDataReliability := NewRelayProcessor(ctx, relayProcessor.usedProviders, 1, chainMessage, rpccs.consumerConsistency, dappID, consumerIp, rpccs.debugRelays, rpccs.rpcConsumerLogs, rpccs, rpccs.disableNodeErrorRetry, rpccs.relayRetriesManager) - err := rpccs.sendRelayToProvider(ctx, chainMessage, relayRequestData, dappID, consumerIp, relayProcessorDataReliability, nil) + err := rpccs.sendRelayToProvider(ctx, protocolMessage, dappID, consumerIp, relayProcessorDataReliability, nil) if err != nil { return utils.LavaFormatWarning("failed data reliability relay to provider", err, utils.LogAttr("relayProcessorDataReliability", relayProcessorDataReliability)) } diff --git a/x/pairing/types/relay_mock.pb.go b/x/pairing/types/relay_mock.pb.go index e49f7212e9..ad76b049fa 100644 --- a/x/pairing/types/relay_mock.pb.go +++ b/x/pairing/types/relay_mock.pb.go @@ -1,10 +1,6 @@ // Code generated by MockGen. DO NOT EDIT. // Source: x/pairing/types/relay.pb.go -// -// Generated by this command: -// -// mockgen -source=x/pairing/types/relay.pb.go -destination x/pairing/types/relay_mock.pb.go -package types -// + // Package types is a generated GoMock package. package types @@ -12,7 +8,7 @@ import ( context "context" reflect "reflect" - gomock "go.uber.org/mock/gomock" + gomock "github.com/golang/mock/gomock" grpc "google.golang.org/grpc" metadata "google.golang.org/grpc/metadata" ) @@ -43,7 +39,7 @@ func (m *MockRelayerClient) EXPECT() *MockRelayerClientMockRecorder { // Probe mocks base method. func (m *MockRelayerClient) Probe(ctx context.Context, in *ProbeRequest, opts ...grpc.CallOption) (*ProbeReply, error) { m.ctrl.T.Helper() - varargs := []any{ctx, in} + varargs := []interface{}{ctx, in} for _, a := range opts { varargs = append(varargs, a) } @@ -54,16 +50,16 @@ func (m *MockRelayerClient) Probe(ctx context.Context, in *ProbeRequest, opts .. } // Probe indicates an expected call of Probe. -func (mr *MockRelayerClientMockRecorder) Probe(ctx, in any, opts ...any) *gomock.Call { +func (mr *MockRelayerClientMockRecorder) Probe(ctx, in interface{}, opts ...interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]any{ctx, in}, opts...) + varargs := append([]interface{}{ctx, in}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Probe", reflect.TypeOf((*MockRelayerClient)(nil).Probe), varargs...) } // Relay mocks base method. func (m *MockRelayerClient) Relay(ctx context.Context, in *RelayRequest, opts ...grpc.CallOption) (*RelayReply, error) { m.ctrl.T.Helper() - varargs := []any{ctx, in} + varargs := []interface{}{ctx, in} for _, a := range opts { varargs = append(varargs, a) } @@ -74,16 +70,16 @@ func (m *MockRelayerClient) Relay(ctx context.Context, in *RelayRequest, opts .. } // Relay indicates an expected call of Relay. -func (mr *MockRelayerClientMockRecorder) Relay(ctx, in any, opts ...any) *gomock.Call { +func (mr *MockRelayerClientMockRecorder) Relay(ctx, in interface{}, opts ...interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]any{ctx, in}, opts...) + varargs := append([]interface{}{ctx, in}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Relay", reflect.TypeOf((*MockRelayerClient)(nil).Relay), varargs...) } // RelaySubscribe mocks base method. func (m *MockRelayerClient) RelaySubscribe(ctx context.Context, in *RelayRequest, opts ...grpc.CallOption) (Relayer_RelaySubscribeClient, error) { m.ctrl.T.Helper() - varargs := []any{ctx, in} + varargs := []interface{}{ctx, in} for _, a := range opts { varargs = append(varargs, a) } @@ -94,9 +90,9 @@ func (m *MockRelayerClient) RelaySubscribe(ctx context.Context, in *RelayRequest } // RelaySubscribe indicates an expected call of RelaySubscribe. -func (mr *MockRelayerClientMockRecorder) RelaySubscribe(ctx, in any, opts ...any) *gomock.Call { +func (mr *MockRelayerClientMockRecorder) RelaySubscribe(ctx, in interface{}, opts ...interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]any{ctx, in}, opts...) + varargs := append([]interface{}{ctx, in}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RelaySubscribe", reflect.TypeOf((*MockRelayerClient)(nil).RelaySubscribe), varargs...) } @@ -190,7 +186,7 @@ func (m_2 *MockRelayer_RelaySubscribeClient) RecvMsg(m any) error { } // RecvMsg indicates an expected call of RecvMsg. -func (mr *MockRelayer_RelaySubscribeClientMockRecorder) RecvMsg(m any) *gomock.Call { +func (mr *MockRelayer_RelaySubscribeClientMockRecorder) RecvMsg(m interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RecvMsg", reflect.TypeOf((*MockRelayer_RelaySubscribeClient)(nil).RecvMsg), m) } @@ -204,7 +200,7 @@ func (m_2 *MockRelayer_RelaySubscribeClient) SendMsg(m any) error { } // SendMsg indicates an expected call of SendMsg. -func (mr *MockRelayer_RelaySubscribeClientMockRecorder) SendMsg(m any) *gomock.Call { +func (mr *MockRelayer_RelaySubscribeClientMockRecorder) SendMsg(m interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendMsg", reflect.TypeOf((*MockRelayer_RelaySubscribeClient)(nil).SendMsg), m) } @@ -256,7 +252,7 @@ func (m *MockRelayerServer) Probe(arg0 context.Context, arg1 *ProbeRequest) (*Pr } // Probe indicates an expected call of Probe. -func (mr *MockRelayerServerMockRecorder) Probe(arg0, arg1 any) *gomock.Call { +func (mr *MockRelayerServerMockRecorder) Probe(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Probe", reflect.TypeOf((*MockRelayerServer)(nil).Probe), arg0, arg1) } @@ -271,7 +267,7 @@ func (m *MockRelayerServer) Relay(arg0 context.Context, arg1 *RelayRequest) (*Re } // Relay indicates an expected call of Relay. -func (mr *MockRelayerServerMockRecorder) Relay(arg0, arg1 any) *gomock.Call { +func (mr *MockRelayerServerMockRecorder) Relay(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Relay", reflect.TypeOf((*MockRelayerServer)(nil).Relay), arg0, arg1) } @@ -285,7 +281,7 @@ func (m *MockRelayerServer) RelaySubscribe(arg0 *RelayRequest, arg1 Relayer_Rela } // RelaySubscribe indicates an expected call of RelaySubscribe. -func (mr *MockRelayerServerMockRecorder) RelaySubscribe(arg0, arg1 any) *gomock.Call { +func (mr *MockRelayerServerMockRecorder) RelaySubscribe(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RelaySubscribe", reflect.TypeOf((*MockRelayerServer)(nil).RelaySubscribe), arg0, arg1) } @@ -336,7 +332,7 @@ func (m_2 *MockRelayer_RelaySubscribeServer) RecvMsg(m any) error { } // RecvMsg indicates an expected call of RecvMsg. -func (mr *MockRelayer_RelaySubscribeServerMockRecorder) RecvMsg(m any) *gomock.Call { +func (mr *MockRelayer_RelaySubscribeServerMockRecorder) RecvMsg(m interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RecvMsg", reflect.TypeOf((*MockRelayer_RelaySubscribeServer)(nil).RecvMsg), m) } @@ -350,7 +346,7 @@ func (m *MockRelayer_RelaySubscribeServer) Send(arg0 *RelayReply) error { } // Send indicates an expected call of Send. -func (mr *MockRelayer_RelaySubscribeServerMockRecorder) Send(arg0 any) *gomock.Call { +func (mr *MockRelayer_RelaySubscribeServerMockRecorder) Send(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Send", reflect.TypeOf((*MockRelayer_RelaySubscribeServer)(nil).Send), arg0) } @@ -364,7 +360,7 @@ func (m *MockRelayer_RelaySubscribeServer) SendHeader(arg0 metadata.MD) error { } // SendHeader indicates an expected call of SendHeader. -func (mr *MockRelayer_RelaySubscribeServerMockRecorder) SendHeader(arg0 any) *gomock.Call { +func (mr *MockRelayer_RelaySubscribeServerMockRecorder) SendHeader(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendHeader", reflect.TypeOf((*MockRelayer_RelaySubscribeServer)(nil).SendHeader), arg0) } @@ -378,7 +374,7 @@ func (m_2 *MockRelayer_RelaySubscribeServer) SendMsg(m any) error { } // SendMsg indicates an expected call of SendMsg. -func (mr *MockRelayer_RelaySubscribeServerMockRecorder) SendMsg(m any) *gomock.Call { +func (mr *MockRelayer_RelaySubscribeServerMockRecorder) SendMsg(m interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendMsg", reflect.TypeOf((*MockRelayer_RelaySubscribeServer)(nil).SendMsg), m) } @@ -392,7 +388,7 @@ func (m *MockRelayer_RelaySubscribeServer) SetHeader(arg0 metadata.MD) error { } // SetHeader indicates an expected call of SetHeader. -func (mr *MockRelayer_RelaySubscribeServerMockRecorder) SetHeader(arg0 any) *gomock.Call { +func (mr *MockRelayer_RelaySubscribeServerMockRecorder) SetHeader(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetHeader", reflect.TypeOf((*MockRelayer_RelaySubscribeServer)(nil).SetHeader), arg0) } @@ -404,7 +400,7 @@ func (m *MockRelayer_RelaySubscribeServer) SetTrailer(arg0 metadata.MD) { } // SetTrailer indicates an expected call of SetTrailer. -func (mr *MockRelayer_RelaySubscribeServerMockRecorder) SetTrailer(arg0 any) *gomock.Call { +func (mr *MockRelayer_RelaySubscribeServerMockRecorder) SetTrailer(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetTrailer", reflect.TypeOf((*MockRelayer_RelaySubscribeServer)(nil).SetTrailer), arg0) } From ea235a1302712e58c7446b94a8feece30aef26df Mon Sep 17 00:00:00 2001 From: Yaroms <103432884+Yaroms@users.noreply.github.com> Date: Sun, 25 Aug 2024 11:46:18 +0300 Subject: [PATCH 06/12] fix: CNS-fix-tracked-cu-query (#1639) * fix and remove warning * fix CLI * support mincost not mendatory * lint * fix cli name * add title and desc flags --------- Co-authored-by: Yarom Swisa --- x/epochstorage/keeper/stake_entries.go | 4 -- .../client/cli/query_spec_tracked_info.go | 16 ++++---- x/rewards/client/cli/tx.go | 41 +++++++++++++------ ...ard.go => grpc_query_spec_tracked_info.go} | 0 .../cli/query_estimated_validators_rewards.go | 2 +- .../keeper/grpc_query_tracked_usage.go | 2 +- 6 files changed, 38 insertions(+), 27 deletions(-) rename x/rewards/keeper/{grpc_query_provider_reward.go => grpc_query_spec_tracked_info.go} (100%) diff --git a/x/epochstorage/keeper/stake_entries.go b/x/epochstorage/keeper/stake_entries.go index cc8d0684d8..0c4a49630f 100644 --- a/x/epochstorage/keeper/stake_entries.go +++ b/x/epochstorage/keeper/stake_entries.go @@ -224,10 +224,6 @@ func (k Keeper) GetAllStakeEntriesCurrentForChainId(ctx sdk.Context, chainID str func (k Keeper) GetStakeEntryCurrentForChainIdByVault(ctx sdk.Context, chainID string, vault string) (val types.StakeEntry, found bool) { pk, err := k.stakeEntriesCurrent.Indexes.Index.MatchExact(ctx, collections.Join(chainID, vault)) if err != nil { - utils.LavaFormatWarning("GetStakeEntryCurrentForChainIdByVault: MatchExact with primary key failed", err, - utils.LogAttr("chain_id", chainID), - utils.LogAttr("vault", vault), - ) return types.StakeEntry{}, false } diff --git a/x/rewards/client/cli/query_spec_tracked_info.go b/x/rewards/client/cli/query_spec_tracked_info.go index d65543a0d6..3045bf1f72 100644 --- a/x/rewards/client/cli/query_spec_tracked_info.go +++ b/x/rewards/client/cli/query_spec_tracked_info.go @@ -28,18 +28,18 @@ func CmdSpecTrackedInfo() *cobra.Command { Args: cobra.RangeArgs(1, 2), RunE: func(cmd *cobra.Command, args []string) (err error) { - reqChainID := "" - if len(args) == 2 { - reqChainID = args[0] - } - clientCtx, err := client.GetClientQueryContext(cmd) if err != nil { return err } - reqProvider, err := utils.ParseCLIAddress(clientCtx, args[1]) - if err != nil { - return err + + reqChainID := args[0] + reqProvider := "" + if len(args) == 2 { + reqProvider, err = utils.ParseCLIAddress(clientCtx, args[1]) + if err != nil { + return err + } } queryClient := types.NewQueryClient(clientCtx) diff --git a/x/rewards/client/cli/tx.go b/x/rewards/client/cli/tx.go index 5d477b1e14..0849a7be87 100644 --- a/x/rewards/client/cli/tx.go +++ b/x/rewards/client/cli/tx.go @@ -27,6 +27,8 @@ const ( listSeparator = "," expeditedFlagName = "expedited" minIprpcCostFlagName = "min-cost" + titleFlagName = "title" + descriptionFlagName = "description" addIprpcSubscriptionsFlagName = "add-subscriptions" removeIprpcSubscriptionsFlagName = "remove-subscriptions" ) @@ -80,16 +82,6 @@ $ %s tx gov submit-legacy-proposal set-iprpc-data --min-cost 0ulava --add-subscr return err } - // get min cost - costStr, err := cmd.Flags().GetString(minIprpcCostFlagName) - if err != nil { - return err - } - cost, err := sdk.ParseCoinNormalized(costStr) - if err != nil { - return err - } - // get current iprpc subscriptions q := types.NewQueryClient(clientCtx) res, err := q.ShowIprpcData(context.Background(), &types.QueryShowIprpcDataRequest{}) @@ -98,6 +90,19 @@ $ %s tx gov submit-legacy-proposal set-iprpc-data --min-cost 0ulava --add-subscr } subs := res.IprpcSubscriptions + // get min cost + costStr, err := cmd.Flags().GetString(minIprpcCostFlagName) + if err != nil { + return err + } + cost := res.MinCost + if costStr != "" { + cost, err = sdk.ParseCoinNormalized(costStr) + if err != nil { + return err + } + } + // add from msg subsToAdd, err := cmd.Flags().GetStringSlice(addIprpcSubscriptionsFlagName) if err != nil { @@ -135,7 +140,16 @@ $ %s tx gov submit-legacy-proposal set-iprpc-data --min-cost 0ulava --add-subscr MinIprpcCost: cost, } - submitPropMsg, err := govv1.NewMsgSubmitProposal([]sdk.Msg{&msg}, deposit, from.String(), "", "Set IPRPC data", "Set IPRPC data", isExpedited) + title, err := cmd.Flags().GetString(titleFlagName) + if err != nil { + return err + } + + description, err := cmd.Flags().GetString(descriptionFlagName) + if err != nil { + return err + } + submitPropMsg, err := govv1.NewMsgSubmitProposal([]sdk.Msg{&msg}, deposit, from.String(), "", title, description, isExpedited) if err != nil { return err } @@ -143,10 +157,11 @@ $ %s tx gov submit-legacy-proposal set-iprpc-data --min-cost 0ulava --add-subscr return tx.GenerateOrBroadcastTxCLI(clientCtx, cmd.Flags(), submitPropMsg) }, } - cmd.Flags().String(minIprpcCostFlagName, "0ulava", "set minimum iprpc cost") + cmd.Flags().String(minIprpcCostFlagName, "", "set minimum iprpc cost") + cmd.Flags().String(titleFlagName, "Set IPRPC data", "proposal title") + cmd.Flags().String(descriptionFlagName, "Set IPRPC data", "proposal description") cmd.Flags().StringSlice(addIprpcSubscriptionsFlagName, []string{}, "add iprpc eligible subscriptions") cmd.Flags().StringSlice(removeIprpcSubscriptionsFlagName, []string{}, "remove iprpc eligible subscriptions") cmd.Flags().Bool(expeditedFlagName, false, "set to true to make the spec proposal expedited") - cmd.MarkFlagRequired(minIprpcCostFlagName) return cmd } diff --git a/x/rewards/keeper/grpc_query_provider_reward.go b/x/rewards/keeper/grpc_query_spec_tracked_info.go similarity index 100% rename from x/rewards/keeper/grpc_query_provider_reward.go rename to x/rewards/keeper/grpc_query_spec_tracked_info.go diff --git a/x/subscription/client/cli/query_estimated_validators_rewards.go b/x/subscription/client/cli/query_estimated_validators_rewards.go index da08bf7803..1b805f8742 100644 --- a/x/subscription/client/cli/query_estimated_validators_rewards.go +++ b/x/subscription/client/cli/query_estimated_validators_rewards.go @@ -10,7 +10,7 @@ import ( func CmdEstimatedValidatorsRewards() *cobra.Command { cmd := &cobra.Command{ - Use: "estimated--validator-rewards [validator] {optional: amount/delegator}", + Use: "estimated-validator-rewards [validator] {optional: amount/delegator}", Short: "calculates the rewards estimation for a validator delegation", Long: `Query to estimate the rewards a delegator will get for 1 month from the validator, if used without optional args the calculations will be for the validator itself. optional args can be amount for new delegation or address for an existing one. diff --git a/x/subscription/keeper/grpc_query_tracked_usage.go b/x/subscription/keeper/grpc_query_tracked_usage.go index 255ba69de5..5312864f53 100644 --- a/x/subscription/keeper/grpc_query_tracked_usage.go +++ b/x/subscription/keeper/grpc_query_tracked_usage.go @@ -20,7 +20,7 @@ func (k Keeper) TrackedUsage(goCtx context.Context, req *types.QuerySubscription sub, _ := k.GetSubscription(ctx, req.Subscription) res.Subscription = &sub - res.Usage, res.TotalUsage = k.GetSubTrackedCuInfo(ctx, req.Subscription, uint64(ctx.BlockHeader().Height)) + res.Usage, res.TotalUsage = k.GetSubTrackedCuInfo(ctx, req.Subscription, sub.Block) return &res, nil } From 4dd69db27091b04c78b5a04ec1157b5e15e2a766 Mon Sep 17 00:00:00 2001 From: oren-lava <111131399+oren-lava@users.noreply.github.com> Date: Sun, 25 Aug 2024 15:38:14 +0300 Subject: [PATCH 07/12] fix: Unstake CLI fee grant fix (#1648) * unstake cli fee grant fix * fix lint --- x/pairing/client/cli/tx_unstake_provider.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/x/pairing/client/cli/tx_unstake_provider.go b/x/pairing/client/cli/tx_unstake_provider.go index b00ed981d8..9e21d04c53 100644 --- a/x/pairing/client/cli/tx_unstake_provider.go +++ b/x/pairing/client/cli/tx_unstake_provider.go @@ -134,6 +134,10 @@ func CreateRevokeFeeGrantMsg(clientCtx client.Context, chainID string) (*feegran feegrantQuerier := feegrant.NewQueryClient(clientCtx) res, err := feegrantQuerier.Allowance(ctx, &feegrant.QueryAllowanceRequest{Granter: vault, Grantee: providerEntry.Address}) if err != nil { + if strings.Contains(err.Error(), "fee-grant not found") { + // fee grant not found, do nothing + return nil, nil //nolint + } return nil, utils.LavaFormatError("failed querying feegrant for gas fees for granter", err, utils.LogAttr("granter", vault), ) From ca1828c79b2bcbfe405403e992b57d6790c37774 Mon Sep 17 00:00:00 2001 From: Omer <100387053+omerlavanet@users.noreply.github.com> Date: Sun, 25 Aug 2024 15:38:29 +0300 Subject: [PATCH 08/12] feat: support provider method routes (#1642) * finished, wip on tests * fixed bugs, added tests and a config * lint --- .../lava_example_archive_methodroute.yml | 35 ++ protocol/chainlib/chain_router.go | 140 +++++-- protocol/chainlib/chain_router_test.go | 371 ++++++++++++++++++ .../extensionslib/extension_parser.go | 2 +- protocol/common/endpoints.go | 1 + protocol/lavasession/router_key.go | 9 +- 6 files changed, 520 insertions(+), 38 deletions(-) create mode 100644 config/provider_examples/lava_example_archive_methodroute.yml diff --git a/config/provider_examples/lava_example_archive_methodroute.yml b/config/provider_examples/lava_example_archive_methodroute.yml new file mode 100644 index 0000000000..e8cbf3bad9 --- /dev/null +++ b/config/provider_examples/lava_example_archive_methodroute.yml @@ -0,0 +1,35 @@ +endpoints: + - api-interface: tendermintrpc + chain-id: LAV1 + network-address: + address: "127.0.0.1:2220" + node-urls: + - url: ws://127.0.0.1:26657/websocket + - url: http://127.0.0.1:26657 + - url: http://127.0.0.1:26657 + addons: + - archive + - url: https://trustless-api.com + methods: + - block + - block_by_hash + addons: + - archive + - api-interface: grpc + chain-id: LAV1 + network-address: + address: "127.0.0.1:2220" + node-urls: + - url: 127.0.0.1:9090 + - url: 127.0.0.1:9090 + addons: + - archive + - api-interface: rest + chain-id: LAV1 + network-address: + address: "127.0.0.1:2220" + node-urls: + - url: http://127.0.0.1:1317 + - url: http://127.0.0.1:1317 + addons: + - archive diff --git a/protocol/chainlib/chain_router.go b/protocol/chainlib/chain_router.go index f4be579440..47a56f1032 100644 --- a/protocol/chainlib/chain_router.go +++ b/protocol/chainlib/chain_router.go @@ -16,9 +16,15 @@ import ( "google.golang.org/grpc/metadata" ) +type MethodRoute struct { + lavasession.RouterKey + method string +} + type chainRouterEntry struct { ChainProxy addonsSupported map[string]struct{} + methodsRouted map[string]struct{} } func (cre *chainRouterEntry) isSupporting(addon string) bool { @@ -36,13 +42,26 @@ type chainRouterImpl struct { chainProxyRouter map[lavasession.RouterKey][]chainRouterEntry } -func (cri *chainRouterImpl) getChainProxySupporting(ctx context.Context, addon string, extensions []string) (ChainProxy, error) { +func (cri *chainRouterImpl) GetChainProxySupporting(ctx context.Context, addon string, extensions []string, method string) (ChainProxy, error) { cri.lock.RLock() defer cri.lock.RUnlock() + + // check if that specific method has a special route, if it does apply it to the router key wantedRouterKey := lavasession.NewRouterKey(extensions) if chainProxyEntries, ok := cri.chainProxyRouter[wantedRouterKey]; ok { for _, chainRouterEntry := range chainProxyEntries { if chainRouterEntry.isSupporting(addon) { + // check if the method is supported + if len(chainRouterEntry.methodsRouted) > 0 { + if _, ok := chainRouterEntry.methodsRouted[method]; !ok { + continue + } + utils.LavaFormatTrace("chainProxy supporting method routing selected", + utils.LogAttr("addon", addon), + utils.LogAttr("wantedRouterKey", wantedRouterKey), + utils.LogAttr("method", method), + ) + } if wantedRouterKey != lavasession.GetEmptyRouterKey() { // add trailer only when router key is not default (||) grpc.SetTrailer(ctx, metadata.Pairs(RPCProviderNodeExtension, string(wantedRouterKey))) } @@ -70,7 +89,7 @@ func (cri chainRouterImpl) ExtensionsSupported(extensions []string) bool { func (cri chainRouterImpl) SendNodeMsg(ctx context.Context, ch chan interface{}, chainMessage ChainMessageForSend, extensions []string) (relayReply *RelayReplyWrapper, subscriptionID string, relayReplyServer *rpcclient.ClientSubscription, proxyUrl common.NodeUrl, chainId string, err error) { // add the parsed addon from the apiCollection addon := chainMessage.GetApiCollection().CollectionData.AddOn - selectedChainProxy, err := cri.getChainProxySupporting(ctx, addon, extensions) + selectedChainProxy, err := cri.GetChainProxySupporting(ctx, addon, extensions, chainMessage.GetApi().Name) if err != nil { return nil, "", nil, common.NodeUrl{}, "", err } @@ -80,55 +99,83 @@ func (cri chainRouterImpl) SendNodeMsg(ctx context.Context, ch chan interface{}, } // batch nodeUrls with the same addons together in a copy -func batchNodeUrlsByServices(rpcProviderEndpoint lavasession.RPCProviderEndpoint) map[lavasession.RouterKey]lavasession.RPCProviderEndpoint { +func (cri *chainRouterImpl) BatchNodeUrlsByServices(rpcProviderEndpoint lavasession.RPCProviderEndpoint) (map[lavasession.RouterKey]lavasession.RPCProviderEndpoint, error) { returnedBatch := map[lavasession.RouterKey]lavasession.RPCProviderEndpoint{} + routesToCheck := map[lavasession.RouterKey]bool{} + methodRoutes := map[string]int{} for _, nodeUrl := range rpcProviderEndpoint.NodeUrls { routerKey := lavasession.NewRouterKey(nodeUrl.Addons) - - u, err := url.Parse(nodeUrl.Url) - // Some parsing may fail because of gRPC - if err == nil && (u.Scheme == "ws" || u.Scheme == "wss") { - // if websocket, check if we have a router key for http already. if not add a websocket router key - // so in case we didn't get an http endpoint, we can use the ws one. - if _, ok := returnedBatch[routerKey]; !ok { - returnedBatch[routerKey] = lavasession.RPCProviderEndpoint{ - NetworkAddress: rpcProviderEndpoint.NetworkAddress, - ChainID: rpcProviderEndpoint.ChainID, - ApiInterface: rpcProviderEndpoint.ApiInterface, - Geolocation: rpcProviderEndpoint.Geolocation, - NodeUrls: []common.NodeUrl{nodeUrl}, // add existing nodeUrl to the batch - } + if len(nodeUrl.Methods) > 0 { + // all methods defined here will go to the same batch + methodRoutesUnique := strings.Join(nodeUrl.Methods, ",") + var existing int + var ok bool + if existing, ok = methodRoutes[methodRoutesUnique]; !ok { + methodRoutes[methodRoutesUnique] = len(methodRoutes) + existing = len(methodRoutes) } - - // now change the router key to fit the websocket extension key. - nodeUrl.Addons = append(nodeUrl.Addons, WebSocketExtension) - routerKey = lavasession.NewRouterKey(nodeUrl.Addons) + routerKey = routerKey.ApplyMethodsRoute(existing) } + cri.parseNodeUrl(nodeUrl, returnedBatch, routerKey, rpcProviderEndpoint) + } + if len(returnedBatch) == 0 { + return nil, utils.LavaFormatError("invalid batch, routes are empty", nil, utils.LogAttr("endpoint", rpcProviderEndpoint)) + } + // validate all defined method routes have a regular route + for routerKey, valid := range routesToCheck { + if !valid { + return nil, utils.LavaFormatError("invalid batch, missing regular route for method route", nil, utils.LogAttr("routerKey", routerKey)) + } + } + return returnedBatch, nil +} - if existingEndpoint, ok := returnedBatch[routerKey]; !ok { +func (*chainRouterImpl) parseNodeUrl(nodeUrl common.NodeUrl, returnedBatch map[lavasession.RouterKey]lavasession.RPCProviderEndpoint, routerKey lavasession.RouterKey, rpcProviderEndpoint lavasession.RPCProviderEndpoint) { + u, err := url.Parse(nodeUrl.Url) + // Some parsing may fail because of gRPC + if err == nil && (u.Scheme == "ws" || u.Scheme == "wss") { + // if websocket, check if we have a router key for http already. if not add a websocket router key + // so in case we didn't get an http endpoint, we can use the ws one. + if _, ok := returnedBatch[routerKey]; !ok { returnedBatch[routerKey] = lavasession.RPCProviderEndpoint{ NetworkAddress: rpcProviderEndpoint.NetworkAddress, ChainID: rpcProviderEndpoint.ChainID, ApiInterface: rpcProviderEndpoint.ApiInterface, Geolocation: rpcProviderEndpoint.Geolocation, - NodeUrls: []common.NodeUrl{nodeUrl}, // add existing nodeUrl to the batch + NodeUrls: []common.NodeUrl{nodeUrl}, } - } else { - // setting the incoming url first as it might be http while existing is websocket. (we prioritize http over ws when possible) - existingEndpoint.NodeUrls = append([]common.NodeUrl{nodeUrl}, existingEndpoint.NodeUrls...) - returnedBatch[routerKey] = existingEndpoint } + // now change the router key to fit the websocket extension key. + nodeUrl.Addons = append(nodeUrl.Addons, WebSocketExtension) + routerKey = lavasession.NewRouterKey(nodeUrl.Addons) } - return returnedBatch + if existingEndpoint, ok := returnedBatch[routerKey]; !ok { + returnedBatch[routerKey] = lavasession.RPCProviderEndpoint{ + NetworkAddress: rpcProviderEndpoint.NetworkAddress, + ChainID: rpcProviderEndpoint.ChainID, + ApiInterface: rpcProviderEndpoint.ApiInterface, + Geolocation: rpcProviderEndpoint.Geolocation, + NodeUrls: []common.NodeUrl{nodeUrl}, + } + } else { + // setting the incoming url first as it might be http while existing is websocket. (we prioritize http over ws when possible) + existingEndpoint.NodeUrls = append([]common.NodeUrl{nodeUrl}, existingEndpoint.NodeUrls...) + returnedBatch[routerKey] = existingEndpoint + } } -func newChainRouter(ctx context.Context, nConns uint, rpcProviderEndpoint lavasession.RPCProviderEndpoint, chainParser ChainParser, proxyConstructor func(context.Context, uint, lavasession.RPCProviderEndpoint, ChainParser) (ChainProxy, error)) (ChainRouter, error) { +func newChainRouter(ctx context.Context, nConns uint, rpcProviderEndpoint lavasession.RPCProviderEndpoint, chainParser ChainParser, proxyConstructor func(context.Context, uint, lavasession.RPCProviderEndpoint, ChainParser) (ChainProxy, error)) (*chainRouterImpl, error) { chainProxyRouter := map[lavasession.RouterKey][]chainRouterEntry{} - + cri := chainRouterImpl{ + lock: &sync.RWMutex{}, + } requiredMap := map[requirementSt]struct{}{} supportedMap := map[requirementSt]struct{}{} - rpcProviderEndpointBatch := batchNodeUrlsByServices(rpcProviderEndpoint) + rpcProviderEndpointBatch, err := cri.BatchNodeUrlsByServices(rpcProviderEndpoint) + if err != nil { + return nil, err + } for _, rpcProviderEndpointEntry := range rpcProviderEndpointBatch { addons, extensions, err := chainParser.SeparateAddonsExtensions(append(rpcProviderEndpointEntry.NodeUrls[0].Addons, "")) if err != nil { @@ -151,6 +198,14 @@ func newChainRouter(ctx context.Context, nConns uint, rpcProviderEndpoint lavase return allExtensionsRouterKey } routerKey := updateRouteCombinations(extensions, addons) + methodsRouted := map[string]struct{}{} + methods := rpcProviderEndpointEntry.NodeUrls[0].Methods + if len(methods) > 0 { + for _, method := range methods { + methodsRouted[method] = struct{}{} + } + } + chainProxy, err := proxyConstructor(ctx, nConns, rpcProviderEndpointEntry, chainParser) if err != nil { // TODO: allow some urls to be down @@ -159,11 +214,17 @@ func newChainRouter(ctx context.Context, nConns uint, rpcProviderEndpoint lavase chainRouterEntryInst := chainRouterEntry{ ChainProxy: chainProxy, addonsSupported: addonsSupportedMap, + methodsRouted: methodsRouted, } if chainRouterEntries, ok := chainProxyRouter[routerKey]; !ok { chainProxyRouter[routerKey] = []chainRouterEntry{chainRouterEntryInst} } else { - chainProxyRouter[routerKey] = append(chainRouterEntries, chainRouterEntryInst) + if len(methodsRouted) > 0 { + // if there are routed methods we want this in the beginning to intercept them + chainProxyRouter[routerKey] = append([]chainRouterEntry{chainRouterEntryInst}, chainRouterEntries...) + } else { + chainProxyRouter[routerKey] = append(chainRouterEntries, chainRouterEntryInst) + } } } if len(requiredMap) > len(supportedMap) { @@ -189,11 +250,18 @@ func newChainRouter(ctx context.Context, nConns uint, rpcProviderEndpoint lavase } } - cri := chainRouterImpl{ - lock: &sync.RWMutex{}, - chainProxyRouter: chainProxyRouter, + // make sure all chainProxyRouter entries have one without a method routing + for routerKey, chainRouterEntries := range chainProxyRouter { + // get the last entry, if it has methods routed, we need to error out + lastEntry := chainRouterEntries[len(chainRouterEntries)-1] + if len(lastEntry.methodsRouted) > 0 { + return nil, utils.LavaFormatError("last entry in chainProxyRouter has methods routed, this means no chainProxy supports all methods", nil, utils.LogAttr("routerKey", routerKey)) + } } - return cri, nil + + cri.chainProxyRouter = chainProxyRouter + + return &cri, nil } type requirementSt struct { diff --git a/protocol/chainlib/chain_router_test.go b/protocol/chainlib/chain_router_test.go index 56e650380e..c16c7e6c25 100644 --- a/protocol/chainlib/chain_router_test.go +++ b/protocol/chainlib/chain_router_test.go @@ -5,6 +5,7 @@ import ( "log" "net" "os" + "strings" "testing" "time" @@ -14,10 +15,12 @@ import ( "github.com/gofiber/fiber/v2/middleware/favicon" "github.com/gofiber/websocket/v2" "github.com/lavanet/lava/v2/protocol/chainlib/chainproxy/rpcclient" + "github.com/lavanet/lava/v2/protocol/chainlib/extensionslib" "github.com/lavanet/lava/v2/protocol/common" "github.com/lavanet/lava/v2/protocol/lavasession" testcommon "github.com/lavanet/lava/v2/testutil/common" "github.com/lavanet/lava/v2/utils" + epochstoragetypes "github.com/lavanet/lava/v2/x/epochstorage/types" spectypes "github.com/lavanet/lava/v2/x/spec/types" "github.com/stretchr/testify/require" ) @@ -751,6 +754,374 @@ func TestChainRouterWithEnabledWebSocketInSpec(t *testing.T) { } } +type chainProxyMock struct { + endpoint lavasession.RPCProviderEndpoint +} + +func (m *chainProxyMock) GetChainProxyInformation() (common.NodeUrl, string) { + urlStr := "" + if len(m.endpoint.NodeUrls) > 0 { + urlStr = m.endpoint.NodeUrls[0].UrlStr() + } + return common.NodeUrl{}, urlStr +} + +func (m *chainProxyMock) SendNodeMsg(ctx context.Context, ch chan interface{}, chainMessage ChainMessageForSend) (relayReply *RelayReplyWrapper, subscriptionID string, relayReplyServer *rpcclient.ClientSubscription, err error) { + return nil, "", nil, nil +} + +type PolicySt struct { + addons []string + extensions []string + apiInterface string +} + +func (a PolicySt) GetSupportedAddons(string) ([]string, error) { + return a.addons, nil +} + +func (a PolicySt) GetSupportedExtensions(string) ([]epochstoragetypes.EndpointService, error) { + ret := []epochstoragetypes.EndpointService{} + for _, ext := range a.extensions { + ret = append(ret, epochstoragetypes.EndpointService{Extension: ext, ApiInterface: a.apiInterface}) + } + return ret, nil +} + +func TestChainRouterWithMethodRoutes(t *testing.T) { + ctx := context.Background() + apiInterface := spectypes.APIInterfaceRest + chainParser, err := NewChainParser(apiInterface) + require.NoError(t, err) + + IgnoreSubscriptionNotConfiguredError = false + + addonsOptions := []string{"-addon-", "-addon2-"} + extensionsOptions := []string{"-test-", "-test2-", "-test3-"} + + spec := testcommon.CreateMockSpec() + spec.ApiCollections = []*spectypes.ApiCollection{ + { + Enabled: true, + CollectionData: spectypes.CollectionData{ + ApiInterface: apiInterface, + InternalPath: "", + Type: "", + AddOn: "", + }, + Extensions: []*spectypes.Extension{ + { + Name: extensionsOptions[0], + CuMultiplier: 1, + }, + { + Name: extensionsOptions[1], + CuMultiplier: 1, + }, + { + Name: extensionsOptions[2], + CuMultiplier: 1, + }, + }, + ParseDirectives: []*spectypes.ParseDirective{{ + FunctionTag: spectypes.FUNCTION_TAG_SUBSCRIBE, + }}, + Apis: []*spectypes.Api{ + { + Enabled: true, + Name: "api-1", + }, + { + Enabled: true, + Name: "api-2", + }, + { + Enabled: true, + Name: "api-8", + }, + }, + }, + { + Enabled: true, + CollectionData: spectypes.CollectionData{ + ApiInterface: apiInterface, + InternalPath: "", + Type: "", + AddOn: addonsOptions[0], + }, + Extensions: []*spectypes.Extension{ + { + Name: extensionsOptions[0], + CuMultiplier: 1, + }, + { + Name: extensionsOptions[1], + CuMultiplier: 1, + }, + { + Name: extensionsOptions[2], + CuMultiplier: 1, + }, + }, + ParseDirectives: []*spectypes.ParseDirective{{ + FunctionTag: spectypes.FUNCTION_TAG_SUBSCRIBE, + }}, + Apis: []*spectypes.Api{ + { + Enabled: true, + Name: "api-3", + }, + { + Enabled: true, + Name: "api-4", + }, + }, + }, + } + chainParser.SetSpec(spec) + endpoint := &lavasession.RPCProviderEndpoint{ + NetworkAddress: lavasession.NetworkAddressData{}, + ChainID: spec.Index, + ApiInterface: apiInterface, + Geolocation: 1, + NodeUrls: []common.NodeUrl{}, + } + const extMarker = "::ext::" + playBook := []struct { + name string + nodeUrls []common.NodeUrl + success bool + apiToUrlMapping map[string]string + }{ + { + name: "addon routing", + nodeUrls: []common.NodeUrl{ + { + Url: "-0-", + Methods: []string{}, + Addons: []string{addonsOptions[0]}, + }, + { + Url: "ws:-0-", + Addons: []string{addonsOptions[0]}, + }, + { + Url: "-1-", + Methods: []string{"api-2"}, + }, + }, + success: true, + apiToUrlMapping: map[string]string{ + "api-1": "-0-", + "api-2": "-1-", + "api-3": "-0-", + }, + }, + { + name: "basic method routing", + nodeUrls: []common.NodeUrl{ + { + Url: "-0-", + Methods: []string{}, + }, + { + Url: "ws:-0-", + Methods: []string{}, + }, + { + Url: "-1-", + Methods: []string{"api-2"}, + }, + { + Url: "ws:-1-", + Methods: []string{}, + }, + }, + success: true, + apiToUrlMapping: map[string]string{ + "api-1": "-0-", + "api-2": "-1-", + }, + }, + { + name: "method routing with extension", + nodeUrls: []common.NodeUrl{ + { + Url: "-0-", + Methods: []string{}, + }, + { + Url: "ws:-0-", + Methods: []string{}, + }, + { + Url: "-1-", + Addons: []string{extensionsOptions[0]}, + }, + { + Url: "-2-", + Methods: []string{"api-2"}, + Addons: []string{extensionsOptions[0]}, + }, + }, + success: true, + apiToUrlMapping: map[string]string{ + "api-1": "-0-", + "api-2": "-0-", + "api-1" + extMarker + extensionsOptions[0]: "-1-", + "api-2" + extMarker + extensionsOptions[0]: "-2-", + }, + }, + { + name: "method routing with two extensions", + nodeUrls: []common.NodeUrl{ + { + Url: "-0-", + Methods: []string{}, + }, + { + Url: "ws:-0-", + Methods: []string{}, + }, + { + Url: "-1-", + Addons: []string{extensionsOptions[0]}, + }, + { + Url: "-2-", + Methods: []string{"api-2"}, + Addons: []string{extensionsOptions[0]}, + }, + { + Url: "-3-", + Addons: []string{extensionsOptions[1]}, + }, + { + Url: "-4-", + Methods: []string{"api-8"}, + Addons: []string{extensionsOptions[1]}, + }, + }, + success: true, + apiToUrlMapping: map[string]string{ + "api-1": "-0-", + "api-2": "-0-", + "api-1" + extMarker + extensionsOptions[0]: "-1-", + "api-2" + extMarker + extensionsOptions[0]: "-2-", + "api-1" + extMarker + extensionsOptions[1]: "-3-", + "api-8" + extMarker + extensionsOptions[1]: "-4-", + }, + }, + { + name: "two method routings with extension", + nodeUrls: []common.NodeUrl{ + { + Url: "-0-", + Methods: []string{}, + }, + { + Url: "ws:-0-", + Methods: []string{}, + }, + { + Url: "-1-", + Addons: []string{extensionsOptions[0]}, + }, + { + Url: "-2-", + Methods: []string{"api-2"}, + Addons: []string{extensionsOptions[0]}, + }, + { + Url: "-3-", + Methods: []string{"api-8"}, + Addons: []string{extensionsOptions[0]}, + }, + { + Url: "ws:-1-", + Methods: []string{}, + }, + }, + success: true, + apiToUrlMapping: map[string]string{ + "api-1": "-0-", + "api-2": "-0-", + "api-1" + extMarker + extensionsOptions[0]: "-1-", + "api-2" + extMarker + extensionsOptions[0]: "-2-", + "api-8" + extMarker + extensionsOptions[0]: "-3-", + }, + }, + { + name: "method routing without base", + nodeUrls: []common.NodeUrl{ + { + Url: "-0-", + Methods: []string{"api-1"}, + }, + { + Url: "ws:-0-", + Methods: []string{"api-1"}, + }, + }, + success: false, + }, + { + name: "method routing without base with extension", + nodeUrls: []common.NodeUrl{ + { + Url: "-0-", + Methods: []string{}, + }, + { + Url: "ws:-0-", + Methods: []string{}, + }, + { + Url: "-1-", + Addons: []string{extensionsOptions[0]}, + Methods: []string{"api-1"}, + }, + }, + success: false, + }, + } + mockProxyConstructor := func(_ context.Context, _ uint, endp lavasession.RPCProviderEndpoint, _ ChainParser) (ChainProxy, error) { + mockChainProxy := &chainProxyMock{endpoint: endp} + return mockChainProxy, nil + } + for _, play := range playBook { + t.Run(play.name, func(t *testing.T) { + endpoint.NodeUrls = play.nodeUrls + policy := PolicySt{ + addons: addonsOptions, + extensions: extensionsOptions, + apiInterface: apiInterface, + } + chainParser.SetPolicy(policy, spec.Index, apiInterface) + chainRouter, err := newChainRouter(ctx, 1, *endpoint, chainParser, mockProxyConstructor) + if play.success { + require.NoError(t, err) + for api, url := range play.apiToUrlMapping { + extension := extensionslib.ExtensionInfo{} + if strings.Contains(api, extMarker) { + splitted := strings.Split(api, extMarker) + api = splitted[0] + extension.ExtensionOverride = []string{splitted[1]} + } + chainMsg, err := chainParser.ParseMsg(api, nil, "", nil, extension) + require.NoError(t, err) + chainProxy, err := chainRouter.GetChainProxySupporting(ctx, chainMsg.GetApiCollection().CollectionData.AddOn, common.GetExtensionNames(chainMsg.GetExtensions()), api) + require.NoError(t, err) + _, urlFromProxy := chainProxy.GetChainProxyInformation() + require.Equal(t, url, urlFromProxy, "chainMsg: %+v, ---chainRouter: %+v", chainMsg, chainRouter) + } + } else { + require.Error(t, err) + } + }) + } +} + func createRPCServer() net.Listener { listener, err := net.Listen("tcp", listenerAddressTcp) if err != nil { diff --git a/protocol/chainlib/extensionslib/extension_parser.go b/protocol/chainlib/extensionslib/extension_parser.go index f9ddcbdea0..91af83f95d 100644 --- a/protocol/chainlib/extensionslib/extension_parser.go +++ b/protocol/chainlib/extensionslib/extension_parser.go @@ -69,7 +69,7 @@ func (ep *ExtensionParser) ExtensionParsing(addon string, extensionsChainMessage continue } extensionParserRule := NewExtensionParserRule(extension) - if extensionParserRule.isPassingRule(extensionsChainMessage, latestBlock) { + if extensionParserRule != nil && extensionParserRule.isPassingRule(extensionsChainMessage, latestBlock) { extensionsChainMessage.SetExtension(extension) } } diff --git a/protocol/common/endpoints.go b/protocol/common/endpoints.go index 6025f6bcb0..b91de7be8c 100644 --- a/protocol/common/endpoints.go +++ b/protocol/common/endpoints.go @@ -55,6 +55,7 @@ type NodeUrl struct { Timeout time.Duration `yaml:"timeout,omitempty" json:"timeout,omitempty" mapstructure:"timeout"` Addons []string `yaml:"addons,omitempty" json:"addons,omitempty" mapstructure:"addons"` SkipVerifications []string `yaml:"skip-verifications,omitempty" json:"skip-verifications,omitempty" mapstructure:"skip-verifications"` + Methods []string `yaml:"methods,omitempty" json:"methods,omitempty" mapstructure:"methods"` } type ChainMessageGetApiInterface interface { diff --git a/protocol/lavasession/router_key.go b/protocol/lavasession/router_key.go index 441bdc6660..291e543235 100644 --- a/protocol/lavasession/router_key.go +++ b/protocol/lavasession/router_key.go @@ -2,15 +2,22 @@ package lavasession import ( "sort" + "strconv" "strings" ) const ( - sep = "|" + sep = "|" + methodRouteSep = "method-route:" ) type RouterKey string +func (rk *RouterKey) ApplyMethodsRoute(routeNum int) RouterKey { + additionalPath := strconv.FormatInt(int64(routeNum), 10) + return RouterKey(string(*rk) + methodRouteSep + additionalPath) +} + func NewRouterKey(extensions []string) RouterKey { // make sure addons have no repetitions uniqueExtensions := map[string]struct{}{} From 22bdee074b54ce7d596956cad22325ef5dcf0be5 Mon Sep 17 00:00:00 2001 From: Yaroms <103432884+Yaroms@users.noreply.github.com> Date: Sun, 25 Aug 2024 15:58:35 +0300 Subject: [PATCH 09/12] fix (#1647) Co-authored-by: Yaroms Co-authored-by: Omer <100387053+omerlavanet@users.noreply.github.com> --- x/pairing/client/cli/tx_modify_provider.go | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/x/pairing/client/cli/tx_modify_provider.go b/x/pairing/client/cli/tx_modify_provider.go index 8d8d03fb6d..393ee75723 100644 --- a/x/pairing/client/cli/tx_modify_provider.go +++ b/x/pairing/client/cli/tx_modify_provider.go @@ -111,7 +111,7 @@ func CmdModifyProvider() *cobra.Command { return utils.LavaFormatError("provider isn't staked on chainID, no address match", nil) } - var validator string + validator := getValidator(clientCtx, clientCtx.GetFromAddress().String()) newAmount, err := cmd.Flags().GetString(AmountFlagName) if err != nil { return err @@ -131,8 +131,6 @@ func CmdModifyProvider() *cobra.Command { } else { return fmt.Errorf("increasing or decreasing stake must be accompanied with validator flag") } - } else { - validator = getValidator(clientCtx, clientCtx.GetFromAddress().String()) } providerEntry.Stake = newStake } From 649e857834e8fb716479cf03d6efd20aec2c2b2b Mon Sep 17 00:00:00 2001 From: Valters Jansons Date: Sun, 25 Aug 2024 17:04:55 +0300 Subject: [PATCH 10/12] fix: Use correct Github repository archive link (#1645) There was a mass-replacement for Lava v2 in commit 5864b7a. This resulted in the auto-download link being broken for Lavavisor. With this commit, the archive downloads can be successful again. --- ecosystem/lavavisor/pkg/process/fetcher.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ecosystem/lavavisor/pkg/process/fetcher.go b/ecosystem/lavavisor/pkg/process/fetcher.go index 1fb067777e..5fa50cd994 100644 --- a/ecosystem/lavavisor/pkg/process/fetcher.go +++ b/ecosystem/lavavisor/pkg/process/fetcher.go @@ -196,7 +196,7 @@ func (pbf *ProtocolBinaryFetcher) downloadAndBuildFromGithub(version, versionDir return utils.LavaFormatError("[Lavavisor] failed to clean up binary directory", err) } // URL might need to be updated based on the actual GitHub repository - url := fmt.Sprintf("https://github.com/lavanet/lava/v2/archive/refs/tags/v%s.zip", version) + url := fmt.Sprintf("https://github.com/lavanet/lava/archive/refs/tags/v%s.zip", version) utils.LavaFormatInfo("[Lavavisor] Fetching the source from: ", utils.Attribute{Key: "URL", Value: url}) // Send the request From 395a05414c490c1d1a8cc34f4349229c0c9b0f00 Mon Sep 17 00:00:00 2001 From: Omer <100387053+omerlavanet@users.noreply.github.com> Date: Sun, 25 Aug 2024 20:19:18 +0300 Subject: [PATCH 11/12] chore: added init so it supports contributor (#1649) * added init so it supports contributor * add sdk init on testMain * lint --- x/spec/keeper/spec_test.go | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/x/spec/keeper/spec_test.go b/x/spec/keeper/spec_test.go index 21d57257db..f8fdd05a33 100644 --- a/x/spec/keeper/spec_test.go +++ b/x/spec/keeper/spec_test.go @@ -9,6 +9,7 @@ import ( "cosmossdk.io/math" sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/lavanet/lava/v2/cmd/lavad/cmd" "github.com/lavanet/lava/v2/testutil/common" keepertest "github.com/lavanet/lava/v2/testutil/keeper" "github.com/lavanet/lava/v2/testutil/nullify" @@ -225,6 +226,14 @@ func TestSpecRemove(t *testing.T) { } } +func TestMain(m *testing.M) { + // This code will run once before any test cases are executed. + cmd.InitSDKConfig() + // Run the actual tests + exitCode := m.Run() + os.Exit(exitCode) +} + func TestSpecGetAll(t *testing.T) { ts := newTester(t) items := ts.createNSpec(10) From 3032b2b1e9a6f19728c06dd4c7988d253bb1863d Mon Sep 17 00:00:00 2001 From: Ran Mishael Date: Sun, 25 Aug 2024 20:16:39 +0200 Subject: [PATCH 12/12] feat: PRT - adding debug headers and fixing nil deref on response nil --- protocol/chainlib/protocol_message.go | 5 +- protocol/common/endpoints.go | 2 + .../lavasession/consumer_session_manager.go | 6 +++ protocol/rpcconsumer/rpcconsumer_server.go | 50 +++++++++++++++---- 4 files changed, 52 insertions(+), 11 deletions(-) diff --git a/protocol/chainlib/protocol_message.go b/protocol/chainlib/protocol_message.go index c9ed2ea01d..9a3313e07e 100644 --- a/protocol/chainlib/protocol_message.go +++ b/protocol/chainlib/protocol_message.go @@ -31,7 +31,10 @@ func (bpm *BaseProtocolMessage) GetBlockedProviders() []string { } blockedProviders, ok := bpm.directiveHeaders[common.BLOCK_PROVIDERS_ADDRESSES_HEADER_NAME] if ok { - return strings.Split(blockedProviders, ",") + blockProviders := strings.Split(blockedProviders, ",") + if len(blockProviders) <= 2 { + return blockProviders + } } return nil } diff --git a/protocol/common/endpoints.go b/protocol/common/endpoints.go index b91de7be8c..33de581e75 100644 --- a/protocol/common/endpoints.go +++ b/protocol/common/endpoints.go @@ -25,7 +25,9 @@ const ( PROVIDER_LATEST_BLOCK_HEADER_NAME = "Provider-Latest-Block" GUID_HEADER_NAME = "Lava-Guid" ERRORED_PROVIDERS_HEADER_NAME = "Lava-Errored-Providers" + NODE_ERRORS_PROVIDERS_HEADER_NAME = "Lava-Node-Errors-providers" REPORTED_PROVIDERS_HEADER_NAME = "Lava-Reported-Providers" + LAVAP_VERSION_HEADER_NAME = "Lavap-Version" LAVA_CONSUMER_PROCESS_GUID = "lava-consumer-process-guid" // these headers need to be lowercase BLOCK_PROVIDERS_ADDRESSES_HEADER_NAME = "lava-providers-block" diff --git a/protocol/lavasession/consumer_session_manager.go b/protocol/lavasession/consumer_session_manager.go index 0074c74508..9b806a4d69 100644 --- a/protocol/lavasession/consumer_session_manager.go +++ b/protocol/lavasession/consumer_session_manager.go @@ -63,6 +63,12 @@ type ConsumerSessionManager struct { activeSubscriptionProvidersStorage *ActiveSubscriptionProvidersStorage } +func (csm *ConsumerSessionManager) GetNumberOfValidProviders() int { + csm.lock.RLock() + defer csm.lock.RUnlock() + return len(csm.validAddresses) +} + // this is being read in multiple locations and but never changes so no need to lock. func (csm *ConsumerSessionManager) RPCEndpoint() RPCEndpoint { return *csm.rpcEndpoint diff --git a/protocol/rpcconsumer/rpcconsumer_server.go b/protocol/rpcconsumer/rpcconsumer_server.go index 595038b5a7..6824821a41 100644 --- a/protocol/rpcconsumer/rpcconsumer_server.go +++ b/protocol/rpcconsumer/rpcconsumer_server.go @@ -25,6 +25,7 @@ import ( "github.com/lavanet/lava/v2/protocol/lavasession" "github.com/lavanet/lava/v2/protocol/metrics" "github.com/lavanet/lava/v2/protocol/performance" + "github.com/lavanet/lava/v2/protocol/upgrade" "github.com/lavanet/lava/v2/utils" "github.com/lavanet/lava/v2/utils/protocopy" "github.com/lavanet/lava/v2/utils/rand" @@ -227,7 +228,15 @@ func (rpccs *RPCConsumerServer) sendRelayWithRetries(ctx context.Context, retrie success := false var err error relayProcessor := NewRelayProcessor(ctx, lavasession.NewUsedProviders(nil), 1, protocolMessage, rpccs.consumerConsistency, "-init-", "", rpccs.debugRelays, rpccs.rpcConsumerLogs, rpccs, rpccs.disableNodeErrorRetry, rpccs.relayRetriesManager) + usedProvidersResets := 1 for i := 0; i < retries; i++ { + // Check if we even have enough providers to communicate with them all. + // If we have 1 provider we will reset the used providers always. + // Instead of spamming no pairing available on bootstrap + if ((i + 1) * usedProvidersResets) > rpccs.consumerSessionManager.GetNumberOfValidProviders() { + usedProvidersResets++ + relayProcessor.GetUsedProviders().ClearUnwanted() + } err = rpccs.sendRelayToProvider(ctx, protocolMessage, "-init-", "", relayProcessor, nil) if lavasession.PairingListEmptyError.Is(err) { // we don't have pairings anymore, could be related to unwanted providers @@ -1405,19 +1414,40 @@ func (rpccs *RPCConsumerServer) appendHeadersToRelayResult(ctx context.Context, relayResult.Reply.Metadata = append(relayResult.Reply.Metadata, erroredProvidersMD) } - currentReportedProviders := rpccs.consumerSessionManager.GetReportedProviders(uint64(relayResult.Request.RelaySession.Epoch)) - if len(currentReportedProviders) > 0 { - reportedProvidersArray := make([]string, len(currentReportedProviders)) - for idx, providerAddress := range currentReportedProviders { - reportedProvidersArray[idx] = providerAddress.Address + nodeErrors := relayProcessor.nodeErrors() + if len(nodeErrors) > 0 { + nodeErrorHeaderString := "" + for _, nodeError := range nodeErrors { + nodeErrorHeaderString += fmt.Sprintf("%s: %s,", nodeError.GetProvider(), string(nodeError.Reply.Data)) } - reportedProvidersString := fmt.Sprintf("%v", reportedProvidersArray) - reportedProvidersMD := pairingtypes.Metadata{ - Name: common.REPORTED_PROVIDERS_HEADER_NAME, - Value: reportedProvidersString, + relayResult.Reply.Metadata = append(relayResult.Reply.Metadata, + pairingtypes.Metadata{ + Name: common.NODE_ERRORS_PROVIDERS_HEADER_NAME, + Value: nodeErrorHeaderString, + }) + } + + if relayResult.Request != nil && relayResult.Request.RelaySession != nil { + currentReportedProviders := rpccs.consumerSessionManager.GetReportedProviders(uint64(relayResult.Request.RelaySession.Epoch)) + if len(currentReportedProviders) > 0 { + reportedProvidersArray := make([]string, len(currentReportedProviders)) + for idx, providerAddress := range currentReportedProviders { + reportedProvidersArray[idx] = providerAddress.Address + } + reportedProvidersString := fmt.Sprintf("%v", reportedProvidersArray) + reportedProvidersMD := pairingtypes.Metadata{ + Name: common.REPORTED_PROVIDERS_HEADER_NAME, + Value: reportedProvidersString, + } + relayResult.Reply.Metadata = append(relayResult.Reply.Metadata, reportedProvidersMD) } - relayResult.Reply.Metadata = append(relayResult.Reply.Metadata, reportedProvidersMD) } + + version := pairingtypes.Metadata{ + Name: common.LAVAP_VERSION_HEADER_NAME, + Value: upgrade.GetCurrentVersion().ConsumerVersion, + } + relayResult.Reply.Metadata = append(relayResult.Reply.Metadata, version) } relayResult.Reply.Metadata = append(relayResult.Reply.Metadata, metadataReply...)