Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Override sub with federated_claims.user_id when dex is used #20683

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion cmd/argocd/commands/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"golang.org/x/oauth2"

"github.com/argoproj/argo-cd/v2/cmd/argocd/commands/headless"
"github.com/argoproj/argo-cd/v2/cmd/argocd/commands/utils"
argocdclient "github.com/argoproj/argo-cd/v2/pkg/apiclient"
sessionpkg "github.com/argoproj/argo-cd/v2/pkg/apiclient/session"
settingspkg "github.com/argoproj/argo-cd/v2/pkg/apiclient/settings"
Expand Down Expand Up @@ -196,7 +197,7 @@ func userDisplayName(claims jwt.MapClaims) string {
if name := jwtutil.StringField(claims, "name"); name != "" {
return name
}
return jwtutil.StringField(claims, "sub")
return utils.GetUserIdentifier(claims)
}

// oauth2Login opens a browser, runs a temporary HTTP server to delegate OAuth2 login flow and
Expand Down
2 changes: 1 addition & 1 deletion cmd/argocd/commands/project_role.go
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ Create token succeeded for proj:test-project:test-role.
issuedAt, _ := jwt.IssuedAt(claims)
expiresAt := int64(jwt.Float64Field(claims, "exp"))
id := jwt.StringField(claims, "jti")
subject := jwt.StringField(claims, "sub")
subject := utils.GetUserIdentifier(claims)

if !outputTokenOnly {
fmt.Printf("Create token succeeded for %s.\n", subject)
Expand Down
19 changes: 19 additions & 0 deletions cmd/argocd/commands/utils/claims.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package utils

import (
"github.com/golang-jwt/jwt/v4"
)

// GetUserIdentifier returns a consistent user identifier, checking federated_claims.user_id when Dex is in use
func GetUserIdentifier(claims jwt.MapClaims) string {
if federatedClaims, ok := claims["federated_claims"].(map[string]interface{}); ok {
if userID, exists := federatedClaims["user_id"].(string); exists && userID != "" {
return userID
}
}
// Fallback to sub
if sub, ok := claims["sub"].(string); ok && sub != "" {
return sub
}
return ""
}
3 changes: 2 additions & 1 deletion server/rbacpolicy/rbacpolicy.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"github.com/golang-jwt/jwt/v4"
log "github.com/sirupsen/logrus"

"github.com/argoproj/argo-cd/v2/cmd/argocd/commands/utils"
"github.com/argoproj/argo-cd/v2/pkg/apis/application/v1alpha1"
applister "github.com/argoproj/argo-cd/v2/pkg/client/listers/application/v1alpha1"
jwtutil "github.com/argoproj/argo-cd/v2/util/jwt"
Expand Down Expand Up @@ -114,7 +115,7 @@ func (p *RBACPolicyEnforcer) EnforceClaims(claims jwt.Claims, rvals ...interface
return false
}

subject := jwtutil.StringField(mapClaims, "sub")
subject := utils.GetUserIdentifier(mapClaims)
// Check if the request is for an application resource. We have special enforcement which takes
// into consideration the project's token and group bindings
var runtimePolicy string
Expand Down
3 changes: 2 additions & 1 deletion server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ import (
"k8s.io/client-go/tools/cache"
"sigs.k8s.io/controller-runtime/pkg/client"

"github.com/argoproj/argo-cd/v2/cmd/argocd/commands/utils"
"github.com/argoproj/argo-cd/v2/common"
"github.com/argoproj/argo-cd/v2/pkg/apiclient"
accountpkg "github.com/argoproj/argo-cd/v2/pkg/apiclient/account"
Expand Down Expand Up @@ -1417,7 +1418,7 @@ func (a *ArgoCDServer) getClaims(ctx context.Context) (jwt.Claims, string, error
log.Errorf("error fetching user info endpoint: %v", err)
return claims, "", status.Errorf(codes.Internal, "invalid userinfo response")
}
if groupClaims["sub"] != userInfo["sub"] {
if utils.GetUserIdentifier(groupClaims) != utils.GetUserIdentifier(userInfo) {
return claims, "", status.Error(codes.Unknown, "subject of claims from user info endpoint didn't match subject of idToken, see https://openid.net/specs/openid-connect-core-1_0.html#UserInfo")
}
groupClaims["groups"] = userInfo["groups"]
Expand Down
23 changes: 14 additions & 9 deletions util/oidc/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
log "github.com/sirupsen/logrus"
"golang.org/x/oauth2"

"github.com/argoproj/argo-cd/v2/cmd/argocd/commands/utils"
"github.com/argoproj/argo-cd/v2/common"
"github.com/argoproj/argo-cd/v2/server/settings/oidc"
"github.com/argoproj/argo-cd/v2/util/cache"
Expand Down Expand Up @@ -402,9 +403,8 @@ func (a *ClientApp) HandleCallback(w http.ResponseWriter, r *http.Request) {
log.Errorf("cannot encrypt accessToken: %v (claims=%s)", err, claimsJSON)
return
}
sub := jwtutil.StringField(claims, "sub")
err = a.clientCache.Set(&cache.Item{
Key: formatAccessTokenCacheKey(sub),
Key: formatAccessTokenCacheKey(claims),
Object: encToken,
CacheActionOpts: cache.CacheActionOpts{
Expiration: getTokenExpiration(claims),
Expand Down Expand Up @@ -552,12 +552,12 @@ func createClaimsAuthenticationRequestParameter(requestedClaims map[string]*oidc

// GetUserInfo queries the IDP userinfo endpoint for claims
func (a *ClientApp) GetUserInfo(actualClaims jwt.MapClaims, issuerURL, userInfoPath string) (jwt.MapClaims, bool, error) {
sub := jwtutil.StringField(actualClaims, "sub")
sub := utils.GetUserIdentifier(actualClaims)
var claims jwt.MapClaims
var encClaims []byte

// in case we got it in the cache, we just return the item
clientCacheKey := formatUserInfoResponseCacheKey(sub)
clientCacheKey := formatUserInfoResponseCacheKey(actualClaims)
if err := a.clientCache.Get(clientCacheKey, &encClaims); err == nil {
claimsRaw, err := crypto.Decrypt(encClaims, a.encryptionKey)
if err != nil {
Expand All @@ -575,7 +575,7 @@ func (a *ClientApp) GetUserInfo(actualClaims jwt.MapClaims, issuerURL, userInfoP

// check if the accessToken for the user is still present
var encAccessToken []byte
err := a.clientCache.Get(formatAccessTokenCacheKey(sub), &encAccessToken)
err := a.clientCache.Get(formatAccessTokenCacheKey(actualClaims), &encAccessToken)
// without an accessToken we can't query the user info endpoint
// thus the user needs to reauthenticate for argocd to get a new accessToken
if errors.Is(err, cache.ErrCacheMiss) {
Expand Down Expand Up @@ -607,6 +607,9 @@ func (a *ClientApp) GetUserInfo(actualClaims jwt.MapClaims, issuerURL, userInfoP
if response.StatusCode == http.StatusUnauthorized {
return claims, true, err
}
if response.StatusCode == http.StatusNotFound {
return jwt.MapClaims{}, true, fmt.Errorf("user info path not found: %s", userInfoPath)
}

// according to https://openid.net/specs/openid-connect-core-1_0.html#UserInfoResponseValidation
// the response should be validated
Expand Down Expand Up @@ -684,11 +687,13 @@ func getTokenExpiration(claims jwt.MapClaims) time.Duration {
}

// formatUserInfoResponseCacheKey returns the key which is used to store userinfo of user in cache
func formatUserInfoResponseCacheKey(sub string) string {
return fmt.Sprintf("%s_%s", UserInfoResponseCachePrefix, sub)
func formatUserInfoResponseCacheKey(claims jwt.MapClaims) string {
userID := utils.GetUserIdentifier(claims)
return fmt.Sprintf("%s_%s", UserInfoResponseCachePrefix, userID)
}

// formatAccessTokenCacheKey returns the key which is used to store the accessToken of a user in cache
func formatAccessTokenCacheKey(sub string) string {
return fmt.Sprintf("%s_%s", AccessTokenCachePrefix, sub)
func formatAccessTokenCacheKey(claims jwt.MapClaims) string {
userID := utils.GetUserIdentifier(claims)
return fmt.Sprintf("%s_%s", AccessTokenCachePrefix, userID)
}
Loading
Loading