Skip to content

Commit

Permalink
Tighter sync strategy for the broker
Browse files Browse the repository at this point in the history
We were not locking the mutexes when reading from the maps and this
could result in races as well.

Now we also check if the session still exists before modifying the map,
otherwise we would just recreate it accidentaly

Some variables were renamed for more consistency between the functions
and their meaning
  • Loading branch information
denisonbarbosa committed Aug 1, 2023
1 parent b9c9dde commit 0ff9466
Showing 1 changed file with 55 additions and 26 deletions.
81 changes: 55 additions & 26 deletions internal/brokers/examplebroker.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ type isAuthorizedCtx struct {

type exampleBroker struct {
currentSessions map[string]sessionInfo
currentSessionsMu sync.Mutex
currentSessionsMu sync.RWMutex
userLastSelectedMode map[string]string
userLastSelectedModeMu sync.Mutex
isAuthorizedCalls map[string]isAuthorizedCtx
Expand Down Expand Up @@ -77,7 +77,7 @@ const (
func newExampleBroker(name string) (b *exampleBroker, fullName, brandIcon string) {
return &exampleBroker{
currentSessions: make(map[string]sessionInfo),
currentSessionsMu: sync.Mutex{},
currentSessionsMu: sync.RWMutex{},
userLastSelectedMode: make(map[string]string),
userLastSelectedModeMu: sync.Mutex{},
isAuthorizedCalls: make(map[string]isAuthorizedCtx),
Expand All @@ -99,9 +99,9 @@ func (b *exampleBroker) NewSession(ctx context.Context, username, lang string) (

// GetAuthenticationModes returns the list of supported authentication modes for the selected broker depending on session info.
func (b *exampleBroker) GetAuthenticationModes(ctx context.Context, sessionID string, supportedUILayouts []map[string]string) (authenticationModes []map[string]string, err error) {
session, ok := b.currentSessions[sessionID]
if !ok {
return nil, fmt.Errorf("%q is not a current transaction", sessionID)
sessionInfo, err := b.sessionInfo(sessionID)
if err != nil {
return nil, err
}

//var candidatesAuthenticationModes []map[string]string
Expand Down Expand Up @@ -132,12 +132,12 @@ func (b *exampleBroker) GetAuthenticationModes(ctx context.Context, sessionID st
}
}
if slices.Contains(supportedEntries, "chars") && layout["wait"] != "" {
allModes[fmt.Sprintf("entry_or_wait_for_%s_gmail.com", session.username)] = map[string]string{
"selection_label": fmt.Sprintf("Send URL to %[email protected]", session.username),
"email": fmt.Sprintf("%[email protected]", session.username),
allModes[fmt.Sprintf("entry_or_wait_for_%s_gmail.com", sessionInfo.username)] = map[string]string{
"selection_label": fmt.Sprintf("Send URL to %[email protected]", sessionInfo.username),
"email": fmt.Sprintf("%[email protected]", sessionInfo.username),
"ui": mapToJSON(map[string]string{
"type": "form",
"label": fmt.Sprintf("Click on the link received at %[email protected] or enter the code:", session.username),
"label": fmt.Sprintf("Click on the link received at %[email protected] or enter the code:", sessionInfo.username),
"entry": "chars",
"wait": "true",
}),
Expand Down Expand Up @@ -216,10 +216,12 @@ func (b *exampleBroker) GetAuthenticationModes(ctx context.Context, sessionID st
}

// Sort in preference order. We want by default password as first and potentially last selection too.
lastSelection := b.userLastSelectedMode[session.username]
b.userLastSelectedModeMu.Lock()
lastSelection := b.userLastSelectedMode[sessionInfo.username]
if _, exists := allModes[lastSelection]; !exists {
lastSelection = ""
}
b.userLastSelectedModeMu.Unlock()

var allNames []string
for n := range allModes {
Expand All @@ -242,20 +244,25 @@ func (b *exampleBroker) GetAuthenticationModes(ctx context.Context, sessionID st
"label": authMode["selection_label"],
})
}
session.allModes = allModes
sessionInfo.allModes = allModes

// Checks if the session was ended in the meantime, otherwise we would just accidentally create a new one.
if _, err := b.sessionInfo(sessionID); err != nil {
return nil, err
}

b.currentSessionsMu.Lock()
b.currentSessions[sessionID] = session
b.currentSessionsMu.Unlock()
defer b.currentSessionsMu.Unlock()
b.currentSessions[sessionID] = sessionInfo

return authenticationModes, nil
}

func (b *exampleBroker) SelectAuthenticationMode(ctx context.Context, sessionID, authenticationModeName string) (uiLayoutInfo map[string]string, err error) {
// Ensure session ID is an active one.
sessionInfo, inprogress := b.currentSessions[sessionID]
if !inprogress {
return nil, fmt.Errorf("%s is not a current transaction", sessionID)
sessionInfo, err := b.sessionInfo(sessionID)
if err != nil {
return nil, err
}

authenticationMode, exists := sessionInfo.allModes[authenticationModeName]
Expand All @@ -282,18 +289,24 @@ func (b *exampleBroker) SelectAuthenticationMode(ctx context.Context, sessionID,

// Store selected mode
sessionInfo.selectedMode = authenticationModeName

// Checks if the session was ended in the meantime, otherwise we would just accidentally create a new one.
if _, err = b.sessionInfo(sessionID); err != nil {
return nil, err
}

b.currentSessionsMu.Lock()
defer b.currentSessionsMu.Unlock()
b.currentSessions[sessionID] = sessionInfo
b.currentSessionsMu.Unlock()

return uiLayoutInfo, nil
}

// IsAuthorized evaluates the provided authenticationData and returns the authorisation level of the user.
func (b *exampleBroker) IsAuthorized(ctx context.Context, sessionID, authenticationData string) (access, infoUser string, err error) {
sessionInfo, inprogress := b.currentSessions[sessionID]
if !inprogress {
return "", "", fmt.Errorf("%s is not a current transaction", sessionID)
sessionInfo, err := b.sessionInfo(sessionID)
if err != nil {
return "", "", err
}

//authenticationData = decryptAES([]byte(brokerEncryptionKey), authenticationData)
Expand Down Expand Up @@ -433,26 +446,31 @@ func (b *exampleBroker) handleIsAuthorized(ctx context.Context, sessionInfo sess

// EndSession ends the requested session and triggers the necessary clean up steps, if any.
func (b *exampleBroker) EndSession(ctx context.Context, sessionID string) error {
if _, exists := b.currentSessions[sessionID]; !exists {
return fmt.Errorf("%q is not an active session", sessionID)
if _, err := b.sessionInfo(sessionID); err != nil {
return err
}

// Checks if there is a isAuthorizedCall running for this session and cancels it before ending the session.
if _, exists := b.isAuthorizedCalls[sessionID]; exists {
b.CancelIsAuthorized(ctx, sessionID)
}

b.currentSessionsMu.Lock()
defer b.currentSessionsMu.Unlock()
delete(b.currentSessions, sessionID)
b.currentSessionsMu.Unlock()
return nil
}

// CancelIsAuthorized cancels the IsAuthorized request for the specified session.
// If there is no pending IsAuthorized call for the session, this is a no-op.
func (b *exampleBroker) CancelIsAuthorized(ctx context.Context, sessionID string) {
b.isAuthorizedCallsMu.Lock()
defer b.isAuthorizedCallsMu.Unlock()
if _, exists := b.isAuthorizedCalls[sessionID]; !exists {
return
}
b.isAuthorizedCalls[sessionID].cancelFunc()

b.isAuthorizedCallsMu.Lock()
delete(b.isAuthorizedCalls, sessionID)
b.isAuthorizedCallsMu.Unlock()
}

func mapToJSON(input map[string]string) string {
Expand Down Expand Up @@ -514,3 +532,14 @@ func decryptAES(key []byte, ct string) string {

return string(pt[:])
}

// sessionInfo returns the session information for the specified session ID or an error if the session is not active.
func (b *exampleBroker) sessionInfo(sessionID string) (sessionInfo, error) {
b.currentSessionsMu.RLock()
defer b.currentSessionsMu.RUnlock()
session, active := b.currentSessions[sessionID]
if !active {
return sessionInfo{}, fmt.Errorf("%s is not a current transaction", sessionID)
}
return session, nil
}

0 comments on commit 0ff9466

Please sign in to comment.