-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Tighter sync strategy for the broker
We were not locking the mutexes when reading from the maps and this could result in races as well. Now we also check if the session still exists before modifying the map, otherwise we would just recreate it accidentaly Some variables were renamed for more consistency between the functions and their meaning
- Loading branch information
1 parent
b9c9dde
commit 0ff9466
Showing
1 changed file
with
55 additions
and
26 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,7 +30,7 @@ type isAuthorizedCtx struct { | |
|
||
type exampleBroker struct { | ||
currentSessions map[string]sessionInfo | ||
currentSessionsMu sync.Mutex | ||
currentSessionsMu sync.RWMutex | ||
userLastSelectedMode map[string]string | ||
userLastSelectedModeMu sync.Mutex | ||
isAuthorizedCalls map[string]isAuthorizedCtx | ||
|
@@ -77,7 +77,7 @@ const ( | |
func newExampleBroker(name string) (b *exampleBroker, fullName, brandIcon string) { | ||
return &exampleBroker{ | ||
currentSessions: make(map[string]sessionInfo), | ||
currentSessionsMu: sync.Mutex{}, | ||
currentSessionsMu: sync.RWMutex{}, | ||
userLastSelectedMode: make(map[string]string), | ||
userLastSelectedModeMu: sync.Mutex{}, | ||
isAuthorizedCalls: make(map[string]isAuthorizedCtx), | ||
|
@@ -99,9 +99,9 @@ func (b *exampleBroker) NewSession(ctx context.Context, username, lang string) ( | |
|
||
// GetAuthenticationModes returns the list of supported authentication modes for the selected broker depending on session info. | ||
func (b *exampleBroker) GetAuthenticationModes(ctx context.Context, sessionID string, supportedUILayouts []map[string]string) (authenticationModes []map[string]string, err error) { | ||
session, ok := b.currentSessions[sessionID] | ||
if !ok { | ||
return nil, fmt.Errorf("%q is not a current transaction", sessionID) | ||
sessionInfo, err := b.sessionInfo(sessionID) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
//var candidatesAuthenticationModes []map[string]string | ||
|
@@ -132,12 +132,12 @@ func (b *exampleBroker) GetAuthenticationModes(ctx context.Context, sessionID st | |
} | ||
} | ||
if slices.Contains(supportedEntries, "chars") && layout["wait"] != "" { | ||
allModes[fmt.Sprintf("entry_or_wait_for_%s_gmail.com", session.username)] = map[string]string{ | ||
"selection_label": fmt.Sprintf("Send URL to %[email protected]", session.username), | ||
"email": fmt.Sprintf("%[email protected]", session.username), | ||
allModes[fmt.Sprintf("entry_or_wait_for_%s_gmail.com", sessionInfo.username)] = map[string]string{ | ||
"selection_label": fmt.Sprintf("Send URL to %[email protected]", sessionInfo.username), | ||
"email": fmt.Sprintf("%[email protected]", sessionInfo.username), | ||
"ui": mapToJSON(map[string]string{ | ||
"type": "form", | ||
"label": fmt.Sprintf("Click on the link received at %[email protected] or enter the code:", session.username), | ||
"label": fmt.Sprintf("Click on the link received at %[email protected] or enter the code:", sessionInfo.username), | ||
"entry": "chars", | ||
"wait": "true", | ||
}), | ||
|
@@ -216,10 +216,12 @@ func (b *exampleBroker) GetAuthenticationModes(ctx context.Context, sessionID st | |
} | ||
|
||
// Sort in preference order. We want by default password as first and potentially last selection too. | ||
lastSelection := b.userLastSelectedMode[session.username] | ||
b.userLastSelectedModeMu.Lock() | ||
lastSelection := b.userLastSelectedMode[sessionInfo.username] | ||
if _, exists := allModes[lastSelection]; !exists { | ||
lastSelection = "" | ||
} | ||
b.userLastSelectedModeMu.Unlock() | ||
|
||
var allNames []string | ||
for n := range allModes { | ||
|
@@ -242,20 +244,25 @@ func (b *exampleBroker) GetAuthenticationModes(ctx context.Context, sessionID st | |
"label": authMode["selection_label"], | ||
}) | ||
} | ||
session.allModes = allModes | ||
sessionInfo.allModes = allModes | ||
|
||
// Checks if the session was ended in the meantime, otherwise we would just accidentally create a new one. | ||
if _, err := b.sessionInfo(sessionID); err != nil { | ||
return nil, err | ||
} | ||
|
||
b.currentSessionsMu.Lock() | ||
b.currentSessions[sessionID] = session | ||
b.currentSessionsMu.Unlock() | ||
defer b.currentSessionsMu.Unlock() | ||
b.currentSessions[sessionID] = sessionInfo | ||
|
||
return authenticationModes, nil | ||
} | ||
|
||
func (b *exampleBroker) SelectAuthenticationMode(ctx context.Context, sessionID, authenticationModeName string) (uiLayoutInfo map[string]string, err error) { | ||
// Ensure session ID is an active one. | ||
sessionInfo, inprogress := b.currentSessions[sessionID] | ||
if !inprogress { | ||
return nil, fmt.Errorf("%s is not a current transaction", sessionID) | ||
sessionInfo, err := b.sessionInfo(sessionID) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
authenticationMode, exists := sessionInfo.allModes[authenticationModeName] | ||
|
@@ -282,18 +289,24 @@ func (b *exampleBroker) SelectAuthenticationMode(ctx context.Context, sessionID, | |
|
||
// Store selected mode | ||
sessionInfo.selectedMode = authenticationModeName | ||
|
||
// Checks if the session was ended in the meantime, otherwise we would just accidentally create a new one. | ||
if _, err = b.sessionInfo(sessionID); err != nil { | ||
return nil, err | ||
} | ||
|
||
b.currentSessionsMu.Lock() | ||
defer b.currentSessionsMu.Unlock() | ||
b.currentSessions[sessionID] = sessionInfo | ||
b.currentSessionsMu.Unlock() | ||
|
||
return uiLayoutInfo, nil | ||
} | ||
|
||
// IsAuthorized evaluates the provided authenticationData and returns the authorisation level of the user. | ||
func (b *exampleBroker) IsAuthorized(ctx context.Context, sessionID, authenticationData string) (access, infoUser string, err error) { | ||
sessionInfo, inprogress := b.currentSessions[sessionID] | ||
if !inprogress { | ||
return "", "", fmt.Errorf("%s is not a current transaction", sessionID) | ||
sessionInfo, err := b.sessionInfo(sessionID) | ||
if err != nil { | ||
return "", "", err | ||
} | ||
|
||
//authenticationData = decryptAES([]byte(brokerEncryptionKey), authenticationData) | ||
|
@@ -433,26 +446,31 @@ func (b *exampleBroker) handleIsAuthorized(ctx context.Context, sessionInfo sess | |
|
||
// EndSession ends the requested session and triggers the necessary clean up steps, if any. | ||
func (b *exampleBroker) EndSession(ctx context.Context, sessionID string) error { | ||
if _, exists := b.currentSessions[sessionID]; !exists { | ||
return fmt.Errorf("%q is not an active session", sessionID) | ||
if _, err := b.sessionInfo(sessionID); err != nil { | ||
return err | ||
} | ||
|
||
// Checks if there is a isAuthorizedCall running for this session and cancels it before ending the session. | ||
if _, exists := b.isAuthorizedCalls[sessionID]; exists { | ||
b.CancelIsAuthorized(ctx, sessionID) | ||
} | ||
|
||
b.currentSessionsMu.Lock() | ||
defer b.currentSessionsMu.Unlock() | ||
delete(b.currentSessions, sessionID) | ||
b.currentSessionsMu.Unlock() | ||
return nil | ||
} | ||
|
||
// CancelIsAuthorized cancels the IsAuthorized request for the specified session. | ||
// If there is no pending IsAuthorized call for the session, this is a no-op. | ||
func (b *exampleBroker) CancelIsAuthorized(ctx context.Context, sessionID string) { | ||
b.isAuthorizedCallsMu.Lock() | ||
defer b.isAuthorizedCallsMu.Unlock() | ||
if _, exists := b.isAuthorizedCalls[sessionID]; !exists { | ||
return | ||
} | ||
b.isAuthorizedCalls[sessionID].cancelFunc() | ||
|
||
b.isAuthorizedCallsMu.Lock() | ||
delete(b.isAuthorizedCalls, sessionID) | ||
b.isAuthorizedCallsMu.Unlock() | ||
} | ||
|
||
func mapToJSON(input map[string]string) string { | ||
|
@@ -514,3 +532,14 @@ func decryptAES(key []byte, ct string) string { | |
|
||
return string(pt[:]) | ||
} | ||
|
||
// sessionInfo returns the session information for the specified session ID or an error if the session is not active. | ||
func (b *exampleBroker) sessionInfo(sessionID string) (sessionInfo, error) { | ||
b.currentSessionsMu.RLock() | ||
defer b.currentSessionsMu.RUnlock() | ||
session, active := b.currentSessions[sessionID] | ||
if !active { | ||
return sessionInfo{}, fmt.Errorf("%s is not a current transaction", sessionID) | ||
} | ||
return session, nil | ||
} |