diff --git a/internal/client-go/go.sum b/internal/client-go/go.sum index c966c8ddfd0d..6cc3f5911d11 100644 --- a/internal/client-go/go.sum +++ b/internal/client-go/go.sum @@ -4,6 +4,7 @@ github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5y golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e h1:bRhVy7zSSasaqNksaRZiA5EEI+Ei4I1nO5Jh72wfHlg= golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4 h1:YUO/7uOKsKeq9UokNS62b8FYywz3ker1l1vDZRCRefw= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= diff --git a/selfservice/flow/login/handler.go b/selfservice/flow/login/handler.go index 88b3712602a0..d9920a780ef7 100644 --- a/selfservice/flow/login/handler.go +++ b/selfservice/flow/login/handler.go @@ -811,8 +811,12 @@ continueLogin: sess = session.NewInactiveSession() } - method := ss.CompletedAuthenticationMethod(r.Context(), sess.AMR) - sess.CompletedLoginForMethod(method) + method, err := ss.CompletedAuthenticationMethod(r.Context(), sess.AMR, nil) + if err != nil { + h.d.LoginFlowErrorHandler().WriteFlowError(w, r, f, group, err) + return + } + sess.CompletedLoginForMethod(*method) i = interim break } diff --git a/selfservice/flow/login/hook.go b/selfservice/flow/login/hook.go index f0e06ccfc934..9a04c847ac78 100644 --- a/selfservice/flow/login/hook.go +++ b/selfservice/flow/login/hook.go @@ -366,8 +366,11 @@ func (e *HookExecutor) maybeLinkCredentials(ctx context.Context, sess *session.S return err } - method := strategy.CompletedAuthenticationMethod(ctx, sess.AMR) - sess.CompletedLoginForMethod(method) + method, err := strategy.CompletedAuthenticationMethod(ctx, sess.AMR, lc.CredentialsConfig) + if err != nil { + return err + } + sess.CompletedLoginForMethod(*method) return nil } diff --git a/selfservice/flow/login/strategy.go b/selfservice/flow/login/strategy.go index c70ad9cc8684..4393b8dbafc2 100644 --- a/selfservice/flow/login/strategy.go +++ b/selfservice/flow/login/strategy.go @@ -22,7 +22,7 @@ type Strategy interface { RegisterLoginRoutes(*x.RouterPublic) PopulateLoginMethod(r *http.Request, requestedAAL identity.AuthenticatorAssuranceLevel, sr *Flow) error Login(w http.ResponseWriter, r *http.Request, f *Flow, sess *session.Session) (i *identity.Identity, err error) - CompletedAuthenticationMethod(ctx context.Context, methods session.AuthenticationMethods) session.AuthenticationMethod + CompletedAuthenticationMethod(ctx context.Context, methods session.AuthenticationMethods, credentialsConfig sqlxx.JSONRawMessage) (*session.AuthenticationMethod, error) } type Strategies []Strategy diff --git a/selfservice/strategy/code/strategy_login.go b/selfservice/strategy/code/strategy_login.go index a9d7459f5c56..3969064689d0 100644 --- a/selfservice/strategy/code/strategy_login.go +++ b/selfservice/strategy/code/strategy_login.go @@ -10,6 +10,8 @@ import ( "net/http" "strings" + "github.com/ory/x/sqlxx" + "github.com/ory/x/sqlcon" "github.com/pkg/errors" @@ -67,20 +69,20 @@ type updateLoginFlowWithCodeMethod struct { func (s *Strategy) RegisterLoginRoutes(*x.RouterPublic) {} -func (s *Strategy) CompletedAuthenticationMethod(ctx context.Context, amr session.AuthenticationMethods) session.AuthenticationMethod { +func (s *Strategy) CompletedAuthenticationMethod(ctx context.Context, amr session.AuthenticationMethods, _ sqlxx.JSONRawMessage) (*session.AuthenticationMethod, error) { aal1Satisfied := lo.ContainsBy(amr, func(am session.AuthenticationMethod) bool { return am.Method != identity.CredentialsTypeCodeAuth && am.AAL == identity.AuthenticatorAssuranceLevel1 }) if aal1Satisfied { - return session.AuthenticationMethod{ + return &session.AuthenticationMethod{ Method: identity.CredentialsTypeCodeAuth, AAL: identity.AuthenticatorAssuranceLevel2, - } + }, nil } - return session.AuthenticationMethod{ + return &session.AuthenticationMethod{ Method: identity.CredentialsTypeCodeAuth, AAL: identity.AuthenticatorAssuranceLevel1, - } + }, nil } func (s *Strategy) HandleLoginError(r *http.Request, f *login.Flow, body *updateLoginFlowWithCodeMethod, err error) error { diff --git a/selfservice/strategy/lookup/strategy.go b/selfservice/strategy/lookup/strategy.go index e8f12cac9948..ee4fcaa9e745 100644 --- a/selfservice/strategy/lookup/strategy.go +++ b/selfservice/strategy/lookup/strategy.go @@ -7,6 +7,8 @@ import ( "context" "encoding/json" + "github.com/ory/x/sqlxx" + "github.com/pkg/errors" "github.com/ory/kratos/continuity" @@ -106,9 +108,9 @@ func (s *Strategy) NodeGroup() node.UiNodeGroup { return node.LookupGroup } -func (s *Strategy) CompletedAuthenticationMethod(ctx context.Context, _ session.AuthenticationMethods) session.AuthenticationMethod { - return session.AuthenticationMethod{ +func (s *Strategy) CompletedAuthenticationMethod(ctx context.Context, _ session.AuthenticationMethods, _ sqlxx.JSONRawMessage) (*session.AuthenticationMethod, error) { + return &session.AuthenticationMethod{ Method: s.ID(), AAL: identity.AuthenticatorAssuranceLevel2, - } + }, nil } diff --git a/selfservice/strategy/oidc/strategy.go b/selfservice/strategy/oidc/strategy.go index 6515d06367ee..5e227faf488d 100644 --- a/selfservice/strategy/oidc/strategy.go +++ b/selfservice/strategy/oidc/strategy.go @@ -16,6 +16,8 @@ import ( "strings" "time" + "github.com/ory/x/sqlxx" + "golang.org/x/exp/maps" "github.com/ory/x/urlx" @@ -719,11 +721,17 @@ func (s *Strategy) NodeGroup() node.UiNodeGroup { return node.OpenIDConnectGroup } -func (s *Strategy) CompletedAuthenticationMethod(ctx context.Context, _ session.AuthenticationMethods) session.AuthenticationMethod { - return session.AuthenticationMethod{ - Method: s.ID(), - AAL: identity.AuthenticatorAssuranceLevel1, +func (s *Strategy) CompletedAuthenticationMethod(ctx context.Context, _ session.AuthenticationMethods, credentialsConfig sqlxx.JSONRawMessage) (*session.AuthenticationMethod, error) { + credentialsOIDCProvider, err := s.getProvider(credentialsConfig) + if err != nil { + return nil, err } + + return &session.AuthenticationMethod{ + Method: s.ID(), + AAL: identity.AuthenticatorAssuranceLevel1, + Provider: credentialsOIDCProvider.Provider, + }, nil } func (s *Strategy) processIDToken(w http.ResponseWriter, r *http.Request, provider Provider, idToken, idTokenNonce string) (*Claims, error) { @@ -840,3 +848,15 @@ func (s *Strategy) encryptOAuth2Tokens(ctx context.Context, token *oauth2.Token) return et, nil } + +func (s *Strategy) getProvider(credentialsConfig sqlxx.JSONRawMessage) (identity.CredentialsOIDCProvider, error) { + var credentialsOIDCConfig identity.CredentialsOIDC + if err := json.Unmarshal(credentialsConfig, &credentialsOIDCConfig); err != nil { + return identity.CredentialsOIDCProvider{}, err + } + if len(credentialsOIDCConfig.Providers) != 1 { + return identity.CredentialsOIDCProvider{}, errors.New("No oidc provider was set") + } + credentialsOIDCProvider := credentialsOIDCConfig.Providers[0] + return credentialsOIDCProvider, nil +} diff --git a/selfservice/strategy/oidc/strategy_settings.go b/selfservice/strategy/oidc/strategy_settings.go index 4fde3a457548..47c7ed159e13 100644 --- a/selfservice/strategy/oidc/strategy_settings.go +++ b/selfservice/strategy/oidc/strategy_settings.go @@ -514,14 +514,10 @@ func (s *Strategy) handleSettingsError(w http.ResponseWriter, r *http.Request, c } func (s *Strategy) Link(ctx context.Context, i *identity.Identity, credentialsConfig sqlxx.JSONRawMessage) error { - var credentialsOIDCConfig identity.CredentialsOIDC - if err := json.Unmarshal(credentialsConfig, &credentialsOIDCConfig); err != nil { + credentialsOIDCProvider, err := s.getProvider(credentialsConfig) + if err != nil { return err } - if len(credentialsOIDCConfig.Providers) != 1 { - return errors.New("No oidc provider was set") - } - credentialsOIDCProvider := credentialsOIDCConfig.Providers[0] if err := s.linkCredentials( ctx, diff --git a/selfservice/strategy/oidc/strategy_test.go b/selfservice/strategy/oidc/strategy_test.go index 65c8f09b2e06..5fbcbefe6a01 100644 --- a/selfservice/strategy/oidc/strategy_test.go +++ b/selfservice/strategy/oidc/strategy_test.go @@ -826,6 +826,7 @@ func TestStrategy(t *testing.T) { }`, expect: func(t *testing.T, res *http.Response, body []byte) { require.NotEmpty(t, gjson.GetBytes(body, "session_token").String(), "%s", body) + require.Equal(t, "test-provider", gjson.GetBytes(body, "session.authentication_methods.0.provider").String(), "%s", body) }, }, { @@ -1273,6 +1274,7 @@ func TestStrategy(t *testing.T) { assert.Equal(t, provider, gjson.GetBytes(i.Credentials["oidc"].Config, "providers.0.provider").String(), "%s", string(i.Credentials["oidc"].Config[:])) assert.Contains(t, gjson.GetBytes(body, "authentication_methods").String(), "oidc", "%s", body) + assert.Equal(t, "valid", gjson.GetBytes(body, "authentication_methods.1.provider").String(), "%s", body) } t.Run("case=second login is password", func(t *testing.T) { diff --git a/selfservice/strategy/passkey/passkey_strategy.go b/selfservice/strategy/passkey/passkey_strategy.go index b590a7e93b6d..45a4efde7007 100644 --- a/selfservice/strategy/passkey/passkey_strategy.go +++ b/selfservice/strategy/passkey/passkey_strategy.go @@ -7,6 +7,8 @@ import ( "context" "encoding/json" + "github.com/ory/x/sqlxx" + "github.com/pkg/errors" "github.com/ory/kratos/continuity" @@ -88,11 +90,11 @@ func (*Strategy) NodeGroup() node.UiNodeGroup { return node.PasskeyGroup } -func (s *Strategy) CompletedAuthenticationMethod(context.Context, session.AuthenticationMethods) session.AuthenticationMethod { - return session.AuthenticationMethod{ +func (s *Strategy) CompletedAuthenticationMethod(context.Context, session.AuthenticationMethods, sqlxx.JSONRawMessage) (*session.AuthenticationMethod, error) { + return &session.AuthenticationMethod{ Method: identity.CredentialsTypePasskey, AAL: identity.AuthenticatorAssuranceLevel1, - } + }, nil } func (s *Strategy) CountActiveMultiFactorCredentials(cc map[identity.CredentialsType]identity.Credentials) (count int, err error) { diff --git a/selfservice/strategy/password/strategy.go b/selfservice/strategy/password/strategy.go index 911ad619cd15..f56076e2042e 100644 --- a/selfservice/strategy/password/strategy.go +++ b/selfservice/strategy/password/strategy.go @@ -7,6 +7,8 @@ import ( "context" "encoding/json" + "github.com/ory/x/sqlxx" + "github.com/ory/kratos/ui/node" "github.com/go-playground/validator/v10" @@ -109,11 +111,11 @@ func (s *Strategy) ID() identity.CredentialsType { return identity.CredentialsTypePassword } -func (s *Strategy) CompletedAuthenticationMethod(ctx context.Context, _ session.AuthenticationMethods) session.AuthenticationMethod { - return session.AuthenticationMethod{ +func (s *Strategy) CompletedAuthenticationMethod(ctx context.Context, _ session.AuthenticationMethods, _ sqlxx.JSONRawMessage) (*session.AuthenticationMethod, error) { + return &session.AuthenticationMethod{ Method: s.ID(), AAL: identity.AuthenticatorAssuranceLevel1, - } + }, nil } func (s *Strategy) NodeGroup() node.UiNodeGroup { diff --git a/selfservice/strategy/totp/strategy.go b/selfservice/strategy/totp/strategy.go index 6c3205abd9ac..f3e7f6a04115 100644 --- a/selfservice/strategy/totp/strategy.go +++ b/selfservice/strategy/totp/strategy.go @@ -7,6 +7,8 @@ import ( "context" "encoding/json" + "github.com/ory/x/sqlxx" + "github.com/pkg/errors" "github.com/pquerna/otp" @@ -109,9 +111,9 @@ func (s *Strategy) NodeGroup() node.UiNodeGroup { return node.TOTPGroup } -func (s *Strategy) CompletedAuthenticationMethod(ctx context.Context, _ session.AuthenticationMethods) session.AuthenticationMethod { - return session.AuthenticationMethod{ +func (s *Strategy) CompletedAuthenticationMethod(ctx context.Context, _ session.AuthenticationMethods, _ sqlxx.JSONRawMessage) (*session.AuthenticationMethod, error) { + return &session.AuthenticationMethod{ Method: s.ID(), AAL: identity.AuthenticatorAssuranceLevel2, - } + }, nil } diff --git a/selfservice/strategy/webauthn/strategy.go b/selfservice/strategy/webauthn/strategy.go index 998490055996..a9d71d4e45e5 100644 --- a/selfservice/strategy/webauthn/strategy.go +++ b/selfservice/strategy/webauthn/strategy.go @@ -7,6 +7,8 @@ import ( "context" "encoding/json" + "github.com/ory/x/sqlxx" + "github.com/pkg/errors" "github.com/ory/kratos/continuity" @@ -114,13 +116,13 @@ func (s *Strategy) NodeGroup() node.UiNodeGroup { return node.WebAuthnGroup } -func (s *Strategy) CompletedAuthenticationMethod(ctx context.Context, _ session.AuthenticationMethods) session.AuthenticationMethod { +func (s *Strategy) CompletedAuthenticationMethod(ctx context.Context, _ session.AuthenticationMethods, _ sqlxx.JSONRawMessage) (*session.AuthenticationMethod, error) { aal := identity.AuthenticatorAssuranceLevel1 if !s.d.Config().WebAuthnForPasswordless(ctx) { aal = identity.AuthenticatorAssuranceLevel2 } - return session.AuthenticationMethod{ + return &session.AuthenticationMethod{ Method: s.ID(), AAL: aal, - } + }, nil } diff --git a/selfservice/strategy/webauthn/strategy_test.go b/selfservice/strategy/webauthn/strategy_test.go index cc5f6fafb475..58c35d1a27c5 100644 --- a/selfservice/strategy/webauthn/strategy_test.go +++ b/selfservice/strategy/webauthn/strategy_test.go @@ -23,16 +23,20 @@ func TestCompletedAuthenticationMethod(t *testing.T) { conf, reg := internal.NewFastRegistryWithMocks(t) strategy := webauthn.NewStrategy(reg) + method, err := strategy.CompletedAuthenticationMethod(context.Background(), session.AuthenticationMethods{}, nil) + assert.NoError(t, err) assert.Equal(t, session.AuthenticationMethod{ Method: strategy.ID(), AAL: identity.AuthenticatorAssuranceLevel2, - }, strategy.CompletedAuthenticationMethod(context.Background(), session.AuthenticationMethods{})) + }, *method) conf.MustSet(ctx, config.ViperKeyWebAuthnPasswordless, true) + method, err = strategy.CompletedAuthenticationMethod(context.Background(), session.AuthenticationMethods{}, nil) + assert.NoError(t, err) assert.Equal(t, session.AuthenticationMethod{ Method: strategy.ID(), AAL: identity.AuthenticatorAssuranceLevel1, - }, strategy.CompletedAuthenticationMethod(context.Background(), session.AuthenticationMethods{})) + }, *method) } func TestCountActiveFirstFactorCredentials(t *testing.T) {