From 2cfc06a7568659cb7583e91c73022b026d41219f Mon Sep 17 00:00:00 2001 From: aeneasr <3372410+aeneasr@users.noreply.github.com> Date: Mon, 25 Nov 2024 11:13:59 +0100 Subject: [PATCH 1/5] feat: retry crdb serializable errors --- go.mod | 2 +- go.sum | 4 ++-- oauth2/oauth2_refresh_token_test.go | 4 +++- persistence/sql/persister_oauth2.go | 8 ++++---- 4 files changed, 10 insertions(+), 8 deletions(-) diff --git a/go.mod b/go.mod index 26adbfc96ce..8092bf8a20a 100644 --- a/go.mod +++ b/go.mod @@ -13,7 +13,7 @@ replace github.com/gobuffalo/pop/v6 => github.com/ory/pop/v6 v6.2.0 // // This is needed until we release the next version of the master branch, as that branch already contains the redirect URI validation fix, which // may be breaking for some users. -replace github.com/ory/fosite => github.com/ory/fosite v0.47.1-0.20241101073333-eab241e153a4 +replace github.com/ory/fosite => github.com/ory/fosite v0.47.1-0.20241125094724-b6468e644902 require ( github.com/ThalesIgnite/crypto11 v1.2.5 diff --git a/go.sum b/go.sum index 5f7a2ba7603..a81e9e76d2b 100644 --- a/go.sum +++ b/go.sum @@ -378,8 +378,8 @@ github.com/ory/analytics-go/v5 v5.0.1 h1:LX8T5B9FN8KZXOtxgN+R3I4THRRVB6+28IKgKBp github.com/ory/analytics-go/v5 v5.0.1/go.mod h1:lWCiCjAaJkKfgR/BN5DCLMol8BjKS1x+4jxBxff/FF0= github.com/ory/dockertest/v3 v3.10.1-0.20240704115616-d229e74b748d h1:By96ZSVuH5LyjXLVVMfvJoLVGHaT96LdOnwgFSLVf0E= github.com/ory/dockertest/v3 v3.10.1-0.20240704115616-d229e74b748d/go.mod h1:F2FIjwwAk6CsNAs//B8+aPFQF0t84pbM8oliyNXwQrk= -github.com/ory/fosite v0.47.1-0.20241101073333-eab241e153a4 h1:1pEVHGC+Dx2xMPMgpRgG3lyejyK8iU9KKfSnLowLYd8= -github.com/ory/fosite v0.47.1-0.20241101073333-eab241e153a4/go.mod h1:AZyn1jrABUaGN12RHcWorRLbqLn52gTdHaIYY81m5J0= +github.com/ory/fosite v0.47.1-0.20241125094724-b6468e644902 h1:X0ngo+uPWCw90ueY3Kh6q8IyF2fbwkJ8bf9RvAmD71U= +github.com/ory/fosite v0.47.1-0.20241125094724-b6468e644902/go.mod h1:AZyn1jrABUaGN12RHcWorRLbqLn52gTdHaIYY81m5J0= github.com/ory/go-acc v0.2.9-0.20230103102148-6b1c9a70dbbe h1:rvu4obdvqR0fkSIJ8IfgzKOWwZ5kOT2UNfLq81Qk7rc= github.com/ory/go-acc v0.2.9-0.20230103102148-6b1c9a70dbbe/go.mod h1:z4n3u6as84LbV4YmgjHhnwtccQqzf4cZlSk9f1FhygI= github.com/ory/go-convenience v0.1.0 h1:zouLKfF2GoSGnJwGq+PE/nJAE6dj2Zj5QlTgmMTsTS8= diff --git a/oauth2/oauth2_refresh_token_test.go b/oauth2/oauth2_refresh_token_test.go index 849fae06460..ffabb0dd2a0 100644 --- a/oauth2/oauth2_refresh_token_test.go +++ b/oauth2/oauth2_refresh_token_test.go @@ -172,8 +172,10 @@ func TestCreateRefreshTokenSessionStress(t *testing.T) { "RETRY_WRITE_TOO_OLD", // refresh token reuse detection "token_inactive", + // Failed to refresh token because of multiple concurrent requests using the same token which is not allowed. + "multiple concurrent requests", } { - if strings.Contains(e.DebugField, errSubstr) { + if strings.Contains(e.DebugField+e.HintField, errSubstr) { matched = true break } diff --git a/persistence/sql/persister_oauth2.go b/persistence/sql/persister_oauth2.go index 083e67ac5da..b67b6ae17ed 100644 --- a/persistence/sql/persister_oauth2.go +++ b/persistence/sql/persister_oauth2.go @@ -254,7 +254,7 @@ func (p *Persister) createSession(ctx context.Context, signature string, request } if err = sqlcon.HandleError(p.CreateWithNetwork(ctx, req)); errors.Is(err, sqlcon.ErrConcurrentUpdate) { - return errors.Wrap(fosite.ErrSerializationFailure, err.Error()) + return fosite.ErrSerializationFailure.WithWrap(err) } else if err != nil { return err } @@ -293,7 +293,7 @@ func (p *Persister) deleteSessionBySignature(ctx context.Context, signature stri return errorsx.WithStack(fosite.ErrNotFound) } if errors.Is(err, sqlcon.ErrConcurrentUpdate) { - return errors.Wrap(fosite.ErrSerializationFailure, err.Error()) + return fosite.ErrSerializationFailure.WithWrap(err) } return err } @@ -310,7 +310,7 @@ func (p *Persister) deleteSessionByRequestID(ctx context.Context, id string, tab } if err := sqlcon.HandleError(err); err != nil { if errors.Is(err, sqlcon.ErrConcurrentUpdate) { - return errors.Wrap(fosite.ErrSerializationFailure, err.Error()) + return fosite.ErrSerializationFailure.WithWrap(err) } if strings.Contains(err.Error(), "Error 1213") { // InnoDB Deadlock? return errors.Wrap(fosite.ErrSerializationFailure, err.Error()) @@ -426,7 +426,7 @@ func (p *Persister) DeleteAccessTokenSession(ctx context.Context, signature stri } } if errors.Is(err, sqlcon.ErrConcurrentUpdate) { - return errors.Wrap(fosite.ErrSerializationFailure, err.Error()) + return fosite.ErrSerializationFailure.WithWrap(err) } return err } From f7e6d948089c3d4e16203fbe910a044a54c91252 Mon Sep 17 00:00:00 2001 From: aeneasr <3372410+aeneasr@users.noreply.github.com> Date: Wed, 27 Nov 2024 17:52:48 +0100 Subject: [PATCH 2/5] chore: synchronize workspaces --- oauth2/fosite_store_test.go | 12 +- oauth2/oauth2_auth_code_test.go | 1267 ++++++++++++++++--------------- 2 files changed, 668 insertions(+), 611 deletions(-) diff --git a/oauth2/fosite_store_test.go b/oauth2/fosite_store_test.go index 2a48a52f8e7..292988b77c4 100644 --- a/oauth2/fosite_store_test.go +++ b/oauth2/fosite_store_test.go @@ -16,7 +16,6 @@ import ( "github.com/ory/hydra/v2/internal" . "github.com/ory/hydra/v2/oauth2" "github.com/ory/x/contextx" - "github.com/ory/x/networkx" "github.com/ory/x/sqlcon/dockertest" ) @@ -72,14 +71,11 @@ func TestManagers(t *testing.T) { require.NoError(t, registries["memory"].ClientManager().CreateClient(context.Background(), &client.Client{ID: "foobar"})) // this is a workaround because the client is not being created for memory store by test helpers. - for k, store := range registries { - net := &networkx.Network{} - require.NoError(t, store.Persister().Connection(context.Background()).First(net)) - store.Config().MustSet(ctx, config.KeyEncryptSessionData, tc.enableSessionEncrypted) - store.WithContextualizer(&contextx.Static{NID: net.ID, C: store.Config().Source(ctx)}) - TestHelperRunner(t, store, k) + for k, _ := range registries { + reg := internal.NewRegistrySQLFromURL(t, registries[k].Config().DSN(), true, &contextx.Default{}) + reg.Config().MustSet(ctx, config.KeyEncryptSessionData, tc.enableSessionEncrypted) + TestHelperRunner(t, reg, k) } }) - } } diff --git a/oauth2/oauth2_auth_code_test.go b/oauth2/oauth2_auth_code_test.go index 0d89e14ac9b..10e629c6cba 100644 --- a/oauth2/oauth2_auth_code_test.go +++ b/oauth2/oauth2_auth_code_test.go @@ -10,6 +10,7 @@ import ( "encoding/json" "errors" "fmt" + "github.com/ory/hydra/v2/jwk" "io" "net/http" "net/http/httptest" @@ -62,6 +63,95 @@ type clientCreator interface { CreateClient(context.Context, *client.Client) error } +func getAuthorizeCode(t *testing.T, conf *oauth2.Config, c *http.Client, params ...oauth2.AuthCodeOption) (string, *http.Response) { + if c == nil { + c = testhelpers.NewEmptyJarClient(t) + } + + state := uuid.New() + resp, err := c.Get(conf.AuthCodeURL(state, params...)) + require.NoError(t, err) + defer resp.Body.Close() + + q := resp.Request.URL.Query() + require.EqualValues(t, state, q.Get("state")) + return q.Get("code"), resp +} + +func acceptLoginHandler(t *testing.T, c *client.Client, adminClient *hydra.APIClient, reg driver.Registry, subject string, checkRequestPayload func(request *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + rr, _, err := adminClient.OAuth2API.GetOAuth2LoginRequest(context.Background()).LoginChallenge(r.URL.Query().Get("login_challenge")).Execute() + require.NoError(t, err) + + assert.EqualValues(t, c.GetID(), pointerx.Deref(rr.Client.ClientId)) + assert.Empty(t, pointerx.Deref(rr.Client.ClientSecret)) + assert.EqualValues(t, c.GrantTypes, rr.Client.GrantTypes) + assert.EqualValues(t, c.LogoURI, pointerx.Deref(rr.Client.LogoUri)) + assert.EqualValues(t, c.RedirectURIs, rr.Client.RedirectUris) + assert.EqualValues(t, r.URL.Query().Get("login_challenge"), rr.Challenge) + assert.EqualValues(t, []string{"hydra", "offline", "openid"}, rr.RequestedScope) + assert.Contains(t, rr.RequestUrl, reg.Config().OAuth2AuthURL(ctx).String()) + + acceptBody := hydra.AcceptOAuth2LoginRequest{ + Subject: subject, + Remember: pointerx.Ptr(!rr.Skip), + Acr: pointerx.Ptr("1"), + Amr: []string{"pwd"}, + Context: map[string]interface{}{"context": "bar"}, + } + if checkRequestPayload != nil { + if b := checkRequestPayload(rr); b != nil { + acceptBody = *b + } + } + + v, _, err := adminClient.OAuth2API.AcceptOAuth2LoginRequest(context.Background()). + LoginChallenge(r.URL.Query().Get("login_challenge")). + AcceptOAuth2LoginRequest(acceptBody). + Execute() + require.NoError(t, err) + require.NotEmpty(t, v.RedirectTo) + http.Redirect(w, r, v.RedirectTo, http.StatusFound) + } +} + +func acceptConsentHandler(t *testing.T, c *client.Client, adminClient *hydra.APIClient, reg driver.Registry, subject string, checkRequestPayload func(*hydra.OAuth2ConsentRequest)) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + rr, _, err := adminClient.OAuth2API.GetOAuth2ConsentRequest(context.Background()).ConsentChallenge(r.URL.Query().Get("consent_challenge")).Execute() + require.NoError(t, err) + + assert.EqualValues(t, c.GetID(), pointerx.Deref(rr.Client.ClientId)) + assert.Empty(t, pointerx.Deref(rr.Client.ClientSecret)) + assert.EqualValues(t, c.GrantTypes, rr.Client.GrantTypes) + assert.EqualValues(t, c.LogoURI, pointerx.Deref(rr.Client.LogoUri)) + assert.EqualValues(t, c.RedirectURIs, rr.Client.RedirectUris) + assert.EqualValues(t, subject, pointerx.Deref(rr.Subject)) + assert.EqualValues(t, []string{"hydra", "offline", "openid"}, rr.RequestedScope) + assert.EqualValues(t, r.URL.Query().Get("consent_challenge"), rr.Challenge) + assert.Contains(t, *rr.RequestUrl, reg.Config().OAuth2AuthURL(r.Context()).String()) + if checkRequestPayload != nil { + checkRequestPayload(rr) + } + + assert.Equal(t, map[string]interface{}{"context": "bar"}, rr.Context) + v, _, err := adminClient.OAuth2API.AcceptOAuth2ConsentRequest(context.Background()). + ConsentChallenge(r.URL.Query().Get("consent_challenge")). + AcceptOAuth2ConsentRequest(hydra.AcceptOAuth2ConsentRequest{ + GrantScope: []string{"hydra", "offline", "openid"}, Remember: pointerx.Ptr(true), RememberFor: pointerx.Ptr[int64](0), + GrantAccessTokenAudience: rr.RequestedAccessTokenAudience, + Session: &hydra.AcceptOAuth2ConsentRequestSession{ + AccessToken: map[string]interface{}{"foo": "bar"}, + IdToken: map[string]interface{}{"bar": "baz", "email": "foo@bar.com"}, + }, + }). + Execute() + require.NoError(t, err) + require.NotEmpty(t, v.RedirectTo) + http.Redirect(w, r, v.RedirectTo, http.StatusFound) + } +} + // TestAuthCodeWithDefaultStrategy runs proper integration tests against in-memory and database connectors, specifically // we test: // @@ -87,94 +177,6 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) { adminClient := hydra.NewAPIClient(hydra.NewConfiguration()) adminClient.GetConfig().Servers = hydra.ServerConfigurations{{URL: adminTS.URL}} - getAuthorizeCode := func(t *testing.T, conf *oauth2.Config, c *http.Client, params ...oauth2.AuthCodeOption) (string, *http.Response) { - if c == nil { - c = testhelpers.NewEmptyJarClient(t) - } - - state := uuid.New() - resp, err := c.Get(conf.AuthCodeURL(state, params...)) - require.NoError(t, err) - defer resp.Body.Close() - - q := resp.Request.URL.Query() - require.EqualValues(t, state, q.Get("state")) - return q.Get("code"), resp - } - - acceptLoginHandler := func(t *testing.T, c *client.Client, subject string, checkRequestPayload func(request *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - rr, _, err := adminClient.OAuth2API.GetOAuth2LoginRequest(context.Background()).LoginChallenge(r.URL.Query().Get("login_challenge")).Execute() - require.NoError(t, err) - - assert.EqualValues(t, c.GetID(), pointerx.Deref(rr.Client.ClientId)) - assert.Empty(t, pointerx.Deref(rr.Client.ClientSecret)) - assert.EqualValues(t, c.GrantTypes, rr.Client.GrantTypes) - assert.EqualValues(t, c.LogoURI, pointerx.Deref(rr.Client.LogoUri)) - assert.EqualValues(t, c.RedirectURIs, rr.Client.RedirectUris) - assert.EqualValues(t, r.URL.Query().Get("login_challenge"), rr.Challenge) - assert.EqualValues(t, []string{"hydra", "offline", "openid"}, rr.RequestedScope) - assert.Contains(t, rr.RequestUrl, reg.Config().OAuth2AuthURL(ctx).String()) - - acceptBody := hydra.AcceptOAuth2LoginRequest{ - Subject: subject, - Remember: pointerx.Ptr(!rr.Skip), - Acr: pointerx.Ptr("1"), - Amr: []string{"pwd"}, - Context: map[string]interface{}{"context": "bar"}, - } - if checkRequestPayload != nil { - if b := checkRequestPayload(rr); b != nil { - acceptBody = *b - } - } - - v, _, err := adminClient.OAuth2API.AcceptOAuth2LoginRequest(context.Background()). - LoginChallenge(r.URL.Query().Get("login_challenge")). - AcceptOAuth2LoginRequest(acceptBody). - Execute() - require.NoError(t, err) - require.NotEmpty(t, v.RedirectTo) - http.Redirect(w, r, v.RedirectTo, http.StatusFound) - } - } - - acceptConsentHandler := func(t *testing.T, c *client.Client, subject string, checkRequestPayload func(*hydra.OAuth2ConsentRequest)) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - rr, _, err := adminClient.OAuth2API.GetOAuth2ConsentRequest(context.Background()).ConsentChallenge(r.URL.Query().Get("consent_challenge")).Execute() - require.NoError(t, err) - - assert.EqualValues(t, c.GetID(), pointerx.Deref(rr.Client.ClientId)) - assert.Empty(t, pointerx.Deref(rr.Client.ClientSecret)) - assert.EqualValues(t, c.GrantTypes, rr.Client.GrantTypes) - assert.EqualValues(t, c.LogoURI, pointerx.Deref(rr.Client.LogoUri)) - assert.EqualValues(t, c.RedirectURIs, rr.Client.RedirectUris) - assert.EqualValues(t, subject, pointerx.Deref(rr.Subject)) - assert.EqualValues(t, []string{"hydra", "offline", "openid"}, rr.RequestedScope) - assert.EqualValues(t, r.URL.Query().Get("consent_challenge"), rr.Challenge) - assert.Contains(t, *rr.RequestUrl, reg.Config().OAuth2AuthURL(ctx).String()) - if checkRequestPayload != nil { - checkRequestPayload(rr) - } - - assert.Equal(t, map[string]interface{}{"context": "bar"}, rr.Context) - v, _, err := adminClient.OAuth2API.AcceptOAuth2ConsentRequest(context.Background()). - ConsentChallenge(r.URL.Query().Get("consent_challenge")). - AcceptOAuth2ConsentRequest(hydra.AcceptOAuth2ConsentRequest{ - GrantScope: []string{"hydra", "offline", "openid"}, Remember: pointerx.Ptr(true), RememberFor: pointerx.Ptr[int64](0), - GrantAccessTokenAudience: rr.RequestedAccessTokenAudience, - Session: &hydra.AcceptOAuth2ConsentRequestSession{ - AccessToken: map[string]interface{}{"foo": "bar"}, - IdToken: map[string]interface{}{"bar": "baz", "email": "foo@bar.com"}, - }, - }). - Execute() - require.NoError(t, err) - require.NotEmpty(t, v.RedirectTo) - http.Redirect(w, r, v.RedirectTo, http.StatusFound) - } - } - assertRefreshToken := func(t *testing.T, token *oauth2.Token, c *oauth2.Config, expectedExp time.Time) { introspect := testhelpers.IntrospectToken(t, c, token.RefreshToken, adminTS) actualExp, err := strconv.ParseInt(introspect.Get("exp").String(), 10, 64) @@ -266,8 +268,8 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) { run := func(t *testing.T, strategy string) { c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) testhelpers.NewLoginConsentUI(t, reg.Config(), - acceptLoginHandler(t, c, subject, nil), - acceptConsentHandler(t, c, subject, nil), + acceptLoginHandler(t, c, adminClient, reg, subject, nil), + acceptConsentHandler(t, c, adminClient, reg, subject, nil), ) code, _ := getAuthorizeCode(t, conf, nil, oauth2.SetAuthURLParam("nonce", nonce)) @@ -342,8 +344,8 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) { c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) testhelpers.NewLoginConsentUI(t, reg.Config(), - acceptLoginHandler(t, c, subject, nil), - acceptConsentHandler(t, c, subject, nil), + acceptLoginHandler(t, c, adminClient, reg, subject, nil), + acceptConsentHandler(t, c, adminClient, reg, subject, nil), ) issueTokens := func(t *testing.T) *oauth2.Token { @@ -378,7 +380,7 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) { return refreshedToken } - t.Run("followup=successfully perform refresh token flow", func(t *testing.T) { + t.Run("followup=graceful token refresh with reuse detection", func(t *testing.T) { start := time.Now() token := issueTokens(t) @@ -411,7 +413,7 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) { }) }) - t.Run("followup=successfully perform refresh token flow", func(t *testing.T) { + t.Run("followup=graceful token refresh with reuse detection with consent revocation", func(t *testing.T) { start := time.Now() token := issueTokens(t) @@ -447,7 +449,57 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) { }) }) - t.Run("followup=graceful refresh tokens are all refreshed", func(t *testing.T) { + t.Run("followup=graceful token refresh can handle concurrent refreshing", func(t *testing.T) { + start := time.Now() + + token := issueTokens(t) + var first, second *oauth2.Token + var wg sync.WaitGroup + refreshes := make([]*oauth2.Token, 5) + for k := range refreshes { + wg.Add(1) + go func(k int) { + defer wg.Done() + t.Logf("Refreshing token %d", k) + refreshes[k] = refreshTokens(t, token) + }(k) + } + + wg.Wait() + for k, refresh := range refreshes { + t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) { + iat := time.Now() + introspectAccessToken(t, conf, refresh, subject) + assertJWTAccessToken(t, strategy, conf, refresh, subject, iat.Add(reg.Config().GetAccessTokenLifespan(ctx)), `["hydra","offline","openid"]`) + assertIDToken(t, refresh, conf, subject, nonce, iat.Add(reg.Config().GetIDTokenLifespan(ctx))) + assertRefreshToken(t, refresh, conf, iat.Add(reg.Config().GetRefreshTokenLifespan(ctx))) + }) + } + + // Sleep until the grace period is over + time.Sleep(time.Until(start.Add(5*time.Second + time.Millisecond*10))) + t.Run("followup=revoking consent revokes all tokens", func(t *testing.T) { + err := reg.ConsentManager().RevokeSubjectConsentSession(context.Background(), subject) + require.NoError(t, err) + + _, err = conf.TokenSource(context.Background(), token).Token() + assert.Error(t, err) + + i := testhelpers.IntrospectToken(t, conf, first.AccessToken, adminTS) + assert.False(t, i.Get("active").Bool(), "%s", i) + + i = testhelpers.IntrospectToken(t, conf, second.AccessToken, adminTS) + assert.False(t, i.Get("active").Bool(), "%s", i) + + i = testhelpers.IntrospectToken(t, conf, first.RefreshToken, adminTS) + assert.False(t, i.Get("active").Bool(), "%s", i) + + i = testhelpers.IntrospectToken(t, conf, second.RefreshToken, adminTS) + assert.False(t, i.Get("active").Bool(), "%s", i) + }) + }) + + t.Run("followup=graceful refresh tokens with multiple nested branches belong to the same request", func(t *testing.T) { start := time.Now() token := issueTokens(t) var a1Refresh, b1Refresh, a2RefreshA, a2RefreshB, b2RefreshA, b2RefreshB *oauth2.Token @@ -698,8 +750,8 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) { c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) otherClient, _ := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) testhelpers.NewLoginConsentUI(t, reg.Config(), - acceptLoginHandler(t, c, subject, nil), - acceptConsentHandler(t, c, subject, nil), + acceptLoginHandler(t, c, adminClient, reg, subject, nil), + acceptConsentHandler(t, c, adminClient, reg, subject, nil), ) withWrongClientAfterLogin := &http.Client{ @@ -813,13 +865,13 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) { t.Run("case=perform flow with prompt=registration", func(t *testing.T) { c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) - regUI := httptest.NewServer(acceptLoginHandler(t, c, subject, nil)) + regUI := httptest.NewServer(acceptLoginHandler(t, c, adminClient, reg, subject, nil)) t.Cleanup(regUI.Close) reg.Config().MustSet(ctx, config.KeyRegistrationURL, regUI.URL) testhelpers.NewLoginConsentUI(t, reg.Config(), nil, - acceptConsentHandler(t, c, subject, nil)) + acceptConsentHandler(t, c, adminClient, reg, subject, nil)) code, _ := getAuthorizeCode(t, conf, nil, oauth2.SetAuthURLParam("prompt", "registration"), @@ -836,12 +888,12 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) { expectAud := "https://api.ory.sh/" c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) testhelpers.NewLoginConsentUI(t, reg.Config(), - acceptLoginHandler(t, c, subject, func(r *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest { + acceptLoginHandler(t, c, adminClient, reg, subject, func(r *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest { assert.False(t, r.Skip) assert.EqualValues(t, []string{expectAud}, r.RequestedAccessTokenAudience) return nil }), - acceptConsentHandler(t, c, subject, func(r *hydra.OAuth2ConsentRequest) { + acceptConsentHandler(t, c, adminClient, reg, subject, func(r *hydra.OAuth2ConsentRequest) { assert.False(t, *r.Skip) assert.EqualValues(t, []string{expectAud}, r.RequestedAccessTokenAudience) })) @@ -865,8 +917,8 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) { t.Run("case=respects client token lifespan configuration", func(t *testing.T) { run := func(t *testing.T, strategy string, c *client.Client, conf *oauth2.Config, expectedLifespans client.Lifespans) { testhelpers.NewLoginConsentUI(t, reg.Config(), - acceptLoginHandler(t, c, subject, nil), - acceptConsentHandler(t, c, subject, nil), + acceptLoginHandler(t, c, adminClient, reg, subject, nil), + acceptConsentHandler(t, c, adminClient, reg, subject, nil), ) code, _ := getAuthorizeCode(t, conf, nil, oauth2.SetAuthURLParam("nonce", nonce)) @@ -967,8 +1019,8 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) { t.Run("case=use remember feature and prompt=none", func(t *testing.T) { c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) testhelpers.NewLoginConsentUI(t, reg.Config(), - acceptLoginHandler(t, c, subject, nil), - acceptConsentHandler(t, c, subject, nil), + acceptLoginHandler(t, c, adminClient, reg, subject, nil), + acceptConsentHandler(t, c, adminClient, reg, subject, nil), ) oc := testhelpers.NewEmptyJarClient(t) @@ -984,12 +1036,12 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) { // Reset UI to check for skip values testhelpers.NewLoginConsentUI(t, reg.Config(), - acceptLoginHandler(t, c, subject, func(r *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest { + acceptLoginHandler(t, c, adminClient, reg, subject, func(r *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest { require.True(t, r.Skip) require.EqualValues(t, subject, r.Subject) return nil }), - acceptConsentHandler(t, c, subject, func(r *hydra.OAuth2ConsentRequest) { + acceptConsentHandler(t, c, adminClient, reg, subject, func(r *hydra.OAuth2ConsentRequest) { require.True(t, *r.Skip) require.EqualValues(t, subject, *r.Subject) }), @@ -1038,12 +1090,12 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) { t.Run("followup=passes and resets skip when prompt=login", func(t *testing.T) { testhelpers.NewLoginConsentUI(t, reg.Config(), - acceptLoginHandler(t, c, subject, func(r *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest { + acceptLoginHandler(t, c, adminClient, reg, subject, func(r *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest { require.False(t, r.Skip) require.Empty(t, r.Subject) return nil }), - acceptConsentHandler(t, c, subject, func(r *hydra.OAuth2ConsentRequest) { + acceptConsentHandler(t, c, adminClient, reg, subject, func(r *hydra.OAuth2ConsentRequest) { require.True(t, *r.Skip) require.EqualValues(t, subject, *r.Subject) }), @@ -1065,8 +1117,8 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) { t.Run("case=should fail if prompt=none but no auth session given", func(t *testing.T) { c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) testhelpers.NewLoginConsentUI(t, reg.Config(), - acceptLoginHandler(t, c, subject, nil), - acceptConsentHandler(t, c, subject, nil), + acceptLoginHandler(t, c, adminClient, reg, subject, nil), + acceptConsentHandler(t, c, adminClient, reg, subject, nil), ) oc := testhelpers.NewEmptyJarClient(t) @@ -1079,12 +1131,12 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) { t.Run("case=requires re-authentication when id_token_hint is set to a user 'patrik-neu' but the session is 'aeneas-rekkas' and then fails because the user id from the log in endpoint is 'aeneas-rekkas'", func(t *testing.T) { c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) testhelpers.NewLoginConsentUI(t, reg.Config(), - acceptLoginHandler(t, c, subject, func(r *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest { + acceptLoginHandler(t, c, adminClient, reg, subject, func(r *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest { require.False(t, r.Skip) require.Empty(t, r.Subject) return nil }), - acceptConsentHandler(t, c, subject, nil), + acceptConsentHandler(t, c, adminClient, reg, subject, nil), ) oc := testhelpers.NewEmptyJarClient(t) @@ -1103,11 +1155,11 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) { t.Run("case=should not cause issues if max_age is very low and consent takes a long time", func(t *testing.T) { c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) testhelpers.NewLoginConsentUI(t, reg.Config(), - acceptLoginHandler(t, c, subject, func(r *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest { + acceptLoginHandler(t, c, adminClient, reg, subject, func(r *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest { time.Sleep(time.Second * 2) return nil }), - acceptConsentHandler(t, c, subject, nil), + acceptConsentHandler(t, c, adminClient, reg, subject, nil), ) code, _ := getAuthorizeCode(t, conf, nil) @@ -1117,8 +1169,8 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) { t.Run("case=ensure consistent claims returned for userinfo", func(t *testing.T) { c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) testhelpers.NewLoginConsentUI(t, reg.Config(), - acceptLoginHandler(t, c, subject, nil), - acceptConsentHandler(t, c, subject, nil), + acceptLoginHandler(t, c, adminClient, reg, subject, nil), + acceptConsentHandler(t, c, adminClient, reg, subject, nil), ) code, _ := getAuthorizeCode(t, conf, nil) @@ -1203,12 +1255,12 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) { expectAud := "https://api.ory.sh/" c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) testhelpers.NewLoginConsentUI(t, reg.Config(), - acceptLoginHandler(t, c, subject, func(r *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest { + acceptLoginHandler(t, c, adminClient, reg, subject, func(r *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest { assert.False(t, r.Skip) assert.EqualValues(t, []string{expectAud}, r.RequestedAccessTokenAudience) return nil }), - acceptConsentHandler(t, c, subject, func(r *hydra.OAuth2ConsentRequest) { + acceptConsentHandler(t, c, adminClient, reg, subject, func(r *hydra.OAuth2ConsentRequest) { assert.False(t, *r.Skip) assert.EqualValues(t, []string{expectAud}, r.RequestedAccessTokenAudience) })) @@ -1252,12 +1304,12 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) { expectAud := "https://api.ory.sh/" c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) testhelpers.NewLoginConsentUI(t, reg.Config(), - acceptLoginHandler(t, c, subject, func(r *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest { + acceptLoginHandler(t, c, adminClient, reg, subject, func(r *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest { assert.False(t, r.Skip) assert.EqualValues(t, []string{expectAud}, r.RequestedAccessTokenAudience) return nil }), - acceptConsentHandler(t, c, subject, func(r *hydra.OAuth2ConsentRequest) { + acceptConsentHandler(t, c, adminClient, reg, subject, func(r *hydra.OAuth2ConsentRequest) { assert.False(t, *r.Skip) assert.EqualValues(t, []string{expectAud}, r.RequestedAccessTokenAudience) })) @@ -1292,12 +1344,12 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) { expectAud := "https://api.ory.sh/" c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) testhelpers.NewLoginConsentUI(t, reg.Config(), - acceptLoginHandler(t, c, subject, func(r *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest { + acceptLoginHandler(t, c, adminClient, reg, subject, func(r *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest { assert.False(t, r.Skip) assert.EqualValues(t, []string{expectAud}, r.RequestedAccessTokenAudience) return nil }), - acceptConsentHandler(t, c, subject, func(r *hydra.OAuth2ConsentRequest) { + acceptConsentHandler(t, c, adminClient, reg, subject, func(r *hydra.OAuth2ConsentRequest) { assert.False(t, *r.Skip) assert.EqualValues(t, []string{expectAud}, r.RequestedAccessTokenAudience) })) @@ -1332,12 +1384,12 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) { expectAud := "https://api.ory.sh/" c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) testhelpers.NewLoginConsentUI(t, reg.Config(), - acceptLoginHandler(t, c, subject, func(r *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest { + acceptLoginHandler(t, c, adminClient, reg, subject, func(r *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest { assert.False(t, r.Skip) assert.EqualValues(t, []string{expectAud}, r.RequestedAccessTokenAudience) return nil }), - acceptConsentHandler(t, c, subject, func(r *hydra.OAuth2ConsentRequest) { + acceptConsentHandler(t, c, adminClient, reg, subject, func(r *hydra.OAuth2ConsentRequest) { assert.False(t, *r.Skip) assert.EqualValues(t, []string{expectAud}, r.RequestedAccessTokenAudience) })) @@ -1483,21 +1535,23 @@ func createVCProofJWT(t *testing.T, pubKey *jose.JSONWebKey, privKey any, nonce // - [x] should pass with prompt=login when authentication time is recent // - [x] should fail with prompt=login when authentication time is in the past func TestAuthCodeWithMockStrategy(t *testing.T) { - ctx := context.Background() - for _, strat := range []struct{ d string }{{d: "opaque"}, {d: "jwt"}} { - t.Run("strategy="+strat.d, func(t *testing.T) { - conf := internal.NewConfigurationWithDefaults() + setupRegistries(t) + + for k := range registries { + t.Run("registry="+k, func(t *testing.T) { + ctx := context.Background() + reg := internal.NewRegistrySQLFromURL(t, registries[k].Config().DSN(), true, &contextx.Default{}) + + require.NoError(t, jwk.EnsureAsymmetricKeypairExists(ctx, reg, string(jose.ES256), x.OpenIDConnectKeyName)) + require.NoError(t, jwk.EnsureAsymmetricKeypairExists(ctx, reg, string(jose.ES256), x.OAuth2JWTKeyName)) + + conf := reg.Config() conf.MustSet(ctx, config.KeyAccessTokenLifespan, time.Second*2) conf.MustSet(ctx, config.KeyScopeStrategy, "DEPRECATED_HIERARCHICAL_SCOPE_STRATEGY") - conf.MustSet(ctx, config.KeyAccessTokenStrategy, strat.d) - reg := internal.NewRegistryMemory(t, conf, &contextx.Default{}) - internal.MustEnsureRegistryKeys(ctx, reg, x.OpenIDConnectKeyName) - internal.MustEnsureRegistryKeys(ctx, reg, x.OAuth2JWTKeyName) - consentStrategy := &consentMock{} router := x.NewRouterPublic() ts := httptest.NewServer(router) - defer ts.Close() + t.Cleanup(ts.Close) reg.WithConsentStrategy(consentStrategy) handler := reg.OAuth2Handler() @@ -1511,7 +1565,7 @@ func TestAuthCodeWithMockStrategy(t *testing.T) { }) var mutex sync.Mutex - require.NoError(t, reg.ClientManager().CreateClient(context.TODO(), &client.Client{ + require.NoError(t, reg.ClientManager().CreateClient(ctx, &client.Client{ ID: "app-client", Secret: "secret", RedirectURIs: []string{ts.URL + "/callback"}, @@ -1531,524 +1585,531 @@ func TestAuthCodeWithMockStrategy(t *testing.T) { Scopes: []string{"hydra.*", "offline", "openid"}, } - var code string - for k, tc := range []struct { - cj http.CookieJar - d string - cb func(t *testing.T) httprouter.Handle - authURL string - shouldPassConsentStrategy bool - expectOAuthAuthError bool - expectOAuthTokenError bool - checkExpiry bool - authTime time.Time - requestTime time.Time - assertAccessToken func(*testing.T, string) - }{ - { - d: "should pass request if strategy passes", - authURL: oauthConfig.AuthCodeURL("some-foo-state"), - shouldPassConsentStrategy: true, - checkExpiry: true, - cb: func(t *testing.T) httprouter.Handle { - return func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { - code = r.URL.Query().Get("code") - require.NotEmpty(t, code) - _, _ = w.Write([]byte(r.URL.Query().Get("code"))) - } - }, - assertAccessToken: func(t *testing.T, token string) { - if strat.d != "jwt" { - return - } - - body, err := x.DecodeSegment(strings.Split(token, ".")[1]) - require.NoError(t, err) - - data := map[string]interface{}{} - require.NoError(t, json.Unmarshal(body, &data)) - - assert.EqualValues(t, "app-client", data["client_id"]) - assert.EqualValues(t, "foo", data["sub"]) - assert.NotEmpty(t, data["iss"]) - assert.NotEmpty(t, data["jti"]) - assert.NotEmpty(t, data["exp"]) - assert.NotEmpty(t, data["iat"]) - assert.NotEmpty(t, data["nbf"]) - assert.EqualValues(t, data["nbf"], data["iat"]) - assert.EqualValues(t, []interface{}{"offline", "openid", "hydra.*"}, data["scp"]) - }, - }, - { - d: "should fail because prompt=none and max_age > auth_time", - authURL: oauthConfig.AuthCodeURL("some-foo-state") + "&prompt=none&max_age=1", - authTime: time.Now().UTC().Add(-time.Minute), - requestTime: time.Now().UTC(), - shouldPassConsentStrategy: true, - cb: func(t *testing.T) httprouter.Handle { - return func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { - code = r.URL.Query().Get("code") - err := r.URL.Query().Get("error") - require.Empty(t, code) - require.EqualValues(t, fosite.ErrLoginRequired.Error(), err) - } - }, - expectOAuthAuthError: true, - }, - { - d: "should pass because prompt=none and max_age is less than auth_time", - authURL: oauthConfig.AuthCodeURL("some-foo-state") + "&prompt=none&max_age=3600", - authTime: time.Now().UTC().Add(-time.Minute), - requestTime: time.Now().UTC(), - shouldPassConsentStrategy: true, - cb: func(t *testing.T) httprouter.Handle { - return func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { - code = r.URL.Query().Get("code") - require.NotEmpty(t, code) - _, _ = w.Write([]byte(r.URL.Query().Get("code"))) - } - }, - }, - { - d: "should fail because prompt=none but auth_time suggests recent authentication", - authURL: oauthConfig.AuthCodeURL("some-foo-state") + "&prompt=none", - authTime: time.Now().UTC().Add(-time.Minute), - requestTime: time.Now().UTC().Add(-time.Hour), - shouldPassConsentStrategy: true, - cb: func(t *testing.T) httprouter.Handle { - return func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { - code = r.URL.Query().Get("code") - err := r.URL.Query().Get("error") - require.Empty(t, code) - require.EqualValues(t, fosite.ErrLoginRequired.Error(), err) - } - }, - expectOAuthAuthError: true, - }, - { - d: "should fail because consent strategy fails", - authURL: oauthConfig.AuthCodeURL("some-foo-state"), - expectOAuthAuthError: true, - shouldPassConsentStrategy: false, - cb: func(t *testing.T) httprouter.Handle { - return func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { - require.Empty(t, r.URL.Query().Get("code")) - assert.Equal(t, fosite.ErrRequestForbidden.Error(), r.URL.Query().Get("error")) - } - }, - }, - { - d: "should pass with prompt=login when authentication time is recent", - authURL: oauthConfig.AuthCodeURL("some-foo-state") + "&prompt=login", - authTime: time.Now().UTC().Add(-time.Second), - requestTime: time.Now().UTC().Add(-time.Minute), - shouldPassConsentStrategy: true, - cb: func(t *testing.T) httprouter.Handle { - return func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { - code = r.URL.Query().Get("code") - require.NotEmpty(t, code) - _, _ = w.Write([]byte(r.URL.Query().Get("code"))) - } - }, - }, - { - d: "should fail with prompt=login when authentication time is in the past", - authURL: oauthConfig.AuthCodeURL("some-foo-state") + "&prompt=login", - authTime: time.Now().UTC().Add(-time.Minute), - requestTime: time.Now().UTC(), - expectOAuthAuthError: true, - shouldPassConsentStrategy: true, - cb: func(t *testing.T) httprouter.Handle { - return func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { - code = r.URL.Query().Get("code") - require.Empty(t, code) - assert.Equal(t, fosite.ErrLoginRequired.Error(), r.URL.Query().Get("error")) - } - }, - }, - } { - t.Run(fmt.Sprintf("case=%d/description=%s", k, tc.d), func(t *testing.T) { - mutex.Lock() - defer mutex.Unlock() - if tc.cb == nil { - tc.cb = noopHandler - } - - consentStrategy.deny = !tc.shouldPassConsentStrategy - consentStrategy.authTime = tc.authTime - consentStrategy.requestTime = tc.requestTime - - cb := tc.cb(t) - callbackHandler = &cb - - req, err := http.NewRequest("GET", tc.authURL, nil) - require.NoError(t, err) - - if tc.cj == nil { - tc.cj = testhelpers.NewEmptyCookieJar(t) - } + for _, strat := range []struct{ d string }{{d: "opaque"}, {d: "jwt"}} { + conf := reg.Config() + conf.MustSet(ctx, config.KeyAccessTokenStrategy, strat.d) + + t.Run("strategy="+strat.d, func(t *testing.T) { + var code string + for k, tc := range []struct { + cj http.CookieJar + d string + cb func(t *testing.T) httprouter.Handle + authURL string + shouldPassConsentStrategy bool + expectOAuthAuthError bool + expectOAuthTokenError bool + checkExpiry bool + authTime time.Time + requestTime time.Time + assertAccessToken func(*testing.T, string) + }{ + { + d: "should pass request if strategy passes", + authURL: oauthConfig.AuthCodeURL("some-foo-state"), + shouldPassConsentStrategy: true, + checkExpiry: true, + cb: func(t *testing.T) httprouter.Handle { + return func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { + code = r.URL.Query().Get("code") + require.NotEmpty(t, code) + _, _ = w.Write([]byte(r.URL.Query().Get("code"))) + } + }, + assertAccessToken: func(t *testing.T, token string) { + if strat.d != "jwt" { + return + } - resp, err := (&http.Client{Jar: tc.cj}).Do(req) - require.NoError(t, err, tc.authURL, ts.URL) - defer resp.Body.Close() + body, err := x.DecodeSegment(strings.Split(token, ".")[1]) + require.NoError(t, err) - if tc.expectOAuthAuthError { - require.Empty(t, code) - return - } + data := map[string]interface{}{} + require.NoError(t, json.Unmarshal(body, &data)) + + assert.EqualValues(t, "app-client", data["client_id"]) + assert.EqualValues(t, "foo", data["sub"]) + assert.NotEmpty(t, data["iss"]) + assert.NotEmpty(t, data["jti"]) + assert.NotEmpty(t, data["exp"]) + assert.NotEmpty(t, data["iat"]) + assert.NotEmpty(t, data["nbf"]) + assert.EqualValues(t, data["nbf"], data["iat"]) + assert.EqualValues(t, []interface{}{"offline", "openid", "hydra.*"}, data["scp"]) + }, + }, + { + d: "should fail because prompt=none and max_age > auth_time", + authURL: oauthConfig.AuthCodeURL("some-foo-state") + "&prompt=none&max_age=1", + authTime: time.Now().UTC().Add(-time.Minute), + requestTime: time.Now().UTC(), + shouldPassConsentStrategy: true, + cb: func(t *testing.T) httprouter.Handle { + return func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { + code = r.URL.Query().Get("code") + err := r.URL.Query().Get("error") + require.Empty(t, code) + require.EqualValues(t, fosite.ErrLoginRequired.Error(), err) + } + }, + expectOAuthAuthError: true, + }, + { + d: "should pass because prompt=none and max_age is less than auth_time", + authURL: oauthConfig.AuthCodeURL("some-foo-state") + "&prompt=none&max_age=3600", + authTime: time.Now().UTC().Add(-time.Minute), + requestTime: time.Now().UTC(), + shouldPassConsentStrategy: true, + cb: func(t *testing.T) httprouter.Handle { + return func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { + code = r.URL.Query().Get("code") + require.NotEmpty(t, code) + _, _ = w.Write([]byte(r.URL.Query().Get("code"))) + } + }, + }, + { + d: "should fail because prompt=none but auth_time suggests recent authentication", + authURL: oauthConfig.AuthCodeURL("some-foo-state") + "&prompt=none", + authTime: time.Now().UTC().Add(-time.Minute), + requestTime: time.Now().UTC().Add(-time.Hour), + shouldPassConsentStrategy: true, + cb: func(t *testing.T) httprouter.Handle { + return func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { + code = r.URL.Query().Get("code") + err := r.URL.Query().Get("error") + require.Empty(t, code) + require.EqualValues(t, fosite.ErrLoginRequired.Error(), err) + } + }, + expectOAuthAuthError: true, + }, + { + d: "should fail because consent strategy fails", + authURL: oauthConfig.AuthCodeURL("some-foo-state"), + expectOAuthAuthError: true, + shouldPassConsentStrategy: false, + cb: func(t *testing.T) httprouter.Handle { + return func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { + require.Empty(t, r.URL.Query().Get("code")) + assert.Equal(t, fosite.ErrRequestForbidden.Error(), r.URL.Query().Get("error")) + } + }, + }, + { + d: "should pass with prompt=login when authentication time is recent", + authURL: oauthConfig.AuthCodeURL("some-foo-state") + "&prompt=login", + authTime: time.Now().UTC().Add(-time.Second), + requestTime: time.Now().UTC().Add(-time.Minute), + shouldPassConsentStrategy: true, + cb: func(t *testing.T) httprouter.Handle { + return func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { + code = r.URL.Query().Get("code") + require.NotEmpty(t, code) + _, _ = w.Write([]byte(r.URL.Query().Get("code"))) + } + }, + }, + { + d: "should fail with prompt=login when authentication time is in the past", + authURL: oauthConfig.AuthCodeURL("some-foo-state") + "&prompt=login", + authTime: time.Now().UTC().Add(-time.Minute), + requestTime: time.Now().UTC(), + expectOAuthAuthError: true, + shouldPassConsentStrategy: true, + cb: func(t *testing.T) httprouter.Handle { + return func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { + code = r.URL.Query().Get("code") + require.Empty(t, code) + assert.Equal(t, fosite.ErrLoginRequired.Error(), r.URL.Query().Get("error")) + } + }, + }, + } { + t.Run(fmt.Sprintf("case=%d/description=%s", k, tc.d), func(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + if tc.cb == nil { + tc.cb = noopHandler + } - require.NotEmpty(t, code) + consentStrategy.deny = !tc.shouldPassConsentStrategy + consentStrategy.authTime = tc.authTime + consentStrategy.requestTime = tc.requestTime - token, err := oauthConfig.Exchange(context.TODO(), code) - if tc.expectOAuthTokenError { - require.Error(t, err) - return - } + cb := tc.cb(t) + callbackHandler = &cb - require.NoError(t, err, code) - if tc.assertAccessToken != nil { - tc.assertAccessToken(t, token.AccessToken) - } - - t.Run("case=userinfo", func(t *testing.T) { - var makeRequest = func(req *http.Request) *http.Response { - resp, err = http.DefaultClient.Do(req) + req, err := http.NewRequest("GET", tc.authURL, nil) require.NoError(t, err) - return resp - } - - var testSuccess = func(response *http.Response) { - defer resp.Body.Close() - - require.Equal(t, http.StatusOK, resp.StatusCode) - - var claims map[string]interface{} - require.NoError(t, json.NewDecoder(resp.Body).Decode(&claims)) - assert.Equal(t, "foo", claims["sub"]) - } - - req, err = http.NewRequest("GET", ts.URL+"/userinfo", nil) - req.Header.Add("Authorization", "bearer "+token.AccessToken) - testSuccess(makeRequest(req)) - - req, err = http.NewRequest("POST", ts.URL+"/userinfo", nil) - req.Header.Add("Authorization", "bearer "+token.AccessToken) - testSuccess(makeRequest(req)) - - req, err = http.NewRequest("POST", ts.URL+"/userinfo", bytes.NewBuffer([]byte("access_token="+token.AccessToken))) - req.Header.Add("Content-Type", "application/x-www-form-urlencoded") - testSuccess(makeRequest(req)) - req, err = http.NewRequest("GET", ts.URL+"/userinfo", nil) - req.Header.Add("Authorization", "bearer asdfg") - resp := makeRequest(req) - require.Equal(t, http.StatusUnauthorized, resp.StatusCode) - }) + if tc.cj == nil { + tc.cj = testhelpers.NewEmptyCookieJar(t) + } - res, err := testRefresh(t, token, ts.URL, tc.checkExpiry) - require.NoError(t, err) - assert.Equal(t, http.StatusOK, res.StatusCode) + resp, err := (&http.Client{Jar: tc.cj}).Do(req) + require.NoError(t, err, tc.authURL, ts.URL) + defer resp.Body.Close() - body, err := io.ReadAll(res.Body) - require.NoError(t, err) + if tc.expectOAuthAuthError { + require.Empty(t, code) + return + } - var refreshedToken oauth2.Token - require.NoError(t, json.Unmarshal(body, &refreshedToken)) + require.NotEmpty(t, code) - if tc.assertAccessToken != nil { - tc.assertAccessToken(t, refreshedToken.AccessToken) - } + token, err := oauthConfig.Exchange(context.TODO(), code) + if tc.expectOAuthTokenError { + require.Error(t, err) + return + } - t.Run("the tokens should be different", func(t *testing.T) { - if strat.d != "jwt" { - t.Skip() - } + require.NoError(t, err, code) + if tc.assertAccessToken != nil { + tc.assertAccessToken(t, token.AccessToken) + } - body, err := x.DecodeSegment(strings.Split(token.AccessToken, ".")[1]) - require.NoError(t, err) + t.Run("case=userinfo", func(t *testing.T) { + var makeRequest = func(req *http.Request) *http.Response { + resp, err = http.DefaultClient.Do(req) + require.NoError(t, err) + return resp + } - origPayload := map[string]interface{}{} - require.NoError(t, json.Unmarshal(body, &origPayload)) + var testSuccess = func(response *http.Response) { + defer resp.Body.Close() - body, err = x.DecodeSegment(strings.Split(refreshedToken.AccessToken, ".")[1]) - require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) - refreshedPayload := map[string]interface{}{} - require.NoError(t, json.Unmarshal(body, &refreshedPayload)) + var claims map[string]interface{} + require.NoError(t, json.NewDecoder(resp.Body).Decode(&claims)) + assert.Equal(t, "foo", claims["sub"]) + } - if tc.checkExpiry { - assert.NotEqual(t, refreshedPayload["exp"], origPayload["exp"]) - assert.NotEqual(t, refreshedPayload["iat"], origPayload["iat"]) - assert.NotEqual(t, refreshedPayload["nbf"], origPayload["nbf"]) - } - assert.NotEqual(t, refreshedPayload["jti"], origPayload["jti"]) - assert.Equal(t, refreshedPayload["client_id"], origPayload["client_id"]) - }) + req, err = http.NewRequest("GET", ts.URL+"/userinfo", nil) + req.Header.Add("Authorization", "bearer "+token.AccessToken) + testSuccess(makeRequest(req)) - require.NotEqual(t, token.AccessToken, refreshedToken.AccessToken) + req, err = http.NewRequest("POST", ts.URL+"/userinfo", nil) + req.Header.Add("Authorization", "bearer "+token.AccessToken) + testSuccess(makeRequest(req)) - t.Run("old token should no longer be usable", func(t *testing.T) { - req, err := http.NewRequest("GET", ts.URL+"/userinfo", nil) - require.NoError(t, err) - req.Header.Add("Authorization", "bearer "+token.AccessToken) - res, err := http.DefaultClient.Do(req) - require.NoError(t, err) - assert.EqualValues(t, http.StatusUnauthorized, res.StatusCode) - }) + req, err = http.NewRequest("POST", ts.URL+"/userinfo", bytes.NewBuffer([]byte("access_token="+token.AccessToken))) + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + testSuccess(makeRequest(req)) - t.Run("refreshing new refresh token should work", func(t *testing.T) { - res, err := testRefresh(t, &refreshedToken, ts.URL, false) - require.NoError(t, err) - assert.Equal(t, http.StatusOK, res.StatusCode) + req, err = http.NewRequest("GET", ts.URL+"/userinfo", nil) + req.Header.Add("Authorization", "bearer asdfg") + resp := makeRequest(req) + require.Equal(t, http.StatusUnauthorized, resp.StatusCode) + }) - body, err := io.ReadAll(res.Body) - require.NoError(t, err) - require.NoError(t, json.Unmarshal(body, &refreshedToken)) - }) - - t.Run("should call refresh token hook if configured", func(t *testing.T) { - run := func(hookType string) func(t *testing.T) { - return func(t *testing.T) { - hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - assert.Equal(t, r.Header.Get("Content-Type"), "application/json; charset=UTF-8") - - expectedGrantedScopes := []string{"openid", "offline", "hydra.*"} - expectedSubject := "foo" - - exceptKeys := []string{ - "session.kid", - "session.id_token.expires_at", - "session.id_token.headers.extra.kid", - "session.id_token.id_token_claims.iat", - "session.id_token.id_token_claims.exp", - "session.id_token.id_token_claims.rat", - "session.id_token.id_token_claims.auth_time", - } + res, err := testRefresh(t, token, ts.URL, tc.checkExpiry) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, res.StatusCode) - if hookType == "legacy" { - var hookReq hydraoauth2.RefreshTokenHookRequest - require.NoError(t, json.NewDecoder(r.Body).Decode(&hookReq)) - require.Equal(t, hookReq.Subject, expectedSubject) - require.ElementsMatch(t, hookReq.GrantedScopes, expectedGrantedScopes) - require.ElementsMatch(t, hookReq.GrantedAudience, []string{}) - require.Equal(t, hookReq.ClientID, oauthConfig.ClientID) - require.NotEmpty(t, hookReq.Session) - require.Equal(t, hookReq.Session.Subject, expectedSubject) - require.Equal(t, hookReq.Session.ClientID, oauthConfig.ClientID) - require.NotEmpty(t, hookReq.Requester) - require.Equal(t, hookReq.Requester.ClientID, oauthConfig.ClientID) - require.ElementsMatch(t, hookReq.Requester.GrantedScopes, expectedGrantedScopes) - - snapshotx.SnapshotT(t, hookReq, snapshotx.ExceptPaths(exceptKeys...)) - } else { - var hookReq hydraoauth2.TokenHookRequest - require.NoError(t, json.NewDecoder(r.Body).Decode(&hookReq)) - require.NotEmpty(t, hookReq.Session) - require.Equal(t, hookReq.Session.Subject, expectedSubject) - require.Equal(t, hookReq.Session.ClientID, oauthConfig.ClientID) - require.NotEmpty(t, hookReq.Request) - require.Equal(t, hookReq.Request.ClientID, oauthConfig.ClientID) - require.ElementsMatch(t, hookReq.Request.GrantedScopes, expectedGrantedScopes) - require.ElementsMatch(t, hookReq.Request.GrantedAudience, []string{}) - require.Equal(t, hookReq.Request.Payload, map[string][]string{"grant_type": {"refresh_token"}}) - - snapshotx.SnapshotT(t, hookReq, snapshotx.ExceptPaths(exceptKeys...)) - } + body, err := io.ReadAll(res.Body) + require.NoError(t, err) - claims := map[string]interface{}{ - "hooked": hookType, - } + var refreshedToken oauth2.Token + require.NoError(t, json.Unmarshal(body, &refreshedToken)) - hookResp := hydraoauth2.TokenHookResponse{ - Session: flow.AcceptOAuth2ConsentRequestSession{ - AccessToken: claims, - IDToken: claims, - }, - } + if tc.assertAccessToken != nil { + tc.assertAccessToken(t, refreshedToken.AccessToken) + } - w.WriteHeader(http.StatusOK) - require.NoError(t, json.NewEncoder(w).Encode(&hookResp)) - })) - defer hs.Close() - - if hookType == "legacy" { - conf.MustSet(ctx, config.KeyRefreshTokenHook, hs.URL) - defer conf.MustSet(ctx, config.KeyRefreshTokenHook, nil) - } else { - conf.MustSet(ctx, config.KeyTokenHook, hs.URL) - defer conf.MustSet(ctx, config.KeyTokenHook, nil) + t.Run("the tokens should be different", func(t *testing.T) { + if strat.d != "jwt" { + t.Skip() } - res, err := testRefresh(t, &refreshedToken, ts.URL, false) + body, err := x.DecodeSegment(strings.Split(token.AccessToken, ".")[1]) require.NoError(t, err) - assert.Equal(t, http.StatusOK, res.StatusCode) - body, err := io.ReadAll(res.Body) - require.NoError(t, err) - require.NoError(t, json.Unmarshal(body, &refreshedToken)) - - accessTokenClaims := testhelpers.IntrospectToken(t, oauthConfig, refreshedToken.AccessToken, ts) - require.Equal(t, accessTokenClaims.Get("ext.hooked").String(), hookType) + origPayload := map[string]interface{}{} + require.NoError(t, json.Unmarshal(body, &origPayload)) - idTokenBody, err := x.DecodeSegment( - strings.Split( - gjson.GetBytes(body, "id_token").String(), - ".", - )[1], - ) + body, err = x.DecodeSegment(strings.Split(refreshedToken.AccessToken, ".")[1]) require.NoError(t, err) - require.Equal(t, gjson.GetBytes(idTokenBody, "hooked").String(), hookType) - } - } - t.Run("hook=legacy", run("legacy")) - t.Run("hook=new", run("new")) - }) + refreshedPayload := map[string]interface{}{} + require.NoError(t, json.Unmarshal(body, &refreshedPayload)) - t.Run("should not override session data if token refresh hook returns no content", func(t *testing.T) { - run := func(hookType string) func(t *testing.T) { - return func(t *testing.T) { - hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusNoContent) - })) - defer hs.Close() - - if hookType == "legacy" { - conf.MustSet(ctx, config.KeyRefreshTokenHook, hs.URL) - defer conf.MustSet(ctx, config.KeyRefreshTokenHook, nil) - } else { - conf.MustSet(ctx, config.KeyTokenHook, hs.URL) - defer conf.MustSet(ctx, config.KeyTokenHook, nil) + if tc.checkExpiry { + assert.NotEqual(t, refreshedPayload["exp"], origPayload["exp"]) + assert.NotEqual(t, refreshedPayload["iat"], origPayload["iat"]) + assert.NotEqual(t, refreshedPayload["nbf"], origPayload["nbf"]) } + assert.NotEqual(t, refreshedPayload["jti"], origPayload["jti"]) + assert.Equal(t, refreshedPayload["client_id"], origPayload["client_id"]) + }) + + require.NotEqual(t, token.AccessToken, refreshedToken.AccessToken) - origAccessTokenClaims := testhelpers.IntrospectToken(t, oauthConfig, refreshedToken.AccessToken, ts) + t.Run("old token should no longer be usable", func(t *testing.T) { + req, err := http.NewRequest("GET", ts.URL+"/userinfo", nil) + require.NoError(t, err) + req.Header.Add("Authorization", "bearer "+token.AccessToken) + res, err := http.DefaultClient.Do(req) + require.NoError(t, err) + assert.EqualValues(t, http.StatusUnauthorized, res.StatusCode) + }) + t.Run("refreshing new refresh token should work", func(t *testing.T) { res, err := testRefresh(t, &refreshedToken, ts.URL, false) require.NoError(t, err) assert.Equal(t, http.StatusOK, res.StatusCode) - body, err = io.ReadAll(res.Body) + body, err := io.ReadAll(res.Body) require.NoError(t, err) - require.NoError(t, json.Unmarshal(body, &refreshedToken)) - - refreshedAccessTokenClaims := testhelpers.IntrospectToken(t, oauthConfig, refreshedToken.AccessToken, ts) - assertx.EqualAsJSONExcept(t, json.RawMessage(origAccessTokenClaims.Raw), json.RawMessage(refreshedAccessTokenClaims.Raw), []string{"exp", "iat", "nbf"}) - } - } - t.Run("hook=legacy", run("legacy")) - t.Run("hook=new", run("new")) - }) - - t.Run("should fail token refresh with `server_error` if refresh hook fails", func(t *testing.T) { - run := func(hookType string) func(t *testing.T) { - return func(t *testing.T) { - hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusInternalServerError) - })) - defer hs.Close() - - if hookType == "legacy" { - conf.MustSet(ctx, config.KeyRefreshTokenHook, hs.URL) - defer conf.MustSet(ctx, config.KeyRefreshTokenHook, nil) - } else { - conf.MustSet(ctx, config.KeyTokenHook, hs.URL) - defer conf.MustSet(ctx, config.KeyTokenHook, nil) + }) + + t.Run("should call refresh token hook if configured", func(t *testing.T) { + run := func(hookType string) func(t *testing.T) { + return func(t *testing.T) { + hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, r.Header.Get("Content-Type"), "application/json; charset=UTF-8") + + expectedGrantedScopes := []string{"openid", "offline", "hydra.*"} + expectedSubject := "foo" + + exceptKeys := []string{ + "session.kid", + "session.id_token.expires_at", + "session.id_token.headers.extra.kid", + "session.id_token.id_token_claims.iat", + "session.id_token.id_token_claims.exp", + "session.id_token.id_token_claims.rat", + "session.id_token.id_token_claims.auth_time", + } + + if hookType == "legacy" { + var hookReq hydraoauth2.RefreshTokenHookRequest + require.NoError(t, json.NewDecoder(r.Body).Decode(&hookReq)) + require.Equal(t, hookReq.Subject, expectedSubject) + require.ElementsMatch(t, hookReq.GrantedScopes, expectedGrantedScopes) + require.ElementsMatch(t, hookReq.GrantedAudience, []string{}) + require.Equal(t, hookReq.ClientID, oauthConfig.ClientID) + require.NotEmpty(t, hookReq.Session) + require.Equal(t, hookReq.Session.Subject, expectedSubject) + require.Equal(t, hookReq.Session.ClientID, oauthConfig.ClientID) + require.NotEmpty(t, hookReq.Requester) + require.Equal(t, hookReq.Requester.ClientID, oauthConfig.ClientID) + require.ElementsMatch(t, hookReq.Requester.GrantedScopes, expectedGrantedScopes) + + snapshotx.SnapshotT(t, hookReq, snapshotx.ExceptPaths(exceptKeys...)) + } else { + var hookReq hydraoauth2.TokenHookRequest + require.NoError(t, json.NewDecoder(r.Body).Decode(&hookReq)) + require.NotEmpty(t, hookReq.Session) + require.Equal(t, hookReq.Session.Subject, expectedSubject) + require.Equal(t, hookReq.Session.ClientID, oauthConfig.ClientID) + require.NotEmpty(t, hookReq.Request) + require.Equal(t, hookReq.Request.ClientID, oauthConfig.ClientID) + require.ElementsMatch(t, hookReq.Request.GrantedScopes, expectedGrantedScopes) + require.ElementsMatch(t, hookReq.Request.GrantedAudience, []string{}) + require.Equal(t, hookReq.Request.Payload, map[string][]string{"grant_type": {"refresh_token"}}) + + snapshotx.SnapshotT(t, hookReq, snapshotx.ExceptPaths(exceptKeys...)) + } + + claims := map[string]interface{}{ + "hooked": hookType, + } + + hookResp := hydraoauth2.TokenHookResponse{ + Session: flow.AcceptOAuth2ConsentRequestSession{ + AccessToken: claims, + IDToken: claims, + }, + } + + w.WriteHeader(http.StatusOK) + require.NoError(t, json.NewEncoder(w).Encode(&hookResp)) + })) + defer hs.Close() + + if hookType == "legacy" { + conf.MustSet(ctx, config.KeyRefreshTokenHook, hs.URL) + defer conf.MustSet(ctx, config.KeyRefreshTokenHook, nil) + } else { + conf.MustSet(ctx, config.KeyTokenHook, hs.URL) + defer conf.MustSet(ctx, config.KeyTokenHook, nil) + } + + res, err := testRefresh(t, &refreshedToken, ts.URL, false) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, res.StatusCode) + + body, err := io.ReadAll(res.Body) + require.NoError(t, err) + require.NoError(t, json.Unmarshal(body, &refreshedToken)) + + accessTokenClaims := testhelpers.IntrospectToken(t, oauthConfig, refreshedToken.AccessToken, ts) + require.Equal(t, accessTokenClaims.Get("ext.hooked").String(), hookType) + + idTokenBody, err := x.DecodeSegment( + strings.Split( + gjson.GetBytes(body, "id_token").String(), + ".", + )[1], + ) + require.NoError(t, err) + + require.Equal(t, gjson.GetBytes(idTokenBody, "hooked").String(), hookType) + } } - - res, err := testRefresh(t, &refreshedToken, ts.URL, false) - require.NoError(t, err) - assert.Equal(t, http.StatusInternalServerError, res.StatusCode) - - var errBody fosite.RFC6749ErrorJson - require.NoError(t, json.NewDecoder(res.Body).Decode(&errBody)) - require.Equal(t, fosite.ErrServerError.Error(), errBody.Name) - require.Equal(t, "An error occurred while executing the token hook.", errBody.Description) - } - } - t.Run("hook=legacy", run("legacy")) - t.Run("hook=new", run("new")) - }) - - t.Run("should fail token refresh with `access_denied` if legacy refresh hook denied the request", func(t *testing.T) { - run := func(hookType string) func(t *testing.T) { - return func(t *testing.T) { - hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusForbidden) - })) - defer hs.Close() - - if hookType == "legacy" { - conf.MustSet(ctx, config.KeyRefreshTokenHook, hs.URL) - defer conf.MustSet(ctx, config.KeyRefreshTokenHook, nil) - } else { - conf.MustSet(ctx, config.KeyTokenHook, hs.URL) - defer conf.MustSet(ctx, config.KeyTokenHook, nil) + t.Run("hook=legacy", run("legacy")) + t.Run("hook=new", run("new")) + }) + + t.Run("should not override session data if token refresh hook returns no content", func(t *testing.T) { + run := func(hookType string) func(t *testing.T) { + return func(t *testing.T) { + hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNoContent) + })) + defer hs.Close() + + if hookType == "legacy" { + conf.MustSet(ctx, config.KeyRefreshTokenHook, hs.URL) + defer conf.MustSet(ctx, config.KeyRefreshTokenHook, nil) + } else { + conf.MustSet(ctx, config.KeyTokenHook, hs.URL) + defer conf.MustSet(ctx, config.KeyTokenHook, nil) + } + + origAccessTokenClaims := testhelpers.IntrospectToken(t, oauthConfig, refreshedToken.AccessToken, ts) + + res, err := testRefresh(t, &refreshedToken, ts.URL, false) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, res.StatusCode) + + body, err = io.ReadAll(res.Body) + require.NoError(t, err) + + require.NoError(t, json.Unmarshal(body, &refreshedToken)) + + refreshedAccessTokenClaims := testhelpers.IntrospectToken(t, oauthConfig, refreshedToken.AccessToken, ts) + assertx.EqualAsJSONExcept(t, json.RawMessage(origAccessTokenClaims.Raw), json.RawMessage(refreshedAccessTokenClaims.Raw), []string{"exp", "iat", "nbf"}) + } + } + t.Run("hook=legacy", run("legacy")) + t.Run("hook=new", run("new")) + }) + + t.Run("should fail token refresh with `server_error` if refresh hook fails", func(t *testing.T) { + run := func(hookType string) func(t *testing.T) { + return func(t *testing.T) { + hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer hs.Close() + + if hookType == "legacy" { + conf.MustSet(ctx, config.KeyRefreshTokenHook, hs.URL) + defer conf.MustSet(ctx, config.KeyRefreshTokenHook, nil) + } else { + conf.MustSet(ctx, config.KeyTokenHook, hs.URL) + defer conf.MustSet(ctx, config.KeyTokenHook, nil) + } + + res, err := testRefresh(t, &refreshedToken, ts.URL, false) + require.NoError(t, err) + assert.Equal(t, http.StatusInternalServerError, res.StatusCode) + + var errBody fosite.RFC6749ErrorJson + require.NoError(t, json.NewDecoder(res.Body).Decode(&errBody)) + require.Equal(t, fosite.ErrServerError.Error(), errBody.Name) + require.Equal(t, "An error occurred while executing the token hook.", errBody.Description) + } } + t.Run("hook=legacy", run("legacy")) + t.Run("hook=new", run("new")) + }) + + t.Run("should fail token refresh with `access_denied` if legacy refresh hook denied the request", func(t *testing.T) { + run := func(hookType string) func(t *testing.T) { + return func(t *testing.T) { + hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusForbidden) + })) + defer hs.Close() + + if hookType == "legacy" { + conf.MustSet(ctx, config.KeyRefreshTokenHook, hs.URL) + defer conf.MustSet(ctx, config.KeyRefreshTokenHook, nil) + } else { + conf.MustSet(ctx, config.KeyTokenHook, hs.URL) + defer conf.MustSet(ctx, config.KeyTokenHook, nil) + } + + res, err := testRefresh(t, &refreshedToken, ts.URL, false) + require.NoError(t, err) + assert.Equal(t, http.StatusForbidden, res.StatusCode) + + var errBody fosite.RFC6749ErrorJson + require.NoError(t, json.NewDecoder(res.Body).Decode(&errBody)) + require.Equal(t, fosite.ErrAccessDenied.Error(), errBody.Name) + require.Equal(t, "The token hook target responded with an error. Make sure that the request you are making is valid. Maybe the credential or request parameters you are using are limited in scope or otherwise restricted.", errBody.Description) + } + } + t.Run("hook=legacy", run("legacy")) + t.Run("hook=new", run("new")) + }) + + t.Run("should fail token refresh with `server_error` if refresh hook response is malformed", func(t *testing.T) { + run := func(hookType string) func(t *testing.T) { + return func(t *testing.T) { + hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer hs.Close() + + if hookType == "legacy" { + conf.MustSet(ctx, config.KeyRefreshTokenHook, hs.URL) + defer conf.MustSet(ctx, config.KeyRefreshTokenHook, nil) + } else { + conf.MustSet(ctx, config.KeyTokenHook, hs.URL) + defer conf.MustSet(ctx, config.KeyTokenHook, nil) + } + + res, err := testRefresh(t, &refreshedToken, ts.URL, false) + require.NoError(t, err) + assert.Equal(t, http.StatusInternalServerError, res.StatusCode) + + var errBody fosite.RFC6749ErrorJson + require.NoError(t, json.NewDecoder(res.Body).Decode(&errBody)) + require.Equal(t, fosite.ErrServerError.Error(), errBody.Name) + require.Equal(t, "The token hook target responded with an error.", errBody.Description) + } + } + t.Run("hook=legacy", run("legacy")) + t.Run("hook=new", run("new")) + }) - res, err := testRefresh(t, &refreshedToken, ts.URL, false) + t.Run("refreshing old token should no longer work", func(t *testing.T) { + res, err := testRefresh(t, token, ts.URL, false) require.NoError(t, err) - assert.Equal(t, http.StatusForbidden, res.StatusCode) - - var errBody fosite.RFC6749ErrorJson - require.NoError(t, json.NewDecoder(res.Body).Decode(&errBody)) - require.Equal(t, fosite.ErrAccessDenied.Error(), errBody.Name) - require.Equal(t, "The token hook target responded with an error. Make sure that the request you are making is valid. Maybe the credential or request parameters you are using are limited in scope or otherwise restricted.", errBody.Description) - } - } - t.Run("hook=legacy", run("legacy")) - t.Run("hook=new", run("new")) - }) - - t.Run("should fail token refresh with `server_error` if refresh hook response is malformed", func(t *testing.T) { - run := func(hookType string) func(t *testing.T) { - return func(t *testing.T) { - hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - })) - defer hs.Close() - - if hookType == "legacy" { - conf.MustSet(ctx, config.KeyRefreshTokenHook, hs.URL) - defer conf.MustSet(ctx, config.KeyRefreshTokenHook, nil) - } else { - conf.MustSet(ctx, config.KeyTokenHook, hs.URL) - defer conf.MustSet(ctx, config.KeyTokenHook, nil) - } + assert.Equal(t, http.StatusUnauthorized, res.StatusCode) + }) + t.Run("attempt to refresh old token should revoke new token", func(t *testing.T) { res, err := testRefresh(t, &refreshedToken, ts.URL, false) require.NoError(t, err) - assert.Equal(t, http.StatusInternalServerError, res.StatusCode) - - var errBody fosite.RFC6749ErrorJson - require.NoError(t, json.NewDecoder(res.Body).Decode(&errBody)) - require.Equal(t, fosite.ErrServerError.Error(), errBody.Name) - require.Equal(t, "The token hook target responded with an error.", errBody.Description) - } - } - t.Run("hook=legacy", run("legacy")) - t.Run("hook=new", run("new")) - }) - - t.Run("refreshing old token should no longer work", func(t *testing.T) { - res, err := testRefresh(t, token, ts.URL, false) - require.NoError(t, err) - assert.Equal(t, http.StatusUnauthorized, res.StatusCode) - }) + assert.Equal(t, http.StatusUnauthorized, res.StatusCode) + }) - t.Run("attempt to refresh old token should revoke new token", func(t *testing.T) { - res, err := testRefresh(t, &refreshedToken, ts.URL, false) - require.NoError(t, err) - assert.Equal(t, http.StatusUnauthorized, res.StatusCode) - }) + t.Run("duplicate code exchange fails", func(t *testing.T) { + token, err := oauthConfig.Exchange(context.TODO(), code) + require.Error(t, err) + require.Nil(t, token) + }) - t.Run("duplicate code exchange fails", func(t *testing.T) { - token, err := oauthConfig.Exchange(context.TODO(), code) - require.Error(t, err) - require.Nil(t, token) - }) - - code = "" + code = "" + }) + } }) } }) From f403a207c05c1c0c05db7f0e67ce0965776d72c9 Mon Sep 17 00:00:00 2001 From: aeneasr <3372410+aeneasr@users.noreply.github.com> Date: Wed, 27 Nov 2024 22:05:58 +0100 Subject: [PATCH 3/5] chore: synchronize workspaces --- oauth2/oauth2_auth_code_test.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/oauth2/oauth2_auth_code_test.go b/oauth2/oauth2_auth_code_test.go index 10e629c6cba..a106006df73 100644 --- a/oauth2/oauth2_auth_code_test.go +++ b/oauth2/oauth2_auth_code_test.go @@ -1537,7 +1537,9 @@ func createVCProofJWT(t *testing.T, pubKey *jose.JSONWebKey, privKey any, nonce func TestAuthCodeWithMockStrategy(t *testing.T) { setupRegistries(t) - for k := range registries { + for k := range map[string]driver.Registry{ + "cockroach": registries["cockroach"], + } { t.Run("registry="+k, func(t *testing.T) { ctx := context.Background() reg := internal.NewRegistrySQLFromURL(t, registries[k].Config().DSN(), true, &contextx.Default{}) From 01fba867963c721b70909cf272a34b69d59b2ef5 Mon Sep 17 00:00:00 2001 From: aeneasr <3372410+aeneasr@users.noreply.github.com> Date: Fri, 29 Nov 2024 09:39:56 +0100 Subject: [PATCH 4/5] chore: synchronize workspaces --- go.mod | 4 +- go.sum | 2 - oauth2/oauth2_auth_code_test.go | 1261 ++++++++++++++------------- persistence/sql/persister_oauth2.go | 144 ++- 4 files changed, 790 insertions(+), 621 deletions(-) diff --git a/go.mod b/go.mod index 8092bf8a20a..809bb143869 100644 --- a/go.mod +++ b/go.mod @@ -13,7 +13,9 @@ replace github.com/gobuffalo/pop/v6 => github.com/ory/pop/v6 v6.2.0 // // This is needed until we release the next version of the master branch, as that branch already contains the redirect URI validation fix, which // may be breaking for some users. -replace github.com/ory/fosite => github.com/ory/fosite v0.47.1-0.20241125094724-b6468e644902 +//replace github.com/ory/fosite => github.com/ory/fosite v0.47.1-0.20241125094724-b6468e644902 + +replace github.com/ory/fosite => ../fosite require ( github.com/ThalesIgnite/crypto11 v1.2.5 diff --git a/go.sum b/go.sum index a81e9e76d2b..2e7f7a6fb05 100644 --- a/go.sum +++ b/go.sum @@ -378,8 +378,6 @@ github.com/ory/analytics-go/v5 v5.0.1 h1:LX8T5B9FN8KZXOtxgN+R3I4THRRVB6+28IKgKBp github.com/ory/analytics-go/v5 v5.0.1/go.mod h1:lWCiCjAaJkKfgR/BN5DCLMol8BjKS1x+4jxBxff/FF0= github.com/ory/dockertest/v3 v3.10.1-0.20240704115616-d229e74b748d h1:By96ZSVuH5LyjXLVVMfvJoLVGHaT96LdOnwgFSLVf0E= github.com/ory/dockertest/v3 v3.10.1-0.20240704115616-d229e74b748d/go.mod h1:F2FIjwwAk6CsNAs//B8+aPFQF0t84pbM8oliyNXwQrk= -github.com/ory/fosite v0.47.1-0.20241125094724-b6468e644902 h1:X0ngo+uPWCw90ueY3Kh6q8IyF2fbwkJ8bf9RvAmD71U= -github.com/ory/fosite v0.47.1-0.20241125094724-b6468e644902/go.mod h1:AZyn1jrABUaGN12RHcWorRLbqLn52gTdHaIYY81m5J0= github.com/ory/go-acc v0.2.9-0.20230103102148-6b1c9a70dbbe h1:rvu4obdvqR0fkSIJ8IfgzKOWwZ5kOT2UNfLq81Qk7rc= github.com/ory/go-acc v0.2.9-0.20230103102148-6b1c9a70dbbe/go.mod h1:z4n3u6as84LbV4YmgjHhnwtccQqzf4cZlSk9f1FhygI= github.com/ory/go-convenience v0.1.0 h1:zouLKfF2GoSGnJwGq+PE/nJAE6dj2Zj5QlTgmMTsTS8= diff --git a/oauth2/oauth2_auth_code_test.go b/oauth2/oauth2_auth_code_test.go index a106006df73..19407ccf56b 100644 --- a/oauth2/oauth2_auth_code_test.go +++ b/oauth2/oauth2_auth_code_test.go @@ -166,8 +166,14 @@ func acceptConsentHandler(t *testing.T, c *client.Client, adminClient *hydra.API // - [x] If `id_token_hint` is handled properly // - [x] What happens if `id_token_hint` does not match the value from the handled authentication request ("accept login") func TestAuthCodeWithDefaultStrategy(t *testing.T) { + setupRegistries(t) + ctx := context.Background() - reg := internal.NewMockedRegistry(t, &contextx.Default{}) + reg := internal.NewRegistrySQLFromURL(t, registries["cockroach"].Config().DSN(), true, &contextx.Default{}) + + require.NoError(t, jwk.EnsureAsymmetricKeypairExists(ctx, reg, string(jose.ES256), x.OpenIDConnectKeyName)) + require.NoError(t, jwk.EnsureAsymmetricKeypairExists(ctx, reg, string(jose.ES256), x.OAuth2JWTKeyName)) + reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "opaque") reg.Config().MustSet(ctx, config.KeyRefreshTokenHook, "") publicTS, adminTS := testhelpers.NewOAuth2Server(ctx, t, reg) @@ -337,9 +343,15 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) { t.Run("case=graceful token rotation", func(t *testing.T) { run := func(t *testing.T, strategy string) { - reg.Config().MustSet(ctx, config.KeyRefreshTokenRotationGracePeriod, "5s") + reg.Config().MustSet(ctx, config.KeyRefreshTokenRotationGracePeriod, time.Second*3) + reg.Config().MustSet(ctx, config.KeyTokenHook, nil) + reg.Config().MustSet(ctx, config.KeyRefreshTokenHook, nil) + reg.Config().MustSet(ctx, config.KeyRefreshTokenLifespan, time.Minute) + reg.Config().MustSet(ctx, config.KeyAccessTokenLifespan, time.Minute) t.Cleanup(func() { reg.Config().MustSet(ctx, config.KeyRefreshTokenRotationGracePeriod, nil) + reg.Config().MustSet(ctx, config.KeyRefreshTokenLifespan, nil) + reg.Config().MustSet(ctx, config.KeyAccessTokenLifespan, nil) }) c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) @@ -364,7 +376,7 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) { refreshTokens := func(t *testing.T, token *oauth2.Token) *oauth2.Token { require.NotEmpty(t, token.RefreshToken) - token.Expiry = token.Expiry.Add(-time.Hour * 24) + token.Expiry = time.Now().Add(-time.Hour * 24) iat := time.Now() refreshedToken, err := conf.TokenSource(context.Background(), token).Token() require.NoError(t, err) @@ -380,40 +392,63 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) { return refreshedToken } - t.Run("followup=graceful token refresh with reuse detection", func(t *testing.T) { - start := time.Now() + t.Run("followup=graceful token refresh can handle concurrent refreshing", func(t *testing.T) { + code, _ := getAuthorizeCode(t, conf, nil, oauth2.SetAuthURLParam("nonce", nonce)) + require.NotEmpty(t, code) + token, err := conf.Exchange(context.Background(), code) + require.NoError(t, err) - token := issueTokens(t) - var first, second *oauth2.Token - t.Run("followup=first refresh", func(t *testing.T) { - first = refreshTokens(t, token) - }) + var wg sync.WaitGroup + refresh := func(t *testing.T, token *oauth2.Token) *oauth2.Token { + require.NotEmpty(t, token.RefreshToken) + token.Expiry = time.Now().Add(-time.Hour * 24) + tt, err := conf.TokenSource(context.Background(), token).Token() + require.NoError(t, err) + return tt + } - t.Run("followup=second refresh", func(t *testing.T) { - second = refreshTokens(t, token) - }) + refreshes := make([]*oauth2.Token, 5) + for k := range refreshes { + wg.Add(1) + //time.Sleep(time.Millisecond * 100) + go func(k int) { + defer wg.Done() + t.Logf("Refreshing token %d", k) + refreshes[k] = refresh(t, token) + }(k) + } - // Sleep until the grace period is over - time.Sleep(time.Until(start.Add(5*time.Second + time.Millisecond*10))) - t.Run("followup=refresh failure invalidates all tokens", func(t *testing.T) { - _, err := conf.TokenSource(context.Background(), token).Token() - assert.Error(t, err) + wg.Wait() - i := testhelpers.IntrospectToken(t, conf, first.AccessToken, adminTS) - assert.False(t, i.Get("active").Bool(), "%s", i) + for k, actual := range refreshes { + refresh := actual - i = testhelpers.IntrospectToken(t, conf, second.AccessToken, adminTS) - assert.False(t, i.Get("active").Bool(), "%s", i) + require.NotEmpty(t, refresh.RefreshToken) + require.NotEmpty(t, refresh.AccessToken) + require.NotEmpty(t, refresh.Extra("id_token")) - i = testhelpers.IntrospectToken(t, conf, first.RefreshToken, adminTS) - assert.False(t, i.Get("active").Bool(), "%s", i) + i := testhelpers.IntrospectToken(t, conf, refresh.AccessToken, adminTS) + assert.Truef(t, i.Get("active").Bool(), "token %d:\ntoken:%+v\nresult:%s", k, refresh, i) - i = testhelpers.IntrospectToken(t, conf, second.RefreshToken, adminTS) - assert.False(t, i.Get("active").Bool(), "%s", i) + i = testhelpers.IntrospectToken(t, conf, refresh.RefreshToken, adminTS) + assert.Truef(t, i.Get("active").Bool(), "token %d:\ntoken:%+v\nresult:%s", k, refresh, i) + } + + t.Run("followup=revoking consent revokes all tokens", func(t *testing.T) { + err := reg.ConsentManager().RevokeSubjectConsentSession(context.Background(), subject) + require.NoError(t, err) + + _, err = conf.TokenSource(context.Background(), token).Token() + assert.Error(t, err) + + for k, actual := range refreshes { + i := testhelpers.IntrospectToken(t, conf, actual.AccessToken, adminTS) + assert.False(t, i.Get("active").Bool(), "token %d: %s", k, i) + } }) }) - t.Run("followup=graceful token refresh with reuse detection with consent revocation", func(t *testing.T) { + t.Run("followup=graceful token refresh with reuse detection", func(t *testing.T) { start := time.Now() token := issueTokens(t) @@ -426,66 +461,36 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) { second = refreshTokens(t, token) }) - // Sleep until the grace period is over - time.Sleep(time.Until(start.Add(5*time.Second + time.Millisecond*10))) - t.Run("followup=revoking consent revokes all tokens", func(t *testing.T) { - err := reg.ConsentManager().RevokeSubjectConsentSession(context.Background(), subject) - require.NoError(t, err) - - _, err = conf.TokenSource(context.Background(), token).Token() - assert.Error(t, err) - + t.Run("followup=all resulting tokens are valid", func(t *testing.T) { i := testhelpers.IntrospectToken(t, conf, first.AccessToken, adminTS) - assert.False(t, i.Get("active").Bool(), "%s", i) + assert.True(t, i.Get("active").Bool(), "%s", i) i = testhelpers.IntrospectToken(t, conf, second.AccessToken, adminTS) - assert.False(t, i.Get("active").Bool(), "%s", i) + assert.True(t, i.Get("active").Bool(), "%s", i) i = testhelpers.IntrospectToken(t, conf, first.RefreshToken, adminTS) - assert.False(t, i.Get("active").Bool(), "%s", i) + assert.True(t, i.Get("active").Bool(), "%s", i) i = testhelpers.IntrospectToken(t, conf, second.RefreshToken, adminTS) - assert.False(t, i.Get("active").Bool(), "%s", i) + assert.True(t, i.Get("active").Bool(), "%s", i) }) - }) - - t.Run("followup=graceful token refresh can handle concurrent refreshing", func(t *testing.T) { - start := time.Now() - token := issueTokens(t) - var first, second *oauth2.Token - var wg sync.WaitGroup - refreshes := make([]*oauth2.Token, 5) - for k := range refreshes { - wg.Add(1) - go func(k int) { - defer wg.Done() - t.Logf("Refreshing token %d", k) - refreshes[k] = refreshTokens(t, token) - }(k) - } + // Sleep until the grace period is over + time.Sleep(time.Until(start.Add(4 * time.Second))) - wg.Wait() - for k, refresh := range refreshes { - t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) { - iat := time.Now() - introspectAccessToken(t, conf, refresh, subject) - assertJWTAccessToken(t, strategy, conf, refresh, subject, iat.Add(reg.Config().GetAccessTokenLifespan(ctx)), `["hydra","offline","openid"]`) - assertIDToken(t, refresh, conf, subject, nonce, iat.Add(reg.Config().GetIDTokenLifespan(ctx))) - assertRefreshToken(t, refresh, conf, iat.Add(reg.Config().GetRefreshTokenLifespan(ctx))) - }) - } + t.Run("followup=refresh failure invalidates all tokens", func(t *testing.T) { + // Fetching the token again should cause an error because we are no longer in the grace period. + token.Expiry = time.Now().Add(-time.Hour * 24) + i := testhelpers.IntrospectToken(t, conf, token.AccessToken, adminTS) + assert.False(t, i.Get("active").Bool(), "%s", i) - // Sleep until the grace period is over - time.Sleep(time.Until(start.Add(5*time.Second + time.Millisecond*10))) - t.Run("followup=revoking consent revokes all tokens", func(t *testing.T) { - err := reg.ConsentManager().RevokeSubjectConsentSession(context.Background(), subject) - require.NoError(t, err) + i = testhelpers.IntrospectToken(t, conf, token.RefreshToken, adminTS) + assert.False(t, i.Get("active").Bool(), "%s", i) - _, err = conf.TokenSource(context.Background(), token).Token() - assert.Error(t, err) + result, err := conf.TokenSource(context.Background(), token).Token() + assert.Error(t, err, "%+v", result) - i := testhelpers.IntrospectToken(t, conf, first.AccessToken, adminTS) + i = testhelpers.IntrospectToken(t, conf, first.AccessToken, adminTS) assert.False(t, i.Get("active").Bool(), "%s", i) i = testhelpers.IntrospectToken(t, conf, second.AccessToken, adminTS) @@ -498,60 +503,92 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) { assert.False(t, i.Get("active").Bool(), "%s", i) }) }) - - t.Run("followup=graceful refresh tokens with multiple nested branches belong to the same request", func(t *testing.T) { - start := time.Now() - token := issueTokens(t) - var a1Refresh, b1Refresh, a2RefreshA, a2RefreshB, b2RefreshA, b2RefreshB *oauth2.Token - t.Run("followup=first refresh", func(t *testing.T) { - a1Refresh = refreshTokens(t, token) - }) - - t.Run("followup=second refresh", func(t *testing.T) { - b1Refresh = refreshTokens(t, token) - }) - - t.Run("followup=first refresh from first refresh", func(t *testing.T) { - a2RefreshA = refreshTokens(t, a1Refresh) - }) - - t.Run("followup=second refresh from first refresh", func(t *testing.T) { - a2RefreshB = refreshTokens(t, a1Refresh) - }) - - t.Run("followup=first refresh from second refresh", func(t *testing.T) { - b2RefreshA = refreshTokens(t, b1Refresh) - }) - - t.Run("followup=second refresh from second refresh", func(t *testing.T) { - b2RefreshB = refreshTokens(t, b1Refresh) - }) - - // Sleep until the grace period is over - time.Sleep(time.Until(start.Add(5*time.Second + time.Millisecond*10))) - t.Run("followup=refresh failure invalidates all tokens", func(t *testing.T) { - _, err := conf.TokenSource(context.Background(), token).Token() - assert.Error(t, err) - - for k, token := range []*oauth2.Token{ - a1Refresh, b1Refresh, a2RefreshA, a2RefreshB, b2RefreshA, b2RefreshB, - } { - t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) { - i := testhelpers.IntrospectToken(t, conf, token.AccessToken, adminTS) - assert.False(t, i.Get("active").Bool(), "%s", i) - - i = testhelpers.IntrospectToken(t, conf, token.AccessToken, adminTS) - assert.False(t, i.Get("active").Bool(), "%s", i) - - i = testhelpers.IntrospectToken(t, conf, token.RefreshToken, adminTS) - assert.False(t, i.Get("active").Bool(), "%s", i) - - i = testhelpers.IntrospectToken(t, conf, token.RefreshToken, adminTS) - assert.False(t, i.Get("active").Bool(), "%s", i) - }) - } - }) - }) + // + //t.Run("followup=graceful token refresh with reuse detection with consent revocation", func(t *testing.T) { + // token := issueTokens(t) + // var first, second *oauth2.Token + // t.Run("followup=first refresh", func(t *testing.T) { + // first = refreshTokens(t, token) + // }) + // + // t.Run("followup=second refresh", func(t *testing.T) { + // second = refreshTokens(t, token) + // }) + // + // t.Run("followup=revoking consent revokes all tokens", func(t *testing.T) { + // err := reg.ConsentManager().RevokeSubjectConsentSession(context.Background(), subject) + // require.NoError(t, err) + // + // _, err = conf.TokenSource(context.Background(), token).Token() + // assert.Error(t, err) + // + // i := testhelpers.IntrospectToken(t, conf, first.AccessToken, adminTS) + // assert.False(t, i.Get("active").Bool(), "%s", i) + // + // i = testhelpers.IntrospectToken(t, conf, second.AccessToken, adminTS) + // assert.False(t, i.Get("active").Bool(), "%s", i) + // + // i = testhelpers.IntrospectToken(t, conf, first.RefreshToken, adminTS) + // assert.False(t, i.Get("active").Bool(), "%s", i) + // + // i = testhelpers.IntrospectToken(t, conf, second.RefreshToken, adminTS) + // assert.False(t, i.Get("active").Bool(), "%s", i) + // }) + //}) + // + //t.Run("followup=graceful refresh tokens with multiple nested branches belong to the same request", func(t *testing.T) { + // start := time.Now() + // token := issueTokens(t) + // var a1Refresh, b1Refresh, a2RefreshA, a2RefreshB, b2RefreshA, b2RefreshB *oauth2.Token + // t.Run("followup=first refresh", func(t *testing.T) { + // a1Refresh = refreshTokens(t, token) + // }) + // + // t.Run("followup=second refresh", func(t *testing.T) { + // b1Refresh = refreshTokens(t, token) + // }) + // + // t.Run("followup=first refresh from first refresh", func(t *testing.T) { + // a2RefreshA = refreshTokens(t, a1Refresh) + // }) + // + // t.Run("followup=second refresh from first refresh", func(t *testing.T) { + // a2RefreshB = refreshTokens(t, a1Refresh) + // }) + // + // t.Run("followup=first refresh from second refresh", func(t *testing.T) { + // b2RefreshA = refreshTokens(t, b1Refresh) + // }) + // + // t.Run("followup=second refresh from second refresh", func(t *testing.T) { + // b2RefreshB = refreshTokens(t, b1Refresh) + // }) + // + // // Sleep until the grace period is over + // time.Sleep(time.Until(start.Add(5*time.Second + time.Millisecond*10))) + // t.Run("followup=refresh failure invalidates all tokens", func(t *testing.T) { + // _, err := conf.TokenSource(context.Background(), token).Token() + // assert.Error(t, err) + // + // for k, token := range []*oauth2.Token{ + // a1Refresh, b1Refresh, a2RefreshA, a2RefreshB, b2RefreshA, b2RefreshB, + // } { + // t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) { + // i := testhelpers.IntrospectToken(t, conf, token.AccessToken, adminTS) + // assert.False(t, i.Get("active").Bool(), "%s", i) + // + // i = testhelpers.IntrospectToken(t, conf, token.AccessToken, adminTS) + // assert.False(t, i.Get("active").Bool(), "%s", i) + // + // i = testhelpers.IntrospectToken(t, conf, token.RefreshToken, adminTS) + // assert.False(t, i.Get("active").Bool(), "%s", i) + // + // i = testhelpers.IntrospectToken(t, conf, token.RefreshToken, adminTS) + // assert.False(t, i.Get("active").Bool(), "%s", i) + // }) + // } + // }) + //}) } t.Run("strategy=jwt", func(t *testing.T) { @@ -1535,21 +1572,17 @@ func createVCProofJWT(t *testing.T, pubKey *jose.JSONWebKey, privKey any, nonce // - [x] should pass with prompt=login when authentication time is recent // - [x] should fail with prompt=login when authentication time is in the past func TestAuthCodeWithMockStrategy(t *testing.T) { - setupRegistries(t) - - for k := range map[string]driver.Registry{ - "cockroach": registries["cockroach"], - } { - t.Run("registry="+k, func(t *testing.T) { - ctx := context.Background() - reg := internal.NewRegistrySQLFromURL(t, registries[k].Config().DSN(), true, &contextx.Default{}) - - require.NoError(t, jwk.EnsureAsymmetricKeypairExists(ctx, reg, string(jose.ES256), x.OpenIDConnectKeyName)) - require.NoError(t, jwk.EnsureAsymmetricKeypairExists(ctx, reg, string(jose.ES256), x.OAuth2JWTKeyName)) - - conf := reg.Config() + ctx := context.Background() + for _, strat := range []struct{ d string }{{d: "opaque"}, {d: "jwt"}} { + t.Run("strategy="+strat.d, func(t *testing.T) { + conf := internal.NewConfigurationWithDefaults() conf.MustSet(ctx, config.KeyAccessTokenLifespan, time.Second*2) conf.MustSet(ctx, config.KeyScopeStrategy, "DEPRECATED_HIERARCHICAL_SCOPE_STRATEGY") + conf.MustSet(ctx, config.KeyAccessTokenStrategy, strat.d) + reg := internal.NewRegistryMemory(t, conf, &contextx.Default{}) + internal.MustEnsureRegistryKeys(ctx, reg, x.OpenIDConnectKeyName) + internal.MustEnsureRegistryKeys(ctx, reg, x.OAuth2JWTKeyName) + consentStrategy := &consentMock{} router := x.NewRouterPublic() ts := httptest.NewServer(router) @@ -1587,531 +1620,528 @@ func TestAuthCodeWithMockStrategy(t *testing.T) { Scopes: []string{"hydra.*", "offline", "openid"}, } - for _, strat := range []struct{ d string }{{d: "opaque"}, {d: "jwt"}} { - conf := reg.Config() - conf.MustSet(ctx, config.KeyAccessTokenStrategy, strat.d) - - t.Run("strategy="+strat.d, func(t *testing.T) { - var code string - for k, tc := range []struct { - cj http.CookieJar - d string - cb func(t *testing.T) httprouter.Handle - authURL string - shouldPassConsentStrategy bool - expectOAuthAuthError bool - expectOAuthTokenError bool - checkExpiry bool - authTime time.Time - requestTime time.Time - assertAccessToken func(*testing.T, string) - }{ - { - d: "should pass request if strategy passes", - authURL: oauthConfig.AuthCodeURL("some-foo-state"), - shouldPassConsentStrategy: true, - checkExpiry: true, - cb: func(t *testing.T) httprouter.Handle { - return func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { - code = r.URL.Query().Get("code") - require.NotEmpty(t, code) - _, _ = w.Write([]byte(r.URL.Query().Get("code"))) - } - }, - assertAccessToken: func(t *testing.T, token string) { - if strat.d != "jwt" { - return - } + var code string + for k, tc := range []struct { + cj http.CookieJar + d string + cb func(t *testing.T) httprouter.Handle + authURL string + shouldPassConsentStrategy bool + expectOAuthAuthError bool + expectOAuthTokenError bool + checkExpiry bool + authTime time.Time + requestTime time.Time + assertAccessToken func(*testing.T, string) + }{ + { + d: "should pass request if strategy passes", + authURL: oauthConfig.AuthCodeURL("some-foo-state"), + shouldPassConsentStrategy: true, + checkExpiry: true, + cb: func(t *testing.T) httprouter.Handle { + return func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { + code = r.URL.Query().Get("code") + require.NotEmpty(t, code) + _, _ = w.Write([]byte(r.URL.Query().Get("code"))) + } + }, + assertAccessToken: func(t *testing.T, token string) { + if strat.d != "jwt" { + return + } - body, err := x.DecodeSegment(strings.Split(token, ".")[1]) - require.NoError(t, err) + body, err := x.DecodeSegment(strings.Split(token, ".")[1]) + require.NoError(t, err) - data := map[string]interface{}{} - require.NoError(t, json.Unmarshal(body, &data)) - - assert.EqualValues(t, "app-client", data["client_id"]) - assert.EqualValues(t, "foo", data["sub"]) - assert.NotEmpty(t, data["iss"]) - assert.NotEmpty(t, data["jti"]) - assert.NotEmpty(t, data["exp"]) - assert.NotEmpty(t, data["iat"]) - assert.NotEmpty(t, data["nbf"]) - assert.EqualValues(t, data["nbf"], data["iat"]) - assert.EqualValues(t, []interface{}{"offline", "openid", "hydra.*"}, data["scp"]) - }, - }, - { - d: "should fail because prompt=none and max_age > auth_time", - authURL: oauthConfig.AuthCodeURL("some-foo-state") + "&prompt=none&max_age=1", - authTime: time.Now().UTC().Add(-time.Minute), - requestTime: time.Now().UTC(), - shouldPassConsentStrategy: true, - cb: func(t *testing.T) httprouter.Handle { - return func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { - code = r.URL.Query().Get("code") - err := r.URL.Query().Get("error") - require.Empty(t, code) - require.EqualValues(t, fosite.ErrLoginRequired.Error(), err) - } - }, - expectOAuthAuthError: true, - }, - { - d: "should pass because prompt=none and max_age is less than auth_time", - authURL: oauthConfig.AuthCodeURL("some-foo-state") + "&prompt=none&max_age=3600", - authTime: time.Now().UTC().Add(-time.Minute), - requestTime: time.Now().UTC(), - shouldPassConsentStrategy: true, - cb: func(t *testing.T) httprouter.Handle { - return func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { - code = r.URL.Query().Get("code") - require.NotEmpty(t, code) - _, _ = w.Write([]byte(r.URL.Query().Get("code"))) - } - }, - }, - { - d: "should fail because prompt=none but auth_time suggests recent authentication", - authURL: oauthConfig.AuthCodeURL("some-foo-state") + "&prompt=none", - authTime: time.Now().UTC().Add(-time.Minute), - requestTime: time.Now().UTC().Add(-time.Hour), - shouldPassConsentStrategy: true, - cb: func(t *testing.T) httprouter.Handle { - return func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { - code = r.URL.Query().Get("code") - err := r.URL.Query().Get("error") - require.Empty(t, code) - require.EqualValues(t, fosite.ErrLoginRequired.Error(), err) - } - }, - expectOAuthAuthError: true, - }, - { - d: "should fail because consent strategy fails", - authURL: oauthConfig.AuthCodeURL("some-foo-state"), - expectOAuthAuthError: true, - shouldPassConsentStrategy: false, - cb: func(t *testing.T) httprouter.Handle { - return func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { - require.Empty(t, r.URL.Query().Get("code")) - assert.Equal(t, fosite.ErrRequestForbidden.Error(), r.URL.Query().Get("error")) - } - }, - }, - { - d: "should pass with prompt=login when authentication time is recent", - authURL: oauthConfig.AuthCodeURL("some-foo-state") + "&prompt=login", - authTime: time.Now().UTC().Add(-time.Second), - requestTime: time.Now().UTC().Add(-time.Minute), - shouldPassConsentStrategy: true, - cb: func(t *testing.T) httprouter.Handle { - return func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { - code = r.URL.Query().Get("code") - require.NotEmpty(t, code) - _, _ = w.Write([]byte(r.URL.Query().Get("code"))) - } - }, - }, - { - d: "should fail with prompt=login when authentication time is in the past", - authURL: oauthConfig.AuthCodeURL("some-foo-state") + "&prompt=login", - authTime: time.Now().UTC().Add(-time.Minute), - requestTime: time.Now().UTC(), - expectOAuthAuthError: true, - shouldPassConsentStrategy: true, - cb: func(t *testing.T) httprouter.Handle { - return func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { - code = r.URL.Query().Get("code") - require.Empty(t, code) - assert.Equal(t, fosite.ErrLoginRequired.Error(), r.URL.Query().Get("error")) - } - }, - }, - } { - t.Run(fmt.Sprintf("case=%d/description=%s", k, tc.d), func(t *testing.T) { - mutex.Lock() - defer mutex.Unlock() - if tc.cb == nil { - tc.cb = noopHandler - } + data := map[string]interface{}{} + require.NoError(t, json.Unmarshal(body, &data)) + + assert.EqualValues(t, "app-client", data["client_id"]) + assert.EqualValues(t, "foo", data["sub"]) + assert.NotEmpty(t, data["iss"]) + assert.NotEmpty(t, data["jti"]) + assert.NotEmpty(t, data["exp"]) + assert.NotEmpty(t, data["iat"]) + assert.NotEmpty(t, data["nbf"]) + assert.EqualValues(t, data["nbf"], data["iat"]) + assert.EqualValues(t, []interface{}{"offline", "openid", "hydra.*"}, data["scp"]) + }, + }, + { + d: "should fail because prompt=none and max_age > auth_time", + authURL: oauthConfig.AuthCodeURL("some-foo-state") + "&prompt=none&max_age=1", + authTime: time.Now().UTC().Add(-time.Minute), + requestTime: time.Now().UTC(), + shouldPassConsentStrategy: true, + cb: func(t *testing.T) httprouter.Handle { + return func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { + code = r.URL.Query().Get("code") + err := r.URL.Query().Get("error") + require.Empty(t, code) + require.EqualValues(t, fosite.ErrLoginRequired.Error(), err) + } + }, + expectOAuthAuthError: true, + }, + { + d: "should pass because prompt=none and max_age is less than auth_time", + authURL: oauthConfig.AuthCodeURL("some-foo-state") + "&prompt=none&max_age=3600", + authTime: time.Now().UTC().Add(-time.Minute), + requestTime: time.Now().UTC(), + shouldPassConsentStrategy: true, + cb: func(t *testing.T) httprouter.Handle { + return func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { + code = r.URL.Query().Get("code") + require.NotEmpty(t, code) + _, _ = w.Write([]byte(r.URL.Query().Get("code"))) + } + }, + }, + { + d: "should fail because prompt=none but auth_time suggests recent authentication", + authURL: oauthConfig.AuthCodeURL("some-foo-state") + "&prompt=none", + authTime: time.Now().UTC().Add(-time.Minute), + requestTime: time.Now().UTC().Add(-time.Hour), + shouldPassConsentStrategy: true, + cb: func(t *testing.T) httprouter.Handle { + return func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { + code = r.URL.Query().Get("code") + err := r.URL.Query().Get("error") + require.Empty(t, code) + require.EqualValues(t, fosite.ErrLoginRequired.Error(), err) + } + }, + expectOAuthAuthError: true, + }, + { + d: "should fail because consent strategy fails", + authURL: oauthConfig.AuthCodeURL("some-foo-state"), + expectOAuthAuthError: true, + shouldPassConsentStrategy: false, + cb: func(t *testing.T) httprouter.Handle { + return func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { + require.Empty(t, r.URL.Query().Get("code")) + assert.Equal(t, fosite.ErrRequestForbidden.Error(), r.URL.Query().Get("error")) + } + }, + }, + { + d: "should pass with prompt=login when authentication time is recent", + authURL: oauthConfig.AuthCodeURL("some-foo-state") + "&prompt=login", + authTime: time.Now().UTC().Add(-time.Second), + requestTime: time.Now().UTC().Add(-time.Minute), + shouldPassConsentStrategy: true, + cb: func(t *testing.T) httprouter.Handle { + return func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { + code = r.URL.Query().Get("code") + require.NotEmpty(t, code) + _, _ = w.Write([]byte(r.URL.Query().Get("code"))) + } + }, + }, + { + d: "should fail with prompt=login when authentication time is in the past", + authURL: oauthConfig.AuthCodeURL("some-foo-state") + "&prompt=login", + authTime: time.Now().UTC().Add(-time.Minute), + requestTime: time.Now().UTC(), + expectOAuthAuthError: true, + shouldPassConsentStrategy: true, + cb: func(t *testing.T) httprouter.Handle { + return func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { + code = r.URL.Query().Get("code") + require.Empty(t, code) + assert.Equal(t, fosite.ErrLoginRequired.Error(), r.URL.Query().Get("error")) + } + }, + }, + } { + t.Run(fmt.Sprintf("case=%d/description=%s", k, tc.d), func(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + if tc.cb == nil { + tc.cb = noopHandler + } - consentStrategy.deny = !tc.shouldPassConsentStrategy - consentStrategy.authTime = tc.authTime - consentStrategy.requestTime = tc.requestTime + consentStrategy.deny = !tc.shouldPassConsentStrategy + consentStrategy.authTime = tc.authTime + consentStrategy.requestTime = tc.requestTime - cb := tc.cb(t) - callbackHandler = &cb + cb := tc.cb(t) + callbackHandler = &cb - req, err := http.NewRequest("GET", tc.authURL, nil) - require.NoError(t, err) + req, err := http.NewRequest("GET", tc.authURL, nil) + require.NoError(t, err) - if tc.cj == nil { - tc.cj = testhelpers.NewEmptyCookieJar(t) - } + if tc.cj == nil { + tc.cj = testhelpers.NewEmptyCookieJar(t) + } + + resp, err := (&http.Client{Jar: tc.cj}).Do(req) + require.NoError(t, err, tc.authURL, ts.URL) + defer resp.Body.Close() + + if tc.expectOAuthAuthError { + require.Empty(t, code) + return + } + + require.NotEmpty(t, code) + + token, err := oauthConfig.Exchange(context.TODO(), code) + if tc.expectOAuthTokenError { + require.Error(t, err) + return + } + + require.NoError(t, err, code) + if tc.assertAccessToken != nil { + tc.assertAccessToken(t, token.AccessToken) + } + + t.Run("case=userinfo", func(t *testing.T) { + var makeRequest = func(req *http.Request) *http.Response { + resp, err = http.DefaultClient.Do(req) + require.NoError(t, err) + return resp + } - resp, err := (&http.Client{Jar: tc.cj}).Do(req) - require.NoError(t, err, tc.authURL, ts.URL) + var testSuccess = func(response *http.Response) { defer resp.Body.Close() - if tc.expectOAuthAuthError { - require.Empty(t, code) - return - } + require.Equal(t, http.StatusOK, resp.StatusCode) - require.NotEmpty(t, code) + var claims map[string]interface{} + require.NoError(t, json.NewDecoder(resp.Body).Decode(&claims)) + assert.Equal(t, "foo", claims["sub"]) + } - token, err := oauthConfig.Exchange(context.TODO(), code) - if tc.expectOAuthTokenError { - require.Error(t, err) - return - } + req, err = http.NewRequest("GET", ts.URL+"/userinfo", nil) + req.Header.Add("Authorization", "bearer "+token.AccessToken) + testSuccess(makeRequest(req)) - require.NoError(t, err, code) - if tc.assertAccessToken != nil { - tc.assertAccessToken(t, token.AccessToken) - } + req, err = http.NewRequest("POST", ts.URL+"/userinfo", nil) + req.Header.Add("Authorization", "bearer "+token.AccessToken) + testSuccess(makeRequest(req)) - t.Run("case=userinfo", func(t *testing.T) { - var makeRequest = func(req *http.Request) *http.Response { - resp, err = http.DefaultClient.Do(req) - require.NoError(t, err) - return resp - } + req, err = http.NewRequest("POST", ts.URL+"/userinfo", bytes.NewBuffer([]byte("access_token="+token.AccessToken))) + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + testSuccess(makeRequest(req)) - var testSuccess = func(response *http.Response) { - defer resp.Body.Close() + req, err = http.NewRequest("GET", ts.URL+"/userinfo", nil) + req.Header.Add("Authorization", "bearer asdfg") + resp := makeRequest(req) + require.Equal(t, http.StatusUnauthorized, resp.StatusCode) + }) - require.Equal(t, http.StatusOK, resp.StatusCode) + res, err := testRefresh(t, token, ts.URL, tc.checkExpiry) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, res.StatusCode) - var claims map[string]interface{} - require.NoError(t, json.NewDecoder(resp.Body).Decode(&claims)) - assert.Equal(t, "foo", claims["sub"]) - } + body, err := io.ReadAll(res.Body) + require.NoError(t, err) - req, err = http.NewRequest("GET", ts.URL+"/userinfo", nil) - req.Header.Add("Authorization", "bearer "+token.AccessToken) - testSuccess(makeRequest(req)) + var refreshedToken oauth2.Token + require.NoError(t, json.Unmarshal(body, &refreshedToken)) - req, err = http.NewRequest("POST", ts.URL+"/userinfo", nil) - req.Header.Add("Authorization", "bearer "+token.AccessToken) - testSuccess(makeRequest(req)) + if tc.assertAccessToken != nil { + tc.assertAccessToken(t, refreshedToken.AccessToken) + } - req, err = http.NewRequest("POST", ts.URL+"/userinfo", bytes.NewBuffer([]byte("access_token="+token.AccessToken))) - req.Header.Add("Content-Type", "application/x-www-form-urlencoded") - testSuccess(makeRequest(req)) + t.Run("the tokens should be different", func(t *testing.T) { + if strat.d != "jwt" { + t.Skip() + } - req, err = http.NewRequest("GET", ts.URL+"/userinfo", nil) - req.Header.Add("Authorization", "bearer asdfg") - resp := makeRequest(req) - require.Equal(t, http.StatusUnauthorized, resp.StatusCode) - }) + body, err := x.DecodeSegment(strings.Split(token.AccessToken, ".")[1]) + require.NoError(t, err) - res, err := testRefresh(t, token, ts.URL, tc.checkExpiry) - require.NoError(t, err) - assert.Equal(t, http.StatusOK, res.StatusCode) + origPayload := map[string]interface{}{} + require.NoError(t, json.Unmarshal(body, &origPayload)) - body, err := io.ReadAll(res.Body) - require.NoError(t, err) + body, err = x.DecodeSegment(strings.Split(refreshedToken.AccessToken, ".")[1]) + require.NoError(t, err) - var refreshedToken oauth2.Token - require.NoError(t, json.Unmarshal(body, &refreshedToken)) + refreshedPayload := map[string]interface{}{} + require.NoError(t, json.Unmarshal(body, &refreshedPayload)) - if tc.assertAccessToken != nil { - tc.assertAccessToken(t, refreshedToken.AccessToken) - } + if tc.checkExpiry { + assert.NotEqual(t, refreshedPayload["exp"], origPayload["exp"]) + assert.NotEqual(t, refreshedPayload["iat"], origPayload["iat"]) + assert.NotEqual(t, refreshedPayload["nbf"], origPayload["nbf"]) + } + assert.NotEqual(t, refreshedPayload["jti"], origPayload["jti"]) + assert.Equal(t, refreshedPayload["client_id"], origPayload["client_id"]) + }) - t.Run("the tokens should be different", func(t *testing.T) { - if strat.d != "jwt" { - t.Skip() - } + require.NotEqual(t, token.AccessToken, refreshedToken.AccessToken) - body, err := x.DecodeSegment(strings.Split(token.AccessToken, ".")[1]) - require.NoError(t, err) + t.Run("old token should no longer be usable", func(t *testing.T) { + req, err := http.NewRequest("GET", ts.URL+"/userinfo", nil) + require.NoError(t, err) + req.Header.Add("Authorization", "bearer "+token.AccessToken) + res, err := http.DefaultClient.Do(req) + require.NoError(t, err) + assert.EqualValues(t, http.StatusUnauthorized, res.StatusCode) + }) - origPayload := map[string]interface{}{} - require.NoError(t, json.Unmarshal(body, &origPayload)) + t.Run("refreshing new refresh token should work", func(t *testing.T) { + res, err := testRefresh(t, &refreshedToken, ts.URL, false) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, res.StatusCode) - body, err = x.DecodeSegment(strings.Split(refreshedToken.AccessToken, ".")[1]) - require.NoError(t, err) + body, err := io.ReadAll(res.Body) + require.NoError(t, err) + require.NoError(t, json.Unmarshal(body, &refreshedToken)) + }) - refreshedPayload := map[string]interface{}{} - require.NoError(t, json.Unmarshal(body, &refreshedPayload)) + t.Run("should call refresh token hook if configured", func(t *testing.T) { + run := func(hookType string) func(t *testing.T) { + return func(t *testing.T) { + hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, r.Header.Get("Content-Type"), "application/json; charset=UTF-8") + + expectedGrantedScopes := []string{"openid", "offline", "hydra.*"} + expectedSubject := "foo" + + exceptKeys := []string{ + "session.kid", + "session.id_token.expires_at", + "session.id_token.headers.extra.kid", + "session.id_token.id_token_claims.iat", + "session.id_token.id_token_claims.exp", + "session.id_token.id_token_claims.rat", + "session.id_token.id_token_claims.auth_time", + } + + if hookType == "legacy" { + var hookReq hydraoauth2.RefreshTokenHookRequest + require.NoError(t, json.NewDecoder(r.Body).Decode(&hookReq)) + require.Equal(t, hookReq.Subject, expectedSubject) + require.ElementsMatch(t, hookReq.GrantedScopes, expectedGrantedScopes) + require.ElementsMatch(t, hookReq.GrantedAudience, []string{}) + require.Equal(t, hookReq.ClientID, oauthConfig.ClientID) + require.NotEmpty(t, hookReq.Session) + require.Equal(t, hookReq.Session.Subject, expectedSubject) + require.Equal(t, hookReq.Session.ClientID, oauthConfig.ClientID) + require.NotEmpty(t, hookReq.Requester) + require.Equal(t, hookReq.Requester.ClientID, oauthConfig.ClientID) + require.ElementsMatch(t, hookReq.Requester.GrantedScopes, expectedGrantedScopes) + + snapshotx.SnapshotT(t, hookReq, snapshotx.ExceptPaths(exceptKeys...)) + } else { + var hookReq hydraoauth2.TokenHookRequest + require.NoError(t, json.NewDecoder(r.Body).Decode(&hookReq)) + require.NotEmpty(t, hookReq.Session) + require.Equal(t, hookReq.Session.Subject, expectedSubject) + require.Equal(t, hookReq.Session.ClientID, oauthConfig.ClientID) + require.NotEmpty(t, hookReq.Request) + require.Equal(t, hookReq.Request.ClientID, oauthConfig.ClientID) + require.ElementsMatch(t, hookReq.Request.GrantedScopes, expectedGrantedScopes) + require.ElementsMatch(t, hookReq.Request.GrantedAudience, []string{}) + require.Equal(t, hookReq.Request.Payload, map[string][]string{"grant_type": {"refresh_token"}}) + + snapshotx.SnapshotT(t, hookReq, snapshotx.ExceptPaths(exceptKeys...)) + } + + claims := map[string]interface{}{ + "hooked": hookType, + } + + hookResp := hydraoauth2.TokenHookResponse{ + Session: flow.AcceptOAuth2ConsentRequestSession{ + AccessToken: claims, + IDToken: claims, + }, + } - if tc.checkExpiry { - assert.NotEqual(t, refreshedPayload["exp"], origPayload["exp"]) - assert.NotEqual(t, refreshedPayload["iat"], origPayload["iat"]) - assert.NotEqual(t, refreshedPayload["nbf"], origPayload["nbf"]) + w.WriteHeader(http.StatusOK) + require.NoError(t, json.NewEncoder(w).Encode(&hookResp)) + })) + defer hs.Close() + + if hookType == "legacy" { + conf.MustSet(ctx, config.KeyRefreshTokenHook, hs.URL) + t.Cleanup(func() { + conf.MustSet(ctx, config.KeyRefreshTokenHook, nil) + }) + } else { + conf.MustSet(ctx, config.KeyTokenHook, hs.URL) + t.Cleanup(func() { + conf.MustSet(ctx, config.KeyTokenHook, nil) + }) } - assert.NotEqual(t, refreshedPayload["jti"], origPayload["jti"]) - assert.Equal(t, refreshedPayload["client_id"], origPayload["client_id"]) - }) - require.NotEqual(t, token.AccessToken, refreshedToken.AccessToken) + res, err := testRefresh(t, &refreshedToken, ts.URL, false) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, res.StatusCode) - t.Run("old token should no longer be usable", func(t *testing.T) { - req, err := http.NewRequest("GET", ts.URL+"/userinfo", nil) + body, err := io.ReadAll(res.Body) require.NoError(t, err) - req.Header.Add("Authorization", "bearer "+token.AccessToken) - res, err := http.DefaultClient.Do(req) + require.NoError(t, json.Unmarshal(body, &refreshedToken)) + + accessTokenClaims := testhelpers.IntrospectToken(t, oauthConfig, refreshedToken.AccessToken, ts) + require.Equal(t, accessTokenClaims.Get("ext.hooked").String(), hookType) + + idTokenBody, err := x.DecodeSegment( + strings.Split( + gjson.GetBytes(body, "id_token").String(), + ".", + )[1], + ) require.NoError(t, err) - assert.EqualValues(t, http.StatusUnauthorized, res.StatusCode) - }) - t.Run("refreshing new refresh token should work", func(t *testing.T) { + require.Equal(t, gjson.GetBytes(idTokenBody, "hooked").String(), hookType) + } + } + t.Run("hook=legacy", run("legacy")) + t.Run("hook=new", run("new")) + }) + + t.Run("should not override session data if token refresh hook returns no content", func(t *testing.T) { + run := func(hookType string) func(t *testing.T) { + return func(t *testing.T) { + hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNoContent) + })) + defer hs.Close() + + if hookType == "legacy" { + conf.MustSet(ctx, config.KeyRefreshTokenHook, hs.URL) + t.Cleanup(func() { conf.MustSet(ctx, config.KeyRefreshTokenHook, nil) }) + } else { + conf.MustSet(ctx, config.KeyTokenHook, hs.URL) + t.Cleanup(func() { conf.MustSet(ctx, config.KeyTokenHook, nil) }) + } + + origAccessTokenClaims := testhelpers.IntrospectToken(t, oauthConfig, refreshedToken.AccessToken, ts) + res, err := testRefresh(t, &refreshedToken, ts.URL, false) require.NoError(t, err) assert.Equal(t, http.StatusOK, res.StatusCode) - body, err := io.ReadAll(res.Body) + body, err = io.ReadAll(res.Body) require.NoError(t, err) + require.NoError(t, json.Unmarshal(body, &refreshedToken)) - }) - - t.Run("should call refresh token hook if configured", func(t *testing.T) { - run := func(hookType string) func(t *testing.T) { - return func(t *testing.T) { - hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - assert.Equal(t, r.Header.Get("Content-Type"), "application/json; charset=UTF-8") - - expectedGrantedScopes := []string{"openid", "offline", "hydra.*"} - expectedSubject := "foo" - - exceptKeys := []string{ - "session.kid", - "session.id_token.expires_at", - "session.id_token.headers.extra.kid", - "session.id_token.id_token_claims.iat", - "session.id_token.id_token_claims.exp", - "session.id_token.id_token_claims.rat", - "session.id_token.id_token_claims.auth_time", - } - - if hookType == "legacy" { - var hookReq hydraoauth2.RefreshTokenHookRequest - require.NoError(t, json.NewDecoder(r.Body).Decode(&hookReq)) - require.Equal(t, hookReq.Subject, expectedSubject) - require.ElementsMatch(t, hookReq.GrantedScopes, expectedGrantedScopes) - require.ElementsMatch(t, hookReq.GrantedAudience, []string{}) - require.Equal(t, hookReq.ClientID, oauthConfig.ClientID) - require.NotEmpty(t, hookReq.Session) - require.Equal(t, hookReq.Session.Subject, expectedSubject) - require.Equal(t, hookReq.Session.ClientID, oauthConfig.ClientID) - require.NotEmpty(t, hookReq.Requester) - require.Equal(t, hookReq.Requester.ClientID, oauthConfig.ClientID) - require.ElementsMatch(t, hookReq.Requester.GrantedScopes, expectedGrantedScopes) - - snapshotx.SnapshotT(t, hookReq, snapshotx.ExceptPaths(exceptKeys...)) - } else { - var hookReq hydraoauth2.TokenHookRequest - require.NoError(t, json.NewDecoder(r.Body).Decode(&hookReq)) - require.NotEmpty(t, hookReq.Session) - require.Equal(t, hookReq.Session.Subject, expectedSubject) - require.Equal(t, hookReq.Session.ClientID, oauthConfig.ClientID) - require.NotEmpty(t, hookReq.Request) - require.Equal(t, hookReq.Request.ClientID, oauthConfig.ClientID) - require.ElementsMatch(t, hookReq.Request.GrantedScopes, expectedGrantedScopes) - require.ElementsMatch(t, hookReq.Request.GrantedAudience, []string{}) - require.Equal(t, hookReq.Request.Payload, map[string][]string{"grant_type": {"refresh_token"}}) - - snapshotx.SnapshotT(t, hookReq, snapshotx.ExceptPaths(exceptKeys...)) - } - - claims := map[string]interface{}{ - "hooked": hookType, - } - - hookResp := hydraoauth2.TokenHookResponse{ - Session: flow.AcceptOAuth2ConsentRequestSession{ - AccessToken: claims, - IDToken: claims, - }, - } - - w.WriteHeader(http.StatusOK) - require.NoError(t, json.NewEncoder(w).Encode(&hookResp)) - })) - defer hs.Close() - - if hookType == "legacy" { - conf.MustSet(ctx, config.KeyRefreshTokenHook, hs.URL) - defer conf.MustSet(ctx, config.KeyRefreshTokenHook, nil) - } else { - conf.MustSet(ctx, config.KeyTokenHook, hs.URL) - defer conf.MustSet(ctx, config.KeyTokenHook, nil) - } - - res, err := testRefresh(t, &refreshedToken, ts.URL, false) - require.NoError(t, err) - assert.Equal(t, http.StatusOK, res.StatusCode) - - body, err := io.ReadAll(res.Body) - require.NoError(t, err) - require.NoError(t, json.Unmarshal(body, &refreshedToken)) - - accessTokenClaims := testhelpers.IntrospectToken(t, oauthConfig, refreshedToken.AccessToken, ts) - require.Equal(t, accessTokenClaims.Get("ext.hooked").String(), hookType) - - idTokenBody, err := x.DecodeSegment( - strings.Split( - gjson.GetBytes(body, "id_token").String(), - ".", - )[1], - ) - require.NoError(t, err) - - require.Equal(t, gjson.GetBytes(idTokenBody, "hooked").String(), hookType) - } - } - t.Run("hook=legacy", run("legacy")) - t.Run("hook=new", run("new")) - }) - - t.Run("should not override session data if token refresh hook returns no content", func(t *testing.T) { - run := func(hookType string) func(t *testing.T) { - return func(t *testing.T) { - hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusNoContent) - })) - defer hs.Close() - - if hookType == "legacy" { - conf.MustSet(ctx, config.KeyRefreshTokenHook, hs.URL) - defer conf.MustSet(ctx, config.KeyRefreshTokenHook, nil) - } else { - conf.MustSet(ctx, config.KeyTokenHook, hs.URL) - defer conf.MustSet(ctx, config.KeyTokenHook, nil) - } - - origAccessTokenClaims := testhelpers.IntrospectToken(t, oauthConfig, refreshedToken.AccessToken, ts) - - res, err := testRefresh(t, &refreshedToken, ts.URL, false) - require.NoError(t, err) - assert.Equal(t, http.StatusOK, res.StatusCode) - - body, err = io.ReadAll(res.Body) - require.NoError(t, err) - - require.NoError(t, json.Unmarshal(body, &refreshedToken)) - - refreshedAccessTokenClaims := testhelpers.IntrospectToken(t, oauthConfig, refreshedToken.AccessToken, ts) - assertx.EqualAsJSONExcept(t, json.RawMessage(origAccessTokenClaims.Raw), json.RawMessage(refreshedAccessTokenClaims.Raw), []string{"exp", "iat", "nbf"}) - } - } - t.Run("hook=legacy", run("legacy")) - t.Run("hook=new", run("new")) - }) - - t.Run("should fail token refresh with `server_error` if refresh hook fails", func(t *testing.T) { - run := func(hookType string) func(t *testing.T) { - return func(t *testing.T) { - hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusInternalServerError) - })) - defer hs.Close() - - if hookType == "legacy" { - conf.MustSet(ctx, config.KeyRefreshTokenHook, hs.URL) - defer conf.MustSet(ctx, config.KeyRefreshTokenHook, nil) - } else { - conf.MustSet(ctx, config.KeyTokenHook, hs.URL) - defer conf.MustSet(ctx, config.KeyTokenHook, nil) - } - - res, err := testRefresh(t, &refreshedToken, ts.URL, false) - require.NoError(t, err) - assert.Equal(t, http.StatusInternalServerError, res.StatusCode) - - var errBody fosite.RFC6749ErrorJson - require.NoError(t, json.NewDecoder(res.Body).Decode(&errBody)) - require.Equal(t, fosite.ErrServerError.Error(), errBody.Name) - require.Equal(t, "An error occurred while executing the token hook.", errBody.Description) - } - } - t.Run("hook=legacy", run("legacy")) - t.Run("hook=new", run("new")) - }) - - t.Run("should fail token refresh with `access_denied` if legacy refresh hook denied the request", func(t *testing.T) { - run := func(hookType string) func(t *testing.T) { - return func(t *testing.T) { - hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusForbidden) - })) - defer hs.Close() - - if hookType == "legacy" { - conf.MustSet(ctx, config.KeyRefreshTokenHook, hs.URL) - defer conf.MustSet(ctx, config.KeyRefreshTokenHook, nil) - } else { - conf.MustSet(ctx, config.KeyTokenHook, hs.URL) - defer conf.MustSet(ctx, config.KeyTokenHook, nil) - } - - res, err := testRefresh(t, &refreshedToken, ts.URL, false) - require.NoError(t, err) - assert.Equal(t, http.StatusForbidden, res.StatusCode) - - var errBody fosite.RFC6749ErrorJson - require.NoError(t, json.NewDecoder(res.Body).Decode(&errBody)) - require.Equal(t, fosite.ErrAccessDenied.Error(), errBody.Name) - require.Equal(t, "The token hook target responded with an error. Make sure that the request you are making is valid. Maybe the credential or request parameters you are using are limited in scope or otherwise restricted.", errBody.Description) - } + + refreshedAccessTokenClaims := testhelpers.IntrospectToken(t, oauthConfig, refreshedToken.AccessToken, ts) + assertx.EqualAsJSONExcept(t, json.RawMessage(origAccessTokenClaims.Raw), json.RawMessage(refreshedAccessTokenClaims.Raw), []string{"exp", "iat", "nbf"}) + } + } + t.Run("hook=legacy", run("legacy")) + t.Run("hook=new", run("new")) + }) + + t.Run("should fail token refresh with `server_error` if refresh hook fails", func(t *testing.T) { + run := func(hookType string) func(t *testing.T) { + return func(t *testing.T) { + hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer hs.Close() + + if hookType == "legacy" { + conf.MustSet(ctx, config.KeyRefreshTokenHook, hs.URL) + t.Cleanup(func() { conf.MustSet(ctx, config.KeyRefreshTokenHook, nil) }) + } else { + conf.MustSet(ctx, config.KeyTokenHook, hs.URL) + t.Cleanup(func() { conf.MustSet(ctx, config.KeyTokenHook, nil) }) } - t.Run("hook=legacy", run("legacy")) - t.Run("hook=new", run("new")) - }) - - t.Run("should fail token refresh with `server_error` if refresh hook response is malformed", func(t *testing.T) { - run := func(hookType string) func(t *testing.T) { - return func(t *testing.T) { - hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - })) - defer hs.Close() - - if hookType == "legacy" { - conf.MustSet(ctx, config.KeyRefreshTokenHook, hs.URL) - defer conf.MustSet(ctx, config.KeyRefreshTokenHook, nil) - } else { - conf.MustSet(ctx, config.KeyTokenHook, hs.URL) - defer conf.MustSet(ctx, config.KeyTokenHook, nil) - } - - res, err := testRefresh(t, &refreshedToken, ts.URL, false) - require.NoError(t, err) - assert.Equal(t, http.StatusInternalServerError, res.StatusCode) - - var errBody fosite.RFC6749ErrorJson - require.NoError(t, json.NewDecoder(res.Body).Decode(&errBody)) - require.Equal(t, fosite.ErrServerError.Error(), errBody.Name) - require.Equal(t, "The token hook target responded with an error.", errBody.Description) - } + + res, err := testRefresh(t, &refreshedToken, ts.URL, false) + require.NoError(t, err) + assert.Equal(t, http.StatusInternalServerError, res.StatusCode) + + var errBody fosite.RFC6749ErrorJson + require.NoError(t, json.NewDecoder(res.Body).Decode(&errBody)) + require.Equal(t, fosite.ErrServerError.Error(), errBody.Name) + require.Equal(t, "An error occurred while executing the token hook.", errBody.Description) + } + } + t.Run("hook=legacy", run("legacy")) + t.Run("hook=new", run("new")) + }) + + t.Run("should fail token refresh with `access_denied` if legacy refresh hook denied the request", func(t *testing.T) { + run := func(hookType string) func(t *testing.T) { + return func(t *testing.T) { + hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusForbidden) + })) + defer hs.Close() + + if hookType == "legacy" { + conf.MustSet(ctx, config.KeyRefreshTokenHook, hs.URL) + t.Cleanup(func() { conf.MustSet(ctx, config.KeyRefreshTokenHook, nil) }) + } else { + conf.MustSet(ctx, config.KeyTokenHook, hs.URL) + t.Cleanup(func() { conf.MustSet(ctx, config.KeyTokenHook, nil) }) } - t.Run("hook=legacy", run("legacy")) - t.Run("hook=new", run("new")) - }) - t.Run("refreshing old token should no longer work", func(t *testing.T) { - res, err := testRefresh(t, token, ts.URL, false) + res, err := testRefresh(t, &refreshedToken, ts.URL, false) require.NoError(t, err) - assert.Equal(t, http.StatusUnauthorized, res.StatusCode) - }) + assert.Equal(t, http.StatusForbidden, res.StatusCode) + + var errBody fosite.RFC6749ErrorJson + require.NoError(t, json.NewDecoder(res.Body).Decode(&errBody)) + require.Equal(t, fosite.ErrAccessDenied.Error(), errBody.Name) + require.Equal(t, "The token hook target responded with an error. Make sure that the request you are making is valid. Maybe the credential or request parameters you are using are limited in scope or otherwise restricted.", errBody.Description) + } + } + t.Run("hook=legacy", run("legacy")) + t.Run("hook=new", run("new")) + }) + + t.Run("should fail token refresh with `server_error` if refresh hook response is malformed", func(t *testing.T) { + run := func(hookType string) func(t *testing.T) { + return func(t *testing.T) { + hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer hs.Close() + + if hookType == "legacy" { + conf.MustSet(ctx, config.KeyRefreshTokenHook, hs.URL) + t.Cleanup(func() { conf.MustSet(ctx, config.KeyRefreshTokenHook, nil) }) + } else { + conf.MustSet(ctx, config.KeyTokenHook, hs.URL) + t.Cleanup(func() { conf.MustSet(ctx, config.KeyTokenHook, nil) }) + } - t.Run("attempt to refresh old token should revoke new token", func(t *testing.T) { res, err := testRefresh(t, &refreshedToken, ts.URL, false) require.NoError(t, err) - assert.Equal(t, http.StatusUnauthorized, res.StatusCode) - }) + assert.Equal(t, http.StatusInternalServerError, res.StatusCode) - t.Run("duplicate code exchange fails", func(t *testing.T) { - token, err := oauthConfig.Exchange(context.TODO(), code) - require.Error(t, err) - require.Nil(t, token) - }) + var errBody fosite.RFC6749ErrorJson + require.NoError(t, json.NewDecoder(res.Body).Decode(&errBody)) + require.Equal(t, fosite.ErrServerError.Error(), errBody.Name) + require.Equal(t, "The token hook target responded with an error.", errBody.Description) + } + } + t.Run("hook=legacy", run("legacy")) + t.Run("hook=new", run("new")) + }) - code = "" - }) - } + t.Run("refreshing old token should no longer work", func(t *testing.T) { + res, err := testRefresh(t, token, ts.URL, false) + require.NoError(t, err) + assert.Equal(t, http.StatusUnauthorized, res.StatusCode) + }) + + t.Run("attempt to refresh old token should revoke new token", func(t *testing.T) { + res, err := testRefresh(t, &refreshedToken, ts.URL, false) + require.NoError(t, err) + assert.Equal(t, http.StatusUnauthorized, res.StatusCode) + }) + + t.Run("duplicate code exchange fails", func(t *testing.T) { + token, err := oauthConfig.Exchange(context.TODO(), code) + require.Error(t, err) + require.Nil(t, token) + }) + + code = "" }) } }) @@ -2183,6 +2213,7 @@ func newOAuth2Client( return c, &oauth2.Config{ ClientID: c.GetID(), ClientSecret: secret, + RedirectURL: callbackURL, Endpoint: oauth2.Endpoint{ AuthURL: reg.Config().OAuth2AuthURL(ctx).String(), TokenURL: reg.Config().OAuth2TokenURL(ctx).String(), diff --git a/persistence/sql/persister_oauth2.go b/persistence/sql/persister_oauth2.go index b67b6ae17ed..071bbe05501 100644 --- a/persistence/sql/persister_oauth2.go +++ b/persistence/sql/persister_oauth2.go @@ -10,6 +10,9 @@ import ( "encoding/hex" "encoding/json" "fmt" + "github.com/gobuffalo/pop/v6" + "github.com/ory/x/dbal" + "go.opentelemetry.io/otel/attribute" "net/url" "strings" "time" @@ -60,7 +63,8 @@ type ( } OAuth2RefreshTable struct { OAuth2RequestSQL - FirstUsedAt sql.NullTime `db:"first_used_at"` + FirstUsedAt sql.NullTime `db:"first_used_at"` + AccessTokenSignature sqlxx.NullString `db:"access_token_signature"` } ) @@ -452,7 +456,7 @@ func (p *Persister) CreateRefreshTokenSession(ctx context.Context, signature str return p.createSession(ctx, signature, requester, sqlTableRefresh, requester.GetSession().GetExpiresAt(fosite.RefreshToken).UTC()) } -func (p *Persister) GetRefreshTokenSession(ctx context.Context, signature string, session fosite.Session) (request fosite.Requester, err error) { +func (p *Persister) RotateRefreshToken(ctx context.Context, refreshTokenSignature string) (requestID string, err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetRefreshTokenSession") defer otelx.End(span, &err) @@ -536,6 +540,140 @@ func (p *Persister) RevokeRefreshToken(ctx context.Context, id string) (err erro return p.deactivateSessionByRequestID(ctx, id, sqlTableRefresh) } +func (p *Persister) gracefulRefreshRotation(ctx context.Context, c *pop.Connection, requestID string, refreshSignature string, period time.Duration) (err error) { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.gracefulRefreshRotation", + trace.WithAttributes( + attribute.String("request_id", requestID), + attribute.String("refresh_signature", refreshSignature), + attribute.String("network_id", p.NetworkID(ctx).String()), + attribute.String("grace_period", period.String()), + )) + defer otelx.End(span, &err) + + if p.conn.Dialect.Name() == dbal.DriverMySQL { + // MySQL does not support returning values from an update query, so we need to do two queries. + var tokensToRevoke []OAuth2RefreshTable + if err := c. + Select("access_token_signature"). + Where("request_id=? AND nid = ? AND active", id, p.NetworkID(ctx)). + Limit(500). + All(&tokensToRevoke); err != nil { + return sqlcon.HandleError(err) + } + + } + + return nil +} + +func (p *Persister) RevokeRotatedTokens(ctx context.Context, refreshSignature string) (fosite.Requester, error) { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.RevokeRotatedTokens") + defer otelx.End(span, &err) + + err = p.QueryWithNetwork(ctx). + Where("request_id=?", id). + Delete(&OAuth2RequestSQL{Table: sqlTableAccess}) + if errors.Is(err, sql.ErrNoRows) { + return errorsx.WithStack(fosite.ErrNotFound) + } + + if err := p.Transaction(ctx, func(ctx context.Context, c *pop.Connection) error { + if gracePeriod := p.r.Config().RefreshTokenRotationGracePeriod(ctx); gracePeriod > 0 { + return p.gracefulRefreshRotation(ctx, c, requestID, refreshSignature, gracePeriod) + } + + if err := p.deleteSessionByRequestID(ctx, requestID, sqlTableAccess); err != nil { + return err + } + + _, err := c.Where("signature = ? AND nid = ? AND active", refreshSignature, p.NetworkID(ctx)).UpdateQuery(&OAuth2RefreshTable{ + OAuth2RequestSQL: OAuth2RequestSQL{ + Active: false, + }, + FirstUsedAt: sql.NullTime{ + Time: time.Now().UTC().Round(time.Millisecond), + Valid: true, + }, + }, "active", "first_used_at") + return sqlcon.HandleError(err) + }); err != nil { + if errors.Is(err, sqlcon.ErrConcurrentUpdate) { + return fosite.ErrSerializationFailure.WithWrap(err) + } + if strings.Contains(err.Error(), "Error 1213") { // InnoDB Deadlock + return errors.Wrap(fosite.ErrSerializationFailure, err.Error()) + } + return err + } + + return nil + + /* #nosec G201 table is static */ + + /* + err = p.QueryWithNetwork(ctx). + Where("request_id=?", id). + Delete(&OAuth2RequestSQL{Table: sqlTableAccess}) + if errors.Is(err, sql.ErrNoRows) { + return errorsx.WithStack(fosite.ErrNotFound) + } + + p.Transaction(ctx, func(ctx context.Context, c *sql.Tx) error { + type tokens []revokingRefreshToken + if p.conn.Dialect.Name() == dbal.DriverMySQL { + // MySQL does not support returning values from an update query, so we need to do two queries. + var t tokens + if err := p.Connection(ctx).Where("request_id=? AND nid = ? AND active", id, p.NetworkID(ctx)).Limit(500).All(&t); err != nil { + return sqlcon.HandleError(err) + } + + } else { + + } + + }) + + if err := p.Connection(ctx).RawQuery(` + SELECT access_token_signature, signature, first_used_at + FROM hydra_oauth2_refresh + WHERE request_id=? AND nid = ? AND active + ORDER BY signature LIMIT 500 + `).All(&tokens); err != nil { + return err + } + + p.Connection(ctx).Where("signature IN (?)", _).Limit(500).UpdateQuery(&OAuth2RequestSQL{ + + Table: sqlTableRefresh, + }, + "active", "first_used_at") + p.Connection(ctx).RawQuery(` + UPDATE hydra_oauth2_refresh + SET active=false, first_used_at = CURRENT_TIMESTAMP + WHERE signature in = (?) + LIMIT 500 + `).All(&t) + + p.Connection(ctx).RawQuery(` + UPDATE hydra_oauth2_refresh + SET active=false, first_used_at = CURRENT_TIMESTAMP + WHERE request_id=? AND nid = ? AND active + RETURNING access_token_signature + LIMIT 500 + `).All(&t) + + // mysql: + // "GET access_token_signature, id (?) WHERE request_id=? AND nid = ? AND active A RETURNING access_token_signature ORDER BY signature LIMIT 500" + // "UPDATE ... SET ..." + + // others: + // "UPDATE refresh SET active=false, first_used_at = CURRENT_TIMESTAMP WHERE request_id=? AND nid = ? AND active A RETURNING access_token_signature ORDER BY signature LIMIT 500" + + // "UPDATE access SET active=false WHERE request_id=? AND signature IN (?) LIMIT 500" + */ + /* #nosec G201 table is static */ +} + func (p *Persister) RevokeRefreshTokenMaybeGracePeriod(ctx context.Context, id string, _ string) (err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.RevokeRefreshTokenMaybeGracePeriod") defer otelx.End(span, &err) @@ -544,7 +682,7 @@ func (p *Persister) RevokeRefreshTokenMaybeGracePeriod(ctx context.Context, id s return sqlcon.HandleError( p.Connection(ctx). RawQuery( - fmt.Sprintf("UPDATE %s SET active=false, first_used_at = CURRENT_TIMESTAMP WHERE request_id=? AND nid = ? AND active", OAuth2RequestSQL{Table: sqlTableRefresh}.TableName()), + fmt.Sprintf("UPDATE %s SET active=false, first_used_at = CURRENT_TIMESTAMP WHERE request_id=? AND nid = ? AND active LIMIT 500", OAuth2RequestSQL{Table: sqlTableRefresh}.TableName()), id, p.NetworkID(ctx), ). From 21e9a9d3a4e27f6792ba2fcf61578a5fc0be7f85 Mon Sep 17 00:00:00 2001 From: aeneasr <3372410+aeneasr@users.noreply.github.com> Date: Fri, 29 Nov 2024 10:16:57 +0100 Subject: [PATCH 5/5] chore: synchronize workspaces --- persistence/sql/persister_oauth2.go | 231 ++++++++++++++-------------- 1 file changed, 119 insertions(+), 112 deletions(-) diff --git a/persistence/sql/persister_oauth2.go b/persistence/sql/persister_oauth2.go index 071bbe05501..091d64aab6a 100644 --- a/persistence/sql/persister_oauth2.go +++ b/persistence/sql/persister_oauth2.go @@ -456,7 +456,7 @@ func (p *Persister) CreateRefreshTokenSession(ctx context.Context, signature str return p.createSession(ctx, signature, requester, sqlTableRefresh, requester.GetSession().GetExpiresAt(fosite.RefreshToken).UTC()) } -func (p *Persister) RotateRefreshToken(ctx context.Context, refreshTokenSignature string) (requestID string, err error) { +func (p *Persister) GetRefreshTokenSession(ctx context.Context, signature string, session fosite.Session) (request fosite.Requester, err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetRefreshTokenSession") defer otelx.End(span, &err) @@ -540,6 +540,85 @@ func (p *Persister) RevokeRefreshToken(ctx context.Context, id string) (err erro return p.deactivateSessionByRequestID(ctx, id, sqlTableRefresh) } +func (p *Persister) RevokeRefreshTokenMaybeGracePeriod(ctx context.Context, id string, _ string) (err error) { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.RevokeRefreshTokenMaybeGracePeriod") + defer otelx.End(span, &err) + + /* #nosec G201 table is static */ + return sqlcon.HandleError( + p.Connection(ctx). + RawQuery( + fmt.Sprintf("UPDATE %s SET active=false, first_used_at = CURRENT_TIMESTAMP WHERE request_id=? AND nid = ? AND active", OAuth2RequestSQL{Table: sqlTableRefresh}.TableName()), + id, + p.NetworkID(ctx), + ). + Exec(), + ) +} + +func (p *Persister) RevokeAccessToken(ctx context.Context, id string) (err error) { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.RevokeAccessToken") + defer otelx.End(span, &err) + return p.deleteSessionByRequestID(ctx, id, sqlTableAccess) +} + +func (p *Persister) flushInactiveTokens(ctx context.Context, notAfter time.Time, limit int, batchSize int, table tableName, lifespan time.Duration) (err error) { + /* #nosec G201 table is static */ + // The value of notAfter should be the minimum between input parameter and token max expire based on its configured age + requestMaxExpire := time.Now().Add(-lifespan) + if requestMaxExpire.Before(notAfter) { + notAfter = requestMaxExpire + } + + totalDeletedCount := 0 + for deletedRecords := batchSize; totalDeletedCount < limit && deletedRecords == batchSize; { + d := batchSize + if limit-totalDeletedCount < batchSize { + d = limit - totalDeletedCount + } + // Delete in batches + // The outer SELECT is necessary because our version of MySQL doesn't yet support 'LIMIT & IN/ALL/ANY/SOME subquery + deletedRecords, err = p.Connection(ctx).RawQuery( + fmt.Sprintf(`DELETE FROM %s WHERE signature in ( + SELECT signature FROM (SELECT signature FROM %s hoa WHERE requested_at < ? and nid = ? ORDER BY requested_at LIMIT %d ) as s + )`, OAuth2RequestSQL{Table: table}.TableName(), OAuth2RequestSQL{Table: table}.TableName(), d), + notAfter, + p.NetworkID(ctx), + ).ExecWithCount() + totalDeletedCount += deletedRecords + + if err != nil { + break + } + p.l.Debugf("Flushing tokens...: %d/%d", totalDeletedCount, limit) + } + p.l.Debugf("Flush Refresh Tokens flushed_records: %d", totalDeletedCount) + return sqlcon.HandleError(err) +} + +func (p *Persister) FlushInactiveAccessTokens(ctx context.Context, notAfter time.Time, limit int, batchSize int) (err error) { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.FlushInactiveAccessTokens") + defer otelx.End(span, &err) + return p.flushInactiveTokens(ctx, notAfter, limit, batchSize, sqlTableAccess, p.config.GetAccessTokenLifespan(ctx)) +} + +func (p *Persister) FlushInactiveRefreshTokens(ctx context.Context, notAfter time.Time, limit int, batchSize int) (err error) { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.FlushInactiveRefreshTokens") + defer otelx.End(span, &err) + return p.flushInactiveTokens(ctx, notAfter, limit, batchSize, sqlTableRefresh, p.config.GetRefreshTokenLifespan(ctx)) +} + +func (p *Persister) DeleteAccessTokens(ctx context.Context, clientID string) (err error) { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.DeleteAccessTokens") + defer otelx.End(span, &err) + /* #nosec G201 table is static */ + return sqlcon.HandleError( + p.QueryWithNetwork(ctx).Where("client_id=?", clientID).Delete(&OAuth2RequestSQL{Table: sqlTableAccess}), + ) +} + +// ---- + func (p *Persister) gracefulRefreshRotation(ctx context.Context, c *pop.Connection, requestID string, refreshSignature string, period time.Duration) (err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.gracefulRefreshRotation", trace.WithAttributes( @@ -550,6 +629,8 @@ func (p *Persister) gracefulRefreshRotation(ctx context.Context, c *pop.Connecti )) defer otelx.End(span, &err) + c := p.Connection(ctx) + if p.conn.Dialect.Name() == dbal.DriverMySQL { // MySQL does not support returning values from an update query, so we need to do two queries. var tokensToRevoke []OAuth2RefreshTable @@ -566,47 +647,50 @@ func (p *Persister) gracefulRefreshRotation(ctx context.Context, c *pop.Connecti return nil } -func (p *Persister) RevokeRotatedTokens(ctx context.Context, refreshSignature string) (fosite.Requester, error) { - ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.RevokeRotatedTokens") - defer otelx.End(span, &err) +func (p *Persister) strictRefreshRotation(ctx context.Context, requestID string, refreshSignature string) (err error) { + c := p.Connection(ctx) + now := time.Now().UTC().Round(time.Millisecond) - err = p.QueryWithNetwork(ctx). - Where("request_id=?", id). - Delete(&OAuth2RequestSQL{Table: sqlTableAccess}) - if errors.Is(err, sql.ErrNoRows) { - return errorsx.WithStack(fosite.ErrNotFound) + // Remove the rotated access token + if err := p.deleteSessionByRequestID(ctx, requestID, sqlTableAccess); err != nil { + return err } - if err := p.Transaction(ctx, func(ctx context.Context, c *pop.Connection) error { - if gracePeriod := p.r.Config().RefreshTokenRotationGracePeriod(ctx); gracePeriod > 0 { - return p.gracefulRefreshRotation(ctx, c, requestID, refreshSignature, gracePeriod) - } + // Disable the rotated refresh token. + _, err = c. + Where( + "signature = ? AND nid = ? AND active", + refreshSignature, + p.NetworkID(ctx), + ). + UpdateQuery(&OAuth2RefreshTable{ + OAuth2RequestSQL: OAuth2RequestSQL{Active: false}, + FirstUsedAt: sql.NullTime{Time: now, Valid: true}, + }, "active", "first_used_at") + return sqlcon.HandleError(err) +} - if err := p.deleteSessionByRequestID(ctx, requestID, sqlTableAccess); err != nil { - return err - } +func handleRetryError(err error) error { + if errors.Is(err, sqlcon.ErrConcurrentUpdate) { + return fosite.ErrSerializationFailure.WithWrap(err) + } + if strings.Contains(err.Error(), "Error 1213") { // InnoDB Deadlock + return errors.Wrap(fosite.ErrSerializationFailure, err.Error()) + } + return nil +} - _, err := c.Where("signature = ? AND nid = ? AND active", refreshSignature, p.NetworkID(ctx)).UpdateQuery(&OAuth2RefreshTable{ - OAuth2RequestSQL: OAuth2RequestSQL{ - Active: false, - }, - FirstUsedAt: sql.NullTime{ - Time: time.Now().UTC().Round(time.Millisecond), - Valid: true, - }, - }, "active", "first_used_at") - return sqlcon.HandleError(err) - }); err != nil { - if errors.Is(err, sqlcon.ErrConcurrentUpdate) { - return fosite.ErrSerializationFailure.WithWrap(err) - } - if strings.Contains(err.Error(), "Error 1213") { // InnoDB Deadlock - return errors.Wrap(fosite.ErrSerializationFailure, err.Error()) - } - return err +func (p *Persister) RotateRefreshToken(ctx context.Context, refreshSignature string) (requestID string, err error) { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.RotateRefreshToken") + defer otelx.End(span, &err) + + if gracePeriod := p.r.Config().RefreshTokenRotationGracePeriod(ctx); gracePeriod > 0 { + return handleRetryError(p.gracefulRefreshRotation(ctx, refreshSignature, gracePeriod)) } - return nil + return handleRetryError(p.strictRefreshRotation(ctx, refreshSignature)) + + return requestID, nil /* #nosec G201 table is static */ @@ -673,80 +757,3 @@ func (p *Persister) RevokeRotatedTokens(ctx context.Context, refreshSignature st */ /* #nosec G201 table is static */ } - -func (p *Persister) RevokeRefreshTokenMaybeGracePeriod(ctx context.Context, id string, _ string) (err error) { - ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.RevokeRefreshTokenMaybeGracePeriod") - defer otelx.End(span, &err) - - /* #nosec G201 table is static */ - return sqlcon.HandleError( - p.Connection(ctx). - RawQuery( - fmt.Sprintf("UPDATE %s SET active=false, first_used_at = CURRENT_TIMESTAMP WHERE request_id=? AND nid = ? AND active LIMIT 500", OAuth2RequestSQL{Table: sqlTableRefresh}.TableName()), - id, - p.NetworkID(ctx), - ). - Exec(), - ) -} - -func (p *Persister) RevokeAccessToken(ctx context.Context, id string) (err error) { - ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.RevokeAccessToken") - defer otelx.End(span, &err) - return p.deleteSessionByRequestID(ctx, id, sqlTableAccess) -} - -func (p *Persister) flushInactiveTokens(ctx context.Context, notAfter time.Time, limit int, batchSize int, table tableName, lifespan time.Duration) (err error) { - /* #nosec G201 table is static */ - // The value of notAfter should be the minimum between input parameter and token max expire based on its configured age - requestMaxExpire := time.Now().Add(-lifespan) - if requestMaxExpire.Before(notAfter) { - notAfter = requestMaxExpire - } - - totalDeletedCount := 0 - for deletedRecords := batchSize; totalDeletedCount < limit && deletedRecords == batchSize; { - d := batchSize - if limit-totalDeletedCount < batchSize { - d = limit - totalDeletedCount - } - // Delete in batches - // The outer SELECT is necessary because our version of MySQL doesn't yet support 'LIMIT & IN/ALL/ANY/SOME subquery - deletedRecords, err = p.Connection(ctx).RawQuery( - fmt.Sprintf(`DELETE FROM %s WHERE signature in ( - SELECT signature FROM (SELECT signature FROM %s hoa WHERE requested_at < ? and nid = ? ORDER BY requested_at LIMIT %d ) as s - )`, OAuth2RequestSQL{Table: table}.TableName(), OAuth2RequestSQL{Table: table}.TableName(), d), - notAfter, - p.NetworkID(ctx), - ).ExecWithCount() - totalDeletedCount += deletedRecords - - if err != nil { - break - } - p.l.Debugf("Flushing tokens...: %d/%d", totalDeletedCount, limit) - } - p.l.Debugf("Flush Refresh Tokens flushed_records: %d", totalDeletedCount) - return sqlcon.HandleError(err) -} - -func (p *Persister) FlushInactiveAccessTokens(ctx context.Context, notAfter time.Time, limit int, batchSize int) (err error) { - ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.FlushInactiveAccessTokens") - defer otelx.End(span, &err) - return p.flushInactiveTokens(ctx, notAfter, limit, batchSize, sqlTableAccess, p.config.GetAccessTokenLifespan(ctx)) -} - -func (p *Persister) FlushInactiveRefreshTokens(ctx context.Context, notAfter time.Time, limit int, batchSize int) (err error) { - ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.FlushInactiveRefreshTokens") - defer otelx.End(span, &err) - return p.flushInactiveTokens(ctx, notAfter, limit, batchSize, sqlTableRefresh, p.config.GetRefreshTokenLifespan(ctx)) -} - -func (p *Persister) DeleteAccessTokens(ctx context.Context, clientID string) (err error) { - ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.DeleteAccessTokens") - defer otelx.End(span, &err) - /* #nosec G201 table is static */ - return sqlcon.HandleError( - p.QueryWithNetwork(ctx).Where("client_id=?", clientID).Delete(&OAuth2RequestSQL{Table: sqlTableAccess}), - ) -}