Skip to content

Commit

Permalink
override sub with federated_claims.user_id when dex is used
Browse files Browse the repository at this point in the history
Signed-off-by: Atif Ali <[email protected]>
  • Loading branch information
aali309 committed Nov 6, 2024
1 parent 17c412e commit be4821c
Show file tree
Hide file tree
Showing 9 changed files with 56 additions and 23 deletions.
3 changes: 2 additions & 1 deletion cmd/argocd/commands/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"golang.org/x/oauth2"

"github.com/argoproj/argo-cd/v2/cmd/argocd/commands/headless"
"github.com/argoproj/argo-cd/v2/cmd/argocd/commands/utils"
argocdclient "github.com/argoproj/argo-cd/v2/pkg/apiclient"
sessionpkg "github.com/argoproj/argo-cd/v2/pkg/apiclient/session"
settingspkg "github.com/argoproj/argo-cd/v2/pkg/apiclient/settings"
Expand Down Expand Up @@ -196,7 +197,7 @@ func userDisplayName(claims jwt.MapClaims) string {
if name := jwtutil.StringField(claims, "name"); name != "" {
return name
}
return jwtutil.StringField(claims, "sub")
return utils.GetUserIdentifier(claims)
}

// oauth2Login opens a browser, runs a temporary HTTP server to delegate OAuth2 login flow and
Expand Down
2 changes: 1 addition & 1 deletion cmd/argocd/commands/project_role.go
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ Create token succeeded for proj:test-project:test-role.
issuedAt, _ := jwt.IssuedAt(claims)
expiresAt := int64(jwt.Float64Field(claims, "exp"))
id := jwt.StringField(claims, "jti")
subject := jwt.StringField(claims, "sub")
subject := utils.GetUserIdentifier(claims)

if !outputTokenOnly {
fmt.Printf("Create token succeeded for %s.\n", subject)
Expand Down
22 changes: 22 additions & 0 deletions cmd/argocd/commands/utils/claims.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package utils

import (
"github.com/golang-jwt/jwt/v4"
)

// GetUserIdentifier returns a consistent user identifier, checking federated_claims.user_id when Dex is in use
func GetUserIdentifier(claims jwt.MapClaims) string {
// Check for federated_claims.user_id if Dex is used
if federatedClaims, ok := claims["federated_claims"].(map[string]interface{}); ok {
if userID, exists := federatedClaims["user_id"].(string); exists {
return userID
}
}

// Fallback to sub
if sub, ok := claims["sub"].(string); ok {
return sub
}
return ""

}
3 changes: 2 additions & 1 deletion server/rbacpolicy/rbacpolicy.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"github.com/golang-jwt/jwt/v4"
log "github.com/sirupsen/logrus"

"github.com/argoproj/argo-cd/v2/cmd/argocd/commands/utils"
"github.com/argoproj/argo-cd/v2/pkg/apis/application/v1alpha1"
applister "github.com/argoproj/argo-cd/v2/pkg/client/listers/application/v1alpha1"
jwtutil "github.com/argoproj/argo-cd/v2/util/jwt"
Expand Down Expand Up @@ -114,7 +115,7 @@ func (p *RBACPolicyEnforcer) EnforceClaims(claims jwt.Claims, rvals ...interface
return false
}

subject := jwtutil.StringField(mapClaims, "sub")
subject := utils.GetUserIdentifier(mapClaims)
// Check if the request is for an application resource. We have special enforcement which takes
// into consideration the project's token and group bindings
var runtimePolicy string
Expand Down
3 changes: 2 additions & 1 deletion server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ import (
"k8s.io/client-go/tools/cache"
"sigs.k8s.io/controller-runtime/pkg/client"

"github.com/argoproj/argo-cd/v2/cmd/argocd/commands/utils"
"github.com/argoproj/argo-cd/v2/common"
"github.com/argoproj/argo-cd/v2/pkg/apiclient"
accountpkg "github.com/argoproj/argo-cd/v2/pkg/apiclient/account"
Expand Down Expand Up @@ -1417,7 +1418,7 @@ func (a *ArgoCDServer) getClaims(ctx context.Context) (jwt.Claims, string, error
log.Errorf("error fetching user info endpoint: %v", err)
return claims, "", status.Errorf(codes.Internal, "invalid userinfo response")
}
if groupClaims["sub"] != userInfo["sub"] {
if utils.GetUserIdentifier(groupClaims) != utils.GetUserIdentifier(userInfo) {
return claims, "", status.Error(codes.Unknown, "subject of claims from user info endpoint didn't match subject of idToken, see https://openid.net/specs/openid-connect-core-1_0.html#UserInfo")
}
groupClaims["groups"] = userInfo["groups"]
Expand Down
16 changes: 9 additions & 7 deletions util/oidc/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
log "github.com/sirupsen/logrus"
"golang.org/x/oauth2"

"github.com/argoproj/argo-cd/v2/cmd/argocd/commands/utils"
"github.com/argoproj/argo-cd/v2/common"
"github.com/argoproj/argo-cd/v2/server/settings/oidc"
"github.com/argoproj/argo-cd/v2/util/cache"
Expand Down Expand Up @@ -402,9 +403,8 @@ func (a *ClientApp) HandleCallback(w http.ResponseWriter, r *http.Request) {
log.Errorf("cannot encrypt accessToken: %v (claims=%s)", err, claimsJSON)
return
}
sub := jwtutil.StringField(claims, "sub")
err = a.clientCache.Set(&cache.Item{
Key: formatAccessTokenCacheKey(sub),
Key: formatAccessTokenCacheKey(claims),
Object: encToken,
CacheActionOpts: cache.CacheActionOpts{
Expiration: getTokenExpiration(claims),
Expand Down Expand Up @@ -552,12 +552,12 @@ func createClaimsAuthenticationRequestParameter(requestedClaims map[string]*oidc

// GetUserInfo queries the IDP userinfo endpoint for claims
func (a *ClientApp) GetUserInfo(actualClaims jwt.MapClaims, issuerURL, userInfoPath string) (jwt.MapClaims, bool, error) {
sub := jwtutil.StringField(actualClaims, "sub")
sub := utils.GetUserIdentifier(actualClaims)
var claims jwt.MapClaims
var encClaims []byte

// in case we got it in the cache, we just return the item
clientCacheKey := formatUserInfoResponseCacheKey(sub)
clientCacheKey := formatUserInfoResponseCacheKey(actualClaims)
if err := a.clientCache.Get(clientCacheKey, &encClaims); err == nil {
claimsRaw, err := crypto.Decrypt(encClaims, a.encryptionKey)
if err != nil {
Expand All @@ -575,7 +575,7 @@ func (a *ClientApp) GetUserInfo(actualClaims jwt.MapClaims, issuerURL, userInfoP

// check if the accessToken for the user is still present
var encAccessToken []byte
err := a.clientCache.Get(formatAccessTokenCacheKey(sub), &encAccessToken)
err := a.clientCache.Get(formatAccessTokenCacheKey(actualClaims), &encAccessToken)
// without an accessToken we can't query the user info endpoint
// thus the user needs to reauthenticate for argocd to get a new accessToken
if errors.Is(err, cache.ErrCacheMiss) {
Expand Down Expand Up @@ -684,11 +684,13 @@ func getTokenExpiration(claims jwt.MapClaims) time.Duration {
}

// formatUserInfoResponseCacheKey returns the key which is used to store userinfo of user in cache
func formatUserInfoResponseCacheKey(sub string) string {
func formatUserInfoResponseCacheKey(claims jwt.MapClaims) string {
sub := utils.GetUserIdentifier(claims)
return fmt.Sprintf("%s_%s", UserInfoResponseCachePrefix, sub)
}

// formatAccessTokenCacheKey returns the key which is used to store the accessToken of a user in cache
func formatAccessTokenCacheKey(sub string) string {
func formatAccessTokenCacheKey(claims jwt.MapClaims) string {
sub := utils.GetUserIdentifier(claims)
return fmt.Sprintf("%s_%s", AccessTokenCachePrefix, sub)
}
3 changes: 2 additions & 1 deletion util/rbac/rbac.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"sync"
"time"

"github.com/argoproj/argo-cd/v2/cmd/argocd/commands/utils"
"github.com/argoproj/argo-cd/v2/util/assets"
"github.com/argoproj/argo-cd/v2/util/glob"
jwtutil "github.com/argoproj/argo-cd/v2/util/jwt"
Expand Down Expand Up @@ -255,7 +256,7 @@ func (e *Enforcer) EnforceErr(rvals ...interface{}) error {
if err != nil {
break
}
if sub := jwtutil.StringField(claims, "sub"); sub != "" {
if sub := utils.GetUserIdentifier(claims); sub != "" {
rvalsStrs = append(rvalsStrs, fmt.Sprintf("sub: %s", sub))
}
if issuedAtTime, err := jwtutil.IssuedAtTime(claims); err == nil {
Expand Down
22 changes: 13 additions & 9 deletions util/session/sessionmanager.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"

"github.com/argoproj/argo-cd/v2/cmd/argocd/commands/utils"
"github.com/argoproj/argo-cd/v2/common"
"github.com/argoproj/argo-cd/v2/pkg/client/listers/application/v1alpha1"
"github.com/argoproj/argo-cd/v2/server/rbacpolicy"
Expand Down Expand Up @@ -226,7 +227,7 @@ func (mgr *SessionManager) Parse(tokenString string) (jwt.Claims, string, error)
return nil, "", err
}

subject := jwtutil.StringField(claims, "sub")
subject := utils.GetUserIdentifier(claims)
id := jwtutil.StringField(claims, "jti")

if projName, role, ok := rbacpolicy.GetProjectRoleFromSubject(subject); ok {
Expand Down Expand Up @@ -502,9 +503,17 @@ func WithAuthMiddleware(disabled bool, authn TokenVerifier, next http.Handler) h
return
}
ctx := r.Context()

// Assert that claims is of type jwt.MapClaims
mapClaims, ok := claims.(jwt.MapClaims)
if !ok {
http.Error(w, "Invalid claims type", http.StatusUnauthorized)
return
}

// Add claims to the context to inspect for RBAC
// nolint:staticcheck
ctx = context.WithValue(ctx, "claims", claims)
ctx = context.WithValue(ctx, "user_id", utils.GetUserIdentifier(mapClaims))
r = r.WithContext(ctx)
}
next.ServeHTTP(w, r)
Expand Down Expand Up @@ -593,12 +602,7 @@ func Username(ctx context.Context) string {
if !ok {
return ""
}
switch jwtutil.StringField(mapClaims, "iss") {
case SessionManagerClaimsIssuer:
return jwtutil.StringField(mapClaims, "sub")
default:
return jwtutil.StringField(mapClaims, "email")
}
return utils.GetUserIdentifier(mapClaims)
}

func Iss(ctx context.Context) string {
Expand All @@ -622,7 +626,7 @@ func Sub(ctx context.Context) string {
if !ok {
return ""
}
return jwtutil.StringField(mapClaims, "sub")
return utils.GetUserIdentifier(mapClaims)
}

func Groups(ctx context.Context, scopes []string) []string {
Expand Down
5 changes: 3 additions & 2 deletions util/session/sessionmanager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/client-go/kubernetes/fake"

"github.com/argoproj/argo-cd/v2/cmd/argocd/commands/utils"
"github.com/argoproj/argo-cd/v2/common"
appv1 "github.com/argoproj/argo-cd/v2/pkg/apis/application/v1alpha1"
apps "github.com/argoproj/argo-cd/v2/pkg/client/clientset/versioned/fake"
Expand Down Expand Up @@ -99,7 +100,7 @@ func TestSessionManager_AdminToken(t *testing.T) {
assert.Empty(t, newToken)

mapClaims := *(claims.(*jwt.MapClaims))
subject := mapClaims["sub"].(string)
subject := utils.GetUserIdentifier(mapClaims)
if subject != "admin" {
t.Errorf("Token claim subject \"%s\" does not match expected subject \"%s\".", subject, "admin")
}
Expand All @@ -126,7 +127,7 @@ func TestSessionManager_AdminToken_ExpiringSoon(t *testing.T) {
claims, _, err := mgr.Parse(newToken)
require.NoError(t, err)
mapClaims := *(claims.(*jwt.MapClaims))
subject := mapClaims["sub"].(string)
subject := utils.GetUserIdentifier(mapClaims)
assert.Equal(t, "admin", subject)
}

Expand Down

0 comments on commit be4821c

Please sign in to comment.