Skip to content

Commit

Permalink
Refactor IsAuthenticated return values
Browse files Browse the repository at this point in the history
By generating the message ourselves (only relying on strings), we had
problems when the error messages invalidated the JSON parsing due to ""
in the message. Now that we have a type that implements json.Marshaler,
we can avoid doing the error messages manually and should rely on it
instead.
  • Loading branch information
denisonbarbosa committed Jul 1, 2024
1 parent e1ccf92 commit dce370a
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 26 deletions.
44 changes: 26 additions & 18 deletions internal/broker/broker.go
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,8 @@ func (b *Broker) IsAuthenticated(sessionID, authenticationData string) (string,
defer b.CancelIsAuthenticated(sessionID)

authDone := make(chan struct{})
var access, data string
var access string
var data json.Marshaler
go func() {
access, data = b.handleIsAuthenticated(ctx, &session, authData)
close(authDone)
Expand All @@ -372,15 +373,17 @@ func (b *Broker) IsAuthenticated(sessionID, authenticationData string) (string,
select {
case <-authDone:
case <-ctx.Done():
return AuthCancelled, `{"message": "authentication request cancelled"}`, ctx.Err()
// We can ignore the error here since the message is constant.
msg, _ := json.Marshal(errorMessage{Message: "authentication request cancelled"})
return AuthCancelled, string(msg), ctx.Err()
}

switch access {
case AuthRetry:
session.attemptsPerMode[session.selectedMode]++
if session.attemptsPerMode[session.selectedMode] == maxAuthAttempts {
access = AuthDenied
data = `{"message": "maximum number of attempts reached"}`
data = errorMessage{Message: "maximum number of attempts reached"}
}

case AuthNext:
Expand All @@ -390,14 +393,19 @@ func (b *Broker) IsAuthenticated(sessionID, authenticationData string) (string,
if err = b.updateSession(sessionID, session); err != nil {
return AuthDenied, "", err
}
return access, data, nil

encoded, err := data.MarshalJSON()
if err != nil {
return AuthDenied, "", fmt.Errorf("could not marshal data: %v", err)
}
return access, string(encoded), nil
}

func (b *Broker) handleIsAuthenticated(ctx context.Context, session *sessionInfo, authData map[string]string) (access, data string) {
func (b *Broker) handleIsAuthenticated(ctx context.Context, session *sessionInfo, authData map[string]string) (access string, data json.Marshaler) {
// Decrypt challenge if present.
challenge, err := decodeRawChallenge(b.privateKey, authData["challenge"])
if err != nil {
return AuthRetry, fmt.Sprintf(`{"message": "could not decode challenge: %v"}`, err)
return AuthRetry, errorMessage{Message: fmt.Sprintf("could not decode challenge: %v", err)}
}

var authInfo authCachedInfo
Expand All @@ -406,62 +414,62 @@ func (b *Broker) handleIsAuthenticated(ctx context.Context, session *sessionInfo
case "device_auth":
response, ok := session.authInfo["response"].(*oauth2.DeviceAuthResponse)
if !ok {
return AuthDenied, `{"message": "could not get required response"}`
return AuthDenied, errorMessage{Message: "could not get required response"}
}

t, err := b.auth.oauthCfg.DeviceAccessToken(ctx, response, b.providerInfo.AuthOptions()...)
if err != nil {
return AuthRetry, fmt.Sprintf(`{"message": "could not authenticate user: %v"}`, err)
return AuthRetry, errorMessage{Message: fmt.Sprintf("could not authenticate user: %v", err)}
}

rawIDToken, ok := t.Extra("id_token").(string)
if !ok {
return AuthDenied, `{"message": "could not get id_token"}`
return AuthDenied, errorMessage{Message: "could not get id_token"}
}

session.authInfo["auth_info"] = authCachedInfo{Token: t, RawIDToken: rawIDToken}
return AuthNext, ""
return AuthNext, errorMessage{}

case "password":
authInfo, offline, err = b.loadAuthInfo(session, challenge)
if err != nil {
return AuthRetry, fmt.Sprintf(`{"message": "could not authenticate user: %v"}`, err)
return AuthRetry, errorMessage{Message: fmt.Sprintf("could not authenticate user: %v", err)}
}

if session.mode == "passwd" {
session.authInfo["auth_info"] = authInfo
return AuthNext, ""
return AuthNext, errorMessage{}
}

case "newpassword":
if challenge == "" {
return AuthRetry, `{"message": "challenge must not be empty"}`
return AuthRetry, errorMessage{Message: "challenge must not be empty"}
}

var ok bool
// This mode must always come after a authentication mode, so it has to have an auth_info.
authInfo, ok = session.authInfo["auth_info"].(authCachedInfo)
if !ok {
return AuthDenied, `{"message": "could not get required information"}`
return AuthDenied, errorMessage{Message: "could not get required information"}
}
}

if authInfo.UserInfo.Name == "" {
authInfo.UserInfo, err = b.fetchUserInfo(ctx, session, &authInfo)
if err != nil {
return AuthDenied, fmt.Sprintf(`{"message": "could not get user info: %v"}`, err)
return AuthDenied, errorMessage{Message: fmt.Sprintf("could not get user info: %v", err)}
}
}

if offline {
return AuthGranted, fmt.Sprintf(`{"userinfo": %s}`, authInfo.UserInfo)
return AuthGranted, authInfo.UserInfo
}

if err := b.cacheAuthInfo(session, authInfo, challenge); err != nil {
return AuthRetry, fmt.Sprintf(`{"message": "could not update cached info: %v"}`, err)
return AuthRetry, errorMessage{Message: fmt.Sprintf("could not update cached info: %v", err)}
}

return AuthGranted, fmt.Sprintf(`{"userinfo": %s}`, authInfo.UserInfo)
return AuthGranted, authInfo.UserInfo
}

func (b *Broker) startAuthenticate(sessionID string) (context.Context, error) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@ data: |-
"gecos": "saved-user",
"dir": "/home/[email protected]",
"shell": "/usr/bin/bash",
"groups": [{"name": "saved-remote-group", "gid": "12345"}, {"name": "saved-local-group", "gid": ""}]
"groups": [ {"name": "remote-group", "ugid": "12345"}, {"name": "linux-local-group", "ugid": ""} ]
}}
err: <nil>
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
access: denied
data: '{"message": "could not get user info: could not fetch user info: could not verify token: oidc: failed to unmarshal claims: invalid character ''\u008a'' looking for beginning of value"}'
data: '{"message": "could not get user info: could not fetch user info: could not verify token: oidc: failed to unmarshal claims: invalid character ''\\u008a'' looking for beginning of value"}'
err: <nil>
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
access: retry
data: |-
{"message": "could not authenticate user: could not load cached info: could not refresh token: oauth2: cannot fetch token: 400 Bad Request
Response: "}
data: '{"message": "could not authenticate user: could not load cached info: could not refresh token: oauth2: cannot fetch token: 400 Bad Request\nResponse: "}'
err: <nil>
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
access: retry
data: |-
{"message": "could not authenticate user: oauth2: cannot fetch token: 503 Service Unavailable
Response: "}
data: '{"message": "could not authenticate user: oauth2: cannot fetch token: 503 Service Unavailable\nResponse: "}'
err: <nil>

0 comments on commit dce370a

Please sign in to comment.