Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: PRT - adding stickiness header #1942

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
2 changes: 2 additions & 0 deletions protocol/common/endpoints.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ const (
FORCE_CACHE_REFRESH_HEADER_NAME = "lava-force-cache-refresh"
LAVA_DEBUG_RELAY = "lava-debug-relay"
LAVA_LB_UNIQUE_ID_HEADER = "lava-lb-unique-id"
STICKINESS_HEADER_NAME = "lava-stickiness"
// send http request to /lava/health to see if the process is up - (ret code 200)
DEFAULT_HEALTH_PATH = "/lava/health"
MAXIMUM_ALLOWED_TIMEOUT_EXTEND_MULTIPLIER_BY_THE_CONSUMER = 4
Expand All @@ -51,6 +52,7 @@ var SPECIAL_LAVA_DIRECTIVE_HEADERS = map[string]struct{}{
EXTENSION_OVERRIDE_HEADER_NAME: {},
FORCE_CACHE_REFRESH_HEADER_NAME: {},
LAVA_DEBUG_RELAY: {},
STICKINESS_HEADER_NAME: {},
}

type UserData struct {
Expand Down
55 changes: 44 additions & 11 deletions protocol/lavasession/consumer_session_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package lavasession
import (
"context"
"fmt"
"slices"
"sort"
"strings"
"sync"
Expand Down Expand Up @@ -36,6 +37,7 @@ type ConsumerSessionManager struct {
rpcEndpoint *RPCEndpoint // used to filter out endpoints
lock sync.RWMutex
pairing map[string]*ConsumerSessionsWithProvider // key == provider address
stickySessions *StickySessionStore
currentEpoch uint64
numberOfResets uint64

Expand Down Expand Up @@ -79,6 +81,7 @@ func (csm *ConsumerSessionManager) UpdateAllProviders(epoch uint64, pairingList
pairingListLength := len(pairingList)
// TODO: we can block updating until some of the probing is done, this can prevent failed attempts on epoch change when we have no information on the providers,
// and all of them are new (less effective on big pairing lists or a process that runs for a few epochs)

defer func() {
// run this after done updating pairing
time.Sleep(time.Duration(rand.Intn(500)) * time.Millisecond) // sleep up to 500ms in order to scatter different chains probe triggers
Expand All @@ -88,7 +91,8 @@ func (csm *ConsumerSessionManager) UpdateAllProviders(epoch uint64, pairingList
csm.lock.Lock() // start by locking the class lock.
defer csm.lock.Unlock() // we defer here so in case we return an error it will unlock automatically.

if epoch <= csm.atomicReadCurrentEpoch() { // sentry shouldn't update an old epoch or current epoch
previousEpoch := csm.atomicReadCurrentEpoch()
if epoch <= previousEpoch { // sentry shouldn't update an old epoch or current epoch
return utils.LavaFormatError("trying to update provider list for older epoch", nil, utils.Attribute{Key: "epoch", Value: epoch}, utils.Attribute{Key: "currentEpoch", Value: csm.atomicReadCurrentEpoch()})
}
// Update Epoch.
Expand Down Expand Up @@ -117,6 +121,9 @@ func (csm *ConsumerSessionManager) UpdateAllProviders(epoch uint64, pairingList
go csm.consumerMetricsManager.ResetSessionRelatedMetrics()
go csm.providerOptimizer.UpdateWeights(CalcWeightsByStake(pairingList), epoch)

// Clean up expired sticky sessions
csm.stickySessions.DeleteOldSessions(previousEpoch)

utils.LavaFormatDebug("updated providers", utils.Attribute{Key: "epoch", Value: epoch}, utils.Attribute{Key: "spec", Value: csm.rpcEndpoint.Key()})
return nil
}
Expand Down Expand Up @@ -402,8 +409,8 @@ func (csm *ConsumerSessionManager) getValidAddressesLengthForExtensionOrAddon(ad
return len(csm.getValidAddresses(addon, extensions))
}

func (csm *ConsumerSessionManager) getSessionWithProviderOrError(usedProviders UsedProvidersInf, tempIgnoredProviders *ignoredProviders, cuNeededForSession uint64, requestedBlock int64, addon string, extensionNames []string, stateful uint32, virtualEpoch uint64) (sessionWithProviderMap SessionWithProviderMap, err error) {
sessionWithProviderMap, err = csm.getValidConsumerSessionsWithProvider(tempIgnoredProviders, cuNeededForSession, requestedBlock, addon, extensionNames, stateful, virtualEpoch)
func (csm *ConsumerSessionManager) getSessionWithProviderOrError(usedProviders UsedProvidersInf, tempIgnoredProviders *ignoredProviders, cuNeededForSession uint64, requestedBlock int64, addon string, extensionNames []string, stateful uint32, virtualEpoch uint64, stickiness string) (sessionWithProviderMap SessionWithProviderMap, err error) {
sessionWithProviderMap, err = csm.getValidConsumerSessionsWithProvider(tempIgnoredProviders, cuNeededForSession, requestedBlock, addon, extensionNames, stateful, virtualEpoch, stickiness)
if err != nil {
if PairingListEmptyError.Is(err) {
// got no pairing available, try to recover a session from the currently banned providers
Expand All @@ -422,7 +429,7 @@ func (csm *ConsumerSessionManager) getSessionWithProviderOrError(usedProviders U

// GetSessions will return a ConsumerSession, given cu needed for that session.
// The user can also request specific providers to not be included in the search for a session.
func (csm *ConsumerSessionManager) GetSessions(ctx context.Context, cuNeededForSession uint64, usedProviders UsedProvidersInf, requestedBlock int64, addon string, extensions []*spectypes.Extension, stateful uint32, virtualEpoch uint64) (
func (csm *ConsumerSessionManager) GetSessions(ctx context.Context, cuNeededForSession uint64, usedProviders UsedProvidersInf, requestedBlock int64, addon string, extensions []*spectypes.Extension, stateful uint32, virtualEpoch uint64, stickiness string) (
consumerSessionMap ConsumerSessionsMap, errRet error,
) {
// set usedProviders if they were chosen for this relay
Expand Down Expand Up @@ -450,7 +457,7 @@ func (csm *ConsumerSessionManager) GetSessions(ctx context.Context, cuNeededForS
}

// Get a valid consumerSessionsWithProvider
sessionWithProviderMap, err := csm.getSessionWithProviderOrError(usedProviders, tempIgnoredProviders, cuNeededForSession, requestedBlock, addon, extensionNames, stateful, virtualEpoch)
sessionWithProviderMap, err := csm.getSessionWithProviderOrError(usedProviders, tempIgnoredProviders, cuNeededForSession, requestedBlock, addon, extensionNames, stateful, virtualEpoch, stickiness)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -582,7 +589,7 @@ func (csm *ConsumerSessionManager) GetSessions(ctx context.Context, cuNeededForS
}

// If we do not have enough fetch more
sessionWithProviderMap, err = csm.getSessionWithProviderOrError(usedProviders, tempIgnoredProviders, cuNeededForSession, requestedBlock, addon, extensionNames, stateful, virtualEpoch)
sessionWithProviderMap, err = csm.getSessionWithProviderOrError(usedProviders, tempIgnoredProviders, cuNeededForSession, requestedBlock, addon, extensionNames, stateful, virtualEpoch, stickiness)
// If error exists but we have sessions, return them
if err != nil && len(sessions) != 0 {
return sessions, nil
Expand Down Expand Up @@ -620,12 +627,26 @@ func (csm *ConsumerSessionManager) getTopTenProvidersForStatefulCalls(validAddre
}

// Get a valid provider address.
func (csm *ConsumerSessionManager) getValidProviderAddresses(ignoredProvidersList map[string]struct{}, cu uint64, requestedBlock int64, addon string, extensions []string, stateful uint32) (addresses []string, err error) {
func (csm *ConsumerSessionManager) getValidProviderAddresses(ignoredProvidersList map[string]struct{}, cu uint64, requestedBlock int64, addon string, extensions []string, stateful uint32, stickiness string) (addresses []string, err error) {
// cs.Lock must be Rlocked here.
ignoredProvidersListLength := len(ignoredProvidersList)
validAddresses := csm.getValidAddresses(addon, extensions)
validAddressesLength := len(validAddresses)
totalValidLength := validAddressesLength - ignoredProvidersListLength

if stickysession, ok := csm.stickySessions.Get(stickiness); ok {
// Check if sticky session provider is still valid
providerValid := slices.Contains(validAddresses, stickysession.Provider)
if providerValid {
addresses = []string{stickysession.Provider}
utils.LavaFormatTrace("returning sticky session", utils.LogAttr("provider", stickysession.Provider), utils.LogAttr("id", stickiness))
return addresses, nil
} else {
utils.LavaFormatTrace("sticky session provider is no longer valid, deleting", utils.LogAttr("provider", stickysession.Provider), utils.LogAttr("id", stickiness))
csm.stickySessions.Delete(stickiness)
}
}

if totalValidLength <= 0 {
// check all ignored are actually valid addresses
ignoredProvidersListLength = 0
Expand All @@ -643,6 +664,8 @@ func (csm *ConsumerSessionManager) getValidProviderAddresses(ignoredProvidersLis
var providers []string
if stateful == common.CONSISTENCY_SELECT_ALL_PROVIDERS && csm.providerOptimizer.Strategy() != provideroptimizer.StrategyCost {
providers = csm.getTopTenProvidersForStatefulCalls(validAddresses, ignoredProvidersList)
} else if stickiness != "" {
providers = csm.providerOptimizer.ChooseProviderFromTopTier(validAddresses, ignoredProvidersList, cu, requestedBlock)
} else {
providers, _ = csm.providerOptimizer.ChooseProvider(validAddresses, ignoredProvidersList, cu, requestedBlock)
}
Expand All @@ -663,6 +686,15 @@ func (csm *ConsumerSessionManager) getValidProviderAddresses(ignoredProvidersLis
return addresses, err
}

// If stickiness is requested, store the first provider for future use
if stickiness != "" {
utils.LavaFormatTrace("setting sticky session", utils.LogAttr("provider", providers[0]), utils.LogAttr("id", stickiness))
csm.stickySessions.Set(stickiness, &StickySession{
Provider: providers[0],
Epoch: csm.atomicReadCurrentEpoch(),
})
return []string{providers[0]}, nil
}
return providers, nil
}

Expand All @@ -685,7 +717,7 @@ func (csm *ConsumerSessionManager) tryGetConsumerSessionWithProviderFromBlockedP
utils.LavaFormatDebug("Epoch changed between getValidConsumerSessionsWithProvider to tryGetConsumerSessionWithProviderFromBlockedProviderList getting pairing from new epoch list")
}
csm.lock.RUnlock() // unlock because getValidConsumerSessionsWithProvider is locking.
return csm.getValidConsumerSessionsWithProvider(ignoredProviders, cuNeededForSession, requestedBlock, addon, extensions, stateful, virtualEpoch)
return csm.getValidConsumerSessionsWithProvider(ignoredProviders, cuNeededForSession, requestedBlock, addon, extensions, stateful, virtualEpoch, "")
}

// if we got here we validated the epoch is still the same epoch as we expected and we need to fetch a session from the blocked provider list.
Expand Down Expand Up @@ -729,7 +761,7 @@ func (csm *ConsumerSessionManager) tryGetConsumerSessionWithProviderFromBlockedP
return nil, utils.LavaFormatError(csm.rpcEndpoint.ChainID+" could not get a provider address from blocked provider list", PairingListEmptyError, utils.LogAttr("csm.currentlyBlockedProviderAddresses", csm.currentlyBlockedProviderAddresses), utils.LogAttr("addons", addon), utils.LogAttr("extensions", extensions), utils.LogAttr("ignoredProviders", ignoredProviders.providers))
}

func (csm *ConsumerSessionManager) getValidConsumerSessionsWithProvider(ignoredProviders *ignoredProviders, cuNeededForSession uint64, requestedBlock int64, addon string, extensions []string, stateful uint32, virtualEpoch uint64) (sessionWithProviderMap SessionWithProviderMap, err error) {
func (csm *ConsumerSessionManager) getValidConsumerSessionsWithProvider(ignoredProviders *ignoredProviders, cuNeededForSession uint64, requestedBlock int64, addon string, extensions []string, stateful uint32, virtualEpoch uint64, stickiness string) (sessionWithProviderMap SessionWithProviderMap, err error) {
csm.lock.RLock()
defer csm.lock.RUnlock()

Expand All @@ -743,7 +775,7 @@ func (csm *ConsumerSessionManager) getValidConsumerSessionsWithProvider(ignoredP
}

// Fetch provider addresses
providerAddresses, err := csm.getValidProviderAddresses(ignoredProviders.providers, cuNeededForSession, requestedBlock, addon, extensions, stateful)
providerAddresses, err := csm.getValidProviderAddresses(ignoredProviders.providers, cuNeededForSession, requestedBlock, addon, extensions, stateful, stickiness)
if err != nil {
utils.LavaFormatDebug(csm.rpcEndpoint.ChainID+" could not get a provider addresses", utils.LogAttr("error", err))
return nil, err
Expand Down Expand Up @@ -792,7 +824,7 @@ func (csm *ConsumerSessionManager) getValidConsumerSessionsWithProvider(ignoredP
}

// If we do not have enough fetch more
providerAddresses, err = csm.getValidProviderAddresses(ignoredProviders.providers, cuNeededForSession, requestedBlock, addon, extensions, stateful)
providerAddresses, err = csm.getValidProviderAddresses(ignoredProviders.providers, cuNeededForSession, requestedBlock, addon, extensions, stateful, stickiness)

// If error exists but we have providers, return them
if err != nil && len(sessionWithProviderMap) != 0 {
Expand Down Expand Up @@ -1158,5 +1190,6 @@ func NewConsumerSessionManager(
csm.rpcEndpoint = rpcEndpoint
csm.providerOptimizer = providerOptimizer
csm.activeSubscriptionProvidersStorage = activeSubscriptionProvidersStorage
csm.stickySessions = NewStickySessionStore()
return csm
}
Loading
Loading