diff --git a/protocol/rpcconsumer/policies_map.go b/protocol/rpcconsumer/policies_map.go index 543b71330b..d70d2de3da 100644 --- a/protocol/rpcconsumer/policies_map.go +++ b/protocol/rpcconsumer/policies_map.go @@ -26,3 +26,22 @@ func (sm *syncMapPolicyUpdaters) Load(key string) (ret *updaters.PolicyUpdater, } return ret, true } + +// LoadOrStore returns the existing value for the key if present. +// Otherwise, it stores and returns the given value. +// The loaded result is true if the value was loaded, false if stored. +// The function returns the value that was loaded or stored. +func (sm *syncMapPolicyUpdaters) LoadOrStore(key string, value *updaters.PolicyUpdater) (ret *updaters.PolicyUpdater, loaded bool) { + actual, loaded := sm.localMap.LoadOrStore(key, value) + if loaded { + // loaded from map + ret, loaded = actual.(*updaters.PolicyUpdater) + if !loaded { + utils.LavaFormatFatal("invalid usage of syncmap, could not cast result into a PolicyUpdater", nil) + } + return ret, loaded + } + + // stored in map + return value, false +} diff --git a/protocol/rpcconsumer/rpcconsumer.go b/protocol/rpcconsumer/rpcconsumer.go index cd5e9eb4e7..67f4a24461 100644 --- a/protocol/rpcconsumer/rpcconsumer.go +++ b/protocol/rpcconsumer/rpcconsumer.go @@ -205,14 +205,13 @@ func (rpcc *RPCConsumer) Start(ctx context.Context, options *rpcConsumerStartOpt } chainID := rpcEndpoint.ChainID // create policyUpdaters per chain - if policyUpdater, ok := policyUpdaters.Load(rpcEndpoint.ChainID); ok { + newPolicyUpdater := updaters.NewPolicyUpdater(chainID, consumerStateTracker, consumerAddr.String(), chainParser, *rpcEndpoint) + if policyUpdater, ok := policyUpdaters.LoadOrStore(chainID, newPolicyUpdater); ok { err := policyUpdater.AddPolicySetter(chainParser, *rpcEndpoint) if err != nil { errCh <- err return utils.LavaFormatError("failed adding policy setter", err) } - } else { - policyUpdaters.Store(rpcEndpoint.ChainID, updaters.NewPolicyUpdater(chainID, consumerStateTracker, consumerAddr.String(), chainParser, *rpcEndpoint)) } err = statetracker.RegisterForSpecUpdatesOrSetStaticSpec(ctx, chainParser, options.cmdFlags.StaticSpecPath, *rpcEndpoint, rpcc.consumerStateTracker)