diff --git a/go.mod b/go.mod index 26adbfc96c..809bb14386 100644 --- a/go.mod +++ b/go.mod @@ -13,7 +13,9 @@ replace github.com/gobuffalo/pop/v6 => github.com/ory/pop/v6 v6.2.0 // // This is needed until we release the next version of the master branch, as that branch already contains the redirect URI validation fix, which // may be breaking for some users. -replace github.com/ory/fosite => github.com/ory/fosite v0.47.1-0.20241101073333-eab241e153a4 +//replace github.com/ory/fosite => github.com/ory/fosite v0.47.1-0.20241125094724-b6468e644902 + +replace github.com/ory/fosite => ../fosite require ( github.com/ThalesIgnite/crypto11 v1.2.5 diff --git a/go.sum b/go.sum index 5f7a2ba760..2e7f7a6fb0 100644 --- a/go.sum +++ b/go.sum @@ -378,8 +378,6 @@ github.com/ory/analytics-go/v5 v5.0.1 h1:LX8T5B9FN8KZXOtxgN+R3I4THRRVB6+28IKgKBp github.com/ory/analytics-go/v5 v5.0.1/go.mod h1:lWCiCjAaJkKfgR/BN5DCLMol8BjKS1x+4jxBxff/FF0= github.com/ory/dockertest/v3 v3.10.1-0.20240704115616-d229e74b748d h1:By96ZSVuH5LyjXLVVMfvJoLVGHaT96LdOnwgFSLVf0E= github.com/ory/dockertest/v3 v3.10.1-0.20240704115616-d229e74b748d/go.mod h1:F2FIjwwAk6CsNAs//B8+aPFQF0t84pbM8oliyNXwQrk= -github.com/ory/fosite v0.47.1-0.20241101073333-eab241e153a4 h1:1pEVHGC+Dx2xMPMgpRgG3lyejyK8iU9KKfSnLowLYd8= -github.com/ory/fosite v0.47.1-0.20241101073333-eab241e153a4/go.mod h1:AZyn1jrABUaGN12RHcWorRLbqLn52gTdHaIYY81m5J0= github.com/ory/go-acc v0.2.9-0.20230103102148-6b1c9a70dbbe h1:rvu4obdvqR0fkSIJ8IfgzKOWwZ5kOT2UNfLq81Qk7rc= github.com/ory/go-acc v0.2.9-0.20230103102148-6b1c9a70dbbe/go.mod h1:z4n3u6as84LbV4YmgjHhnwtccQqzf4cZlSk9f1FhygI= github.com/ory/go-convenience v0.1.0 h1:zouLKfF2GoSGnJwGq+PE/nJAE6dj2Zj5QlTgmMTsTS8= diff --git a/oauth2/fosite_store_test.go b/oauth2/fosite_store_test.go index 2a48a52f8e..292988b77c 100644 --- a/oauth2/fosite_store_test.go +++ b/oauth2/fosite_store_test.go @@ -16,7 +16,6 @@ import ( "github.com/ory/hydra/v2/internal" . "github.com/ory/hydra/v2/oauth2" "github.com/ory/x/contextx" - "github.com/ory/x/networkx" "github.com/ory/x/sqlcon/dockertest" ) @@ -72,14 +71,11 @@ func TestManagers(t *testing.T) { require.NoError(t, registries["memory"].ClientManager().CreateClient(context.Background(), &client.Client{ID: "foobar"})) // this is a workaround because the client is not being created for memory store by test helpers. - for k, store := range registries { - net := &networkx.Network{} - require.NoError(t, store.Persister().Connection(context.Background()).First(net)) - store.Config().MustSet(ctx, config.KeyEncryptSessionData, tc.enableSessionEncrypted) - store.WithContextualizer(&contextx.Static{NID: net.ID, C: store.Config().Source(ctx)}) - TestHelperRunner(t, store, k) + for k, _ := range registries { + reg := internal.NewRegistrySQLFromURL(t, registries[k].Config().DSN(), true, &contextx.Default{}) + reg.Config().MustSet(ctx, config.KeyEncryptSessionData, tc.enableSessionEncrypted) + TestHelperRunner(t, reg, k) } }) - } } diff --git a/oauth2/oauth2_auth_code_test.go b/oauth2/oauth2_auth_code_test.go index 0d89e14ac9..19407ccf56 100644 --- a/oauth2/oauth2_auth_code_test.go +++ b/oauth2/oauth2_auth_code_test.go @@ -10,6 +10,7 @@ import ( "encoding/json" "errors" "fmt" + "github.com/ory/hydra/v2/jwk" "io" "net/http" "net/http/httptest" @@ -62,6 +63,95 @@ type clientCreator interface { CreateClient(context.Context, *client.Client) error } +func getAuthorizeCode(t *testing.T, conf *oauth2.Config, c *http.Client, params ...oauth2.AuthCodeOption) (string, *http.Response) { + if c == nil { + c = testhelpers.NewEmptyJarClient(t) + } + + state := uuid.New() + resp, err := c.Get(conf.AuthCodeURL(state, params...)) + require.NoError(t, err) + defer resp.Body.Close() + + q := resp.Request.URL.Query() + require.EqualValues(t, state, q.Get("state")) + return q.Get("code"), resp +} + +func acceptLoginHandler(t *testing.T, c *client.Client, adminClient *hydra.APIClient, reg driver.Registry, subject string, checkRequestPayload func(request *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + rr, _, err := adminClient.OAuth2API.GetOAuth2LoginRequest(context.Background()).LoginChallenge(r.URL.Query().Get("login_challenge")).Execute() + require.NoError(t, err) + + assert.EqualValues(t, c.GetID(), pointerx.Deref(rr.Client.ClientId)) + assert.Empty(t, pointerx.Deref(rr.Client.ClientSecret)) + assert.EqualValues(t, c.GrantTypes, rr.Client.GrantTypes) + assert.EqualValues(t, c.LogoURI, pointerx.Deref(rr.Client.LogoUri)) + assert.EqualValues(t, c.RedirectURIs, rr.Client.RedirectUris) + assert.EqualValues(t, r.URL.Query().Get("login_challenge"), rr.Challenge) + assert.EqualValues(t, []string{"hydra", "offline", "openid"}, rr.RequestedScope) + assert.Contains(t, rr.RequestUrl, reg.Config().OAuth2AuthURL(ctx).String()) + + acceptBody := hydra.AcceptOAuth2LoginRequest{ + Subject: subject, + Remember: pointerx.Ptr(!rr.Skip), + Acr: pointerx.Ptr("1"), + Amr: []string{"pwd"}, + Context: map[string]interface{}{"context": "bar"}, + } + if checkRequestPayload != nil { + if b := checkRequestPayload(rr); b != nil { + acceptBody = *b + } + } + + v, _, err := adminClient.OAuth2API.AcceptOAuth2LoginRequest(context.Background()). + LoginChallenge(r.URL.Query().Get("login_challenge")). + AcceptOAuth2LoginRequest(acceptBody). + Execute() + require.NoError(t, err) + require.NotEmpty(t, v.RedirectTo) + http.Redirect(w, r, v.RedirectTo, http.StatusFound) + } +} + +func acceptConsentHandler(t *testing.T, c *client.Client, adminClient *hydra.APIClient, reg driver.Registry, subject string, checkRequestPayload func(*hydra.OAuth2ConsentRequest)) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + rr, _, err := adminClient.OAuth2API.GetOAuth2ConsentRequest(context.Background()).ConsentChallenge(r.URL.Query().Get("consent_challenge")).Execute() + require.NoError(t, err) + + assert.EqualValues(t, c.GetID(), pointerx.Deref(rr.Client.ClientId)) + assert.Empty(t, pointerx.Deref(rr.Client.ClientSecret)) + assert.EqualValues(t, c.GrantTypes, rr.Client.GrantTypes) + assert.EqualValues(t, c.LogoURI, pointerx.Deref(rr.Client.LogoUri)) + assert.EqualValues(t, c.RedirectURIs, rr.Client.RedirectUris) + assert.EqualValues(t, subject, pointerx.Deref(rr.Subject)) + assert.EqualValues(t, []string{"hydra", "offline", "openid"}, rr.RequestedScope) + assert.EqualValues(t, r.URL.Query().Get("consent_challenge"), rr.Challenge) + assert.Contains(t, *rr.RequestUrl, reg.Config().OAuth2AuthURL(r.Context()).String()) + if checkRequestPayload != nil { + checkRequestPayload(rr) + } + + assert.Equal(t, map[string]interface{}{"context": "bar"}, rr.Context) + v, _, err := adminClient.OAuth2API.AcceptOAuth2ConsentRequest(context.Background()). + ConsentChallenge(r.URL.Query().Get("consent_challenge")). + AcceptOAuth2ConsentRequest(hydra.AcceptOAuth2ConsentRequest{ + GrantScope: []string{"hydra", "offline", "openid"}, Remember: pointerx.Ptr(true), RememberFor: pointerx.Ptr[int64](0), + GrantAccessTokenAudience: rr.RequestedAccessTokenAudience, + Session: &hydra.AcceptOAuth2ConsentRequestSession{ + AccessToken: map[string]interface{}{"foo": "bar"}, + IdToken: map[string]interface{}{"bar": "baz", "email": "foo@bar.com"}, + }, + }). + Execute() + require.NoError(t, err) + require.NotEmpty(t, v.RedirectTo) + http.Redirect(w, r, v.RedirectTo, http.StatusFound) + } +} + // TestAuthCodeWithDefaultStrategy runs proper integration tests against in-memory and database connectors, specifically // we test: // @@ -76,8 +166,14 @@ type clientCreator interface { // - [x] If `id_token_hint` is handled properly // - [x] What happens if `id_token_hint` does not match the value from the handled authentication request ("accept login") func TestAuthCodeWithDefaultStrategy(t *testing.T) { + setupRegistries(t) + ctx := context.Background() - reg := internal.NewMockedRegistry(t, &contextx.Default{}) + reg := internal.NewRegistrySQLFromURL(t, registries["cockroach"].Config().DSN(), true, &contextx.Default{}) + + require.NoError(t, jwk.EnsureAsymmetricKeypairExists(ctx, reg, string(jose.ES256), x.OpenIDConnectKeyName)) + require.NoError(t, jwk.EnsureAsymmetricKeypairExists(ctx, reg, string(jose.ES256), x.OAuth2JWTKeyName)) + reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "opaque") reg.Config().MustSet(ctx, config.KeyRefreshTokenHook, "") publicTS, adminTS := testhelpers.NewOAuth2Server(ctx, t, reg) @@ -87,94 +183,6 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) { adminClient := hydra.NewAPIClient(hydra.NewConfiguration()) adminClient.GetConfig().Servers = hydra.ServerConfigurations{{URL: adminTS.URL}} - getAuthorizeCode := func(t *testing.T, conf *oauth2.Config, c *http.Client, params ...oauth2.AuthCodeOption) (string, *http.Response) { - if c == nil { - c = testhelpers.NewEmptyJarClient(t) - } - - state := uuid.New() - resp, err := c.Get(conf.AuthCodeURL(state, params...)) - require.NoError(t, err) - defer resp.Body.Close() - - q := resp.Request.URL.Query() - require.EqualValues(t, state, q.Get("state")) - return q.Get("code"), resp - } - - acceptLoginHandler := func(t *testing.T, c *client.Client, subject string, checkRequestPayload func(request *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - rr, _, err := adminClient.OAuth2API.GetOAuth2LoginRequest(context.Background()).LoginChallenge(r.URL.Query().Get("login_challenge")).Execute() - require.NoError(t, err) - - assert.EqualValues(t, c.GetID(), pointerx.Deref(rr.Client.ClientId)) - assert.Empty(t, pointerx.Deref(rr.Client.ClientSecret)) - assert.EqualValues(t, c.GrantTypes, rr.Client.GrantTypes) - assert.EqualValues(t, c.LogoURI, pointerx.Deref(rr.Client.LogoUri)) - assert.EqualValues(t, c.RedirectURIs, rr.Client.RedirectUris) - assert.EqualValues(t, r.URL.Query().Get("login_challenge"), rr.Challenge) - assert.EqualValues(t, []string{"hydra", "offline", "openid"}, rr.RequestedScope) - assert.Contains(t, rr.RequestUrl, reg.Config().OAuth2AuthURL(ctx).String()) - - acceptBody := hydra.AcceptOAuth2LoginRequest{ - Subject: subject, - Remember: pointerx.Ptr(!rr.Skip), - Acr: pointerx.Ptr("1"), - Amr: []string{"pwd"}, - Context: map[string]interface{}{"context": "bar"}, - } - if checkRequestPayload != nil { - if b := checkRequestPayload(rr); b != nil { - acceptBody = *b - } - } - - v, _, err := adminClient.OAuth2API.AcceptOAuth2LoginRequest(context.Background()). - LoginChallenge(r.URL.Query().Get("login_challenge")). - AcceptOAuth2LoginRequest(acceptBody). - Execute() - require.NoError(t, err) - require.NotEmpty(t, v.RedirectTo) - http.Redirect(w, r, v.RedirectTo, http.StatusFound) - } - } - - acceptConsentHandler := func(t *testing.T, c *client.Client, subject string, checkRequestPayload func(*hydra.OAuth2ConsentRequest)) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - rr, _, err := adminClient.OAuth2API.GetOAuth2ConsentRequest(context.Background()).ConsentChallenge(r.URL.Query().Get("consent_challenge")).Execute() - require.NoError(t, err) - - assert.EqualValues(t, c.GetID(), pointerx.Deref(rr.Client.ClientId)) - assert.Empty(t, pointerx.Deref(rr.Client.ClientSecret)) - assert.EqualValues(t, c.GrantTypes, rr.Client.GrantTypes) - assert.EqualValues(t, c.LogoURI, pointerx.Deref(rr.Client.LogoUri)) - assert.EqualValues(t, c.RedirectURIs, rr.Client.RedirectUris) - assert.EqualValues(t, subject, pointerx.Deref(rr.Subject)) - assert.EqualValues(t, []string{"hydra", "offline", "openid"}, rr.RequestedScope) - assert.EqualValues(t, r.URL.Query().Get("consent_challenge"), rr.Challenge) - assert.Contains(t, *rr.RequestUrl, reg.Config().OAuth2AuthURL(ctx).String()) - if checkRequestPayload != nil { - checkRequestPayload(rr) - } - - assert.Equal(t, map[string]interface{}{"context": "bar"}, rr.Context) - v, _, err := adminClient.OAuth2API.AcceptOAuth2ConsentRequest(context.Background()). - ConsentChallenge(r.URL.Query().Get("consent_challenge")). - AcceptOAuth2ConsentRequest(hydra.AcceptOAuth2ConsentRequest{ - GrantScope: []string{"hydra", "offline", "openid"}, Remember: pointerx.Ptr(true), RememberFor: pointerx.Ptr[int64](0), - GrantAccessTokenAudience: rr.RequestedAccessTokenAudience, - Session: &hydra.AcceptOAuth2ConsentRequestSession{ - AccessToken: map[string]interface{}{"foo": "bar"}, - IdToken: map[string]interface{}{"bar": "baz", "email": "foo@bar.com"}, - }, - }). - Execute() - require.NoError(t, err) - require.NotEmpty(t, v.RedirectTo) - http.Redirect(w, r, v.RedirectTo, http.StatusFound) - } - } - assertRefreshToken := func(t *testing.T, token *oauth2.Token, c *oauth2.Config, expectedExp time.Time) { introspect := testhelpers.IntrospectToken(t, c, token.RefreshToken, adminTS) actualExp, err := strconv.ParseInt(introspect.Get("exp").String(), 10, 64) @@ -266,8 +274,8 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) { run := func(t *testing.T, strategy string) { c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) testhelpers.NewLoginConsentUI(t, reg.Config(), - acceptLoginHandler(t, c, subject, nil), - acceptConsentHandler(t, c, subject, nil), + acceptLoginHandler(t, c, adminClient, reg, subject, nil), + acceptConsentHandler(t, c, adminClient, reg, subject, nil), ) code, _ := getAuthorizeCode(t, conf, nil, oauth2.SetAuthURLParam("nonce", nonce)) @@ -335,15 +343,21 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) { t.Run("case=graceful token rotation", func(t *testing.T) { run := func(t *testing.T, strategy string) { - reg.Config().MustSet(ctx, config.KeyRefreshTokenRotationGracePeriod, "5s") + reg.Config().MustSet(ctx, config.KeyRefreshTokenRotationGracePeriod, time.Second*3) + reg.Config().MustSet(ctx, config.KeyTokenHook, nil) + reg.Config().MustSet(ctx, config.KeyRefreshTokenHook, nil) + reg.Config().MustSet(ctx, config.KeyRefreshTokenLifespan, time.Minute) + reg.Config().MustSet(ctx, config.KeyAccessTokenLifespan, time.Minute) t.Cleanup(func() { reg.Config().MustSet(ctx, config.KeyRefreshTokenRotationGracePeriod, nil) + reg.Config().MustSet(ctx, config.KeyRefreshTokenLifespan, nil) + reg.Config().MustSet(ctx, config.KeyAccessTokenLifespan, nil) }) c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) testhelpers.NewLoginConsentUI(t, reg.Config(), - acceptLoginHandler(t, c, subject, nil), - acceptConsentHandler(t, c, subject, nil), + acceptLoginHandler(t, c, adminClient, reg, subject, nil), + acceptConsentHandler(t, c, adminClient, reg, subject, nil), ) issueTokens := func(t *testing.T) *oauth2.Token { @@ -362,7 +376,7 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) { refreshTokens := func(t *testing.T, token *oauth2.Token) *oauth2.Token { require.NotEmpty(t, token.RefreshToken) - token.Expiry = token.Expiry.Add(-time.Hour * 24) + token.Expiry = time.Now().Add(-time.Hour * 24) iat := time.Now() refreshedToken, err := conf.TokenSource(context.Background(), token).Token() require.NoError(t, err) @@ -378,40 +392,63 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) { return refreshedToken } - t.Run("followup=successfully perform refresh token flow", func(t *testing.T) { - start := time.Now() + t.Run("followup=graceful token refresh can handle concurrent refreshing", func(t *testing.T) { + code, _ := getAuthorizeCode(t, conf, nil, oauth2.SetAuthURLParam("nonce", nonce)) + require.NotEmpty(t, code) + token, err := conf.Exchange(context.Background(), code) + require.NoError(t, err) - token := issueTokens(t) - var first, second *oauth2.Token - t.Run("followup=first refresh", func(t *testing.T) { - first = refreshTokens(t, token) - }) + var wg sync.WaitGroup + refresh := func(t *testing.T, token *oauth2.Token) *oauth2.Token { + require.NotEmpty(t, token.RefreshToken) + token.Expiry = time.Now().Add(-time.Hour * 24) + tt, err := conf.TokenSource(context.Background(), token).Token() + require.NoError(t, err) + return tt + } - t.Run("followup=second refresh", func(t *testing.T) { - second = refreshTokens(t, token) - }) + refreshes := make([]*oauth2.Token, 5) + for k := range refreshes { + wg.Add(1) + //time.Sleep(time.Millisecond * 100) + go func(k int) { + defer wg.Done() + t.Logf("Refreshing token %d", k) + refreshes[k] = refresh(t, token) + }(k) + } - // Sleep until the grace period is over - time.Sleep(time.Until(start.Add(5*time.Second + time.Millisecond*10))) - t.Run("followup=refresh failure invalidates all tokens", func(t *testing.T) { - _, err := conf.TokenSource(context.Background(), token).Token() - assert.Error(t, err) + wg.Wait() - i := testhelpers.IntrospectToken(t, conf, first.AccessToken, adminTS) - assert.False(t, i.Get("active").Bool(), "%s", i) + for k, actual := range refreshes { + refresh := actual - i = testhelpers.IntrospectToken(t, conf, second.AccessToken, adminTS) - assert.False(t, i.Get("active").Bool(), "%s", i) + require.NotEmpty(t, refresh.RefreshToken) + require.NotEmpty(t, refresh.AccessToken) + require.NotEmpty(t, refresh.Extra("id_token")) - i = testhelpers.IntrospectToken(t, conf, first.RefreshToken, adminTS) - assert.False(t, i.Get("active").Bool(), "%s", i) + i := testhelpers.IntrospectToken(t, conf, refresh.AccessToken, adminTS) + assert.Truef(t, i.Get("active").Bool(), "token %d:\ntoken:%+v\nresult:%s", k, refresh, i) - i = testhelpers.IntrospectToken(t, conf, second.RefreshToken, adminTS) - assert.False(t, i.Get("active").Bool(), "%s", i) + i = testhelpers.IntrospectToken(t, conf, refresh.RefreshToken, adminTS) + assert.Truef(t, i.Get("active").Bool(), "token %d:\ntoken:%+v\nresult:%s", k, refresh, i) + } + + t.Run("followup=revoking consent revokes all tokens", func(t *testing.T) { + err := reg.ConsentManager().RevokeSubjectConsentSession(context.Background(), subject) + require.NoError(t, err) + + _, err = conf.TokenSource(context.Background(), token).Token() + assert.Error(t, err) + + for k, actual := range refreshes { + i := testhelpers.IntrospectToken(t, conf, actual.AccessToken, adminTS) + assert.False(t, i.Get("active").Bool(), "token %d: %s", k, i) + } }) }) - t.Run("followup=successfully perform refresh token flow", func(t *testing.T) { + t.Run("followup=graceful token refresh with reuse detection", func(t *testing.T) { start := time.Now() token := issueTokens(t) @@ -424,82 +461,134 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) { second = refreshTokens(t, token) }) - // Sleep until the grace period is over - time.Sleep(time.Until(start.Add(5*time.Second + time.Millisecond*10))) - t.Run("followup=revoking consent revokes all tokens", func(t *testing.T) { - err := reg.ConsentManager().RevokeSubjectConsentSession(context.Background(), subject) - require.NoError(t, err) - - _, err = conf.TokenSource(context.Background(), token).Token() - assert.Error(t, err) - + t.Run("followup=all resulting tokens are valid", func(t *testing.T) { i := testhelpers.IntrospectToken(t, conf, first.AccessToken, adminTS) - assert.False(t, i.Get("active").Bool(), "%s", i) + assert.True(t, i.Get("active").Bool(), "%s", i) i = testhelpers.IntrospectToken(t, conf, second.AccessToken, adminTS) - assert.False(t, i.Get("active").Bool(), "%s", i) + assert.True(t, i.Get("active").Bool(), "%s", i) i = testhelpers.IntrospectToken(t, conf, first.RefreshToken, adminTS) - assert.False(t, i.Get("active").Bool(), "%s", i) + assert.True(t, i.Get("active").Bool(), "%s", i) i = testhelpers.IntrospectToken(t, conf, second.RefreshToken, adminTS) - assert.False(t, i.Get("active").Bool(), "%s", i) + assert.True(t, i.Get("active").Bool(), "%s", i) }) - }) - t.Run("followup=graceful refresh tokens are all refreshed", func(t *testing.T) { - start := time.Now() - token := issueTokens(t) - var a1Refresh, b1Refresh, a2RefreshA, a2RefreshB, b2RefreshA, b2RefreshB *oauth2.Token - t.Run("followup=first refresh", func(t *testing.T) { - a1Refresh = refreshTokens(t, token) - }) - - t.Run("followup=second refresh", func(t *testing.T) { - b1Refresh = refreshTokens(t, token) - }) - - t.Run("followup=first refresh from first refresh", func(t *testing.T) { - a2RefreshA = refreshTokens(t, a1Refresh) - }) - - t.Run("followup=second refresh from first refresh", func(t *testing.T) { - a2RefreshB = refreshTokens(t, a1Refresh) - }) + // Sleep until the grace period is over + time.Sleep(time.Until(start.Add(4 * time.Second))) - t.Run("followup=first refresh from second refresh", func(t *testing.T) { - b2RefreshA = refreshTokens(t, b1Refresh) - }) + t.Run("followup=refresh failure invalidates all tokens", func(t *testing.T) { + // Fetching the token again should cause an error because we are no longer in the grace period. + token.Expiry = time.Now().Add(-time.Hour * 24) + i := testhelpers.IntrospectToken(t, conf, token.AccessToken, adminTS) + assert.False(t, i.Get("active").Bool(), "%s", i) - t.Run("followup=second refresh from second refresh", func(t *testing.T) { - b2RefreshB = refreshTokens(t, b1Refresh) - }) + i = testhelpers.IntrospectToken(t, conf, token.RefreshToken, adminTS) + assert.False(t, i.Get("active").Bool(), "%s", i) - // Sleep until the grace period is over - time.Sleep(time.Until(start.Add(5*time.Second + time.Millisecond*10))) - t.Run("followup=refresh failure invalidates all tokens", func(t *testing.T) { - _, err := conf.TokenSource(context.Background(), token).Token() - assert.Error(t, err) + result, err := conf.TokenSource(context.Background(), token).Token() + assert.Error(t, err, "%+v", result) - for k, token := range []*oauth2.Token{ - a1Refresh, b1Refresh, a2RefreshA, a2RefreshB, b2RefreshA, b2RefreshB, - } { - t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) { - i := testhelpers.IntrospectToken(t, conf, token.AccessToken, adminTS) - assert.False(t, i.Get("active").Bool(), "%s", i) + i = testhelpers.IntrospectToken(t, conf, first.AccessToken, adminTS) + assert.False(t, i.Get("active").Bool(), "%s", i) - i = testhelpers.IntrospectToken(t, conf, token.AccessToken, adminTS) - assert.False(t, i.Get("active").Bool(), "%s", i) + i = testhelpers.IntrospectToken(t, conf, second.AccessToken, adminTS) + assert.False(t, i.Get("active").Bool(), "%s", i) - i = testhelpers.IntrospectToken(t, conf, token.RefreshToken, adminTS) - assert.False(t, i.Get("active").Bool(), "%s", i) + i = testhelpers.IntrospectToken(t, conf, first.RefreshToken, adminTS) + assert.False(t, i.Get("active").Bool(), "%s", i) - i = testhelpers.IntrospectToken(t, conf, token.RefreshToken, adminTS) - assert.False(t, i.Get("active").Bool(), "%s", i) - }) - } + i = testhelpers.IntrospectToken(t, conf, second.RefreshToken, adminTS) + assert.False(t, i.Get("active").Bool(), "%s", i) }) }) + // + //t.Run("followup=graceful token refresh with reuse detection with consent revocation", func(t *testing.T) { + // token := issueTokens(t) + // var first, second *oauth2.Token + // t.Run("followup=first refresh", func(t *testing.T) { + // first = refreshTokens(t, token) + // }) + // + // t.Run("followup=second refresh", func(t *testing.T) { + // second = refreshTokens(t, token) + // }) + // + // t.Run("followup=revoking consent revokes all tokens", func(t *testing.T) { + // err := reg.ConsentManager().RevokeSubjectConsentSession(context.Background(), subject) + // require.NoError(t, err) + // + // _, err = conf.TokenSource(context.Background(), token).Token() + // assert.Error(t, err) + // + // i := testhelpers.IntrospectToken(t, conf, first.AccessToken, adminTS) + // assert.False(t, i.Get("active").Bool(), "%s", i) + // + // i = testhelpers.IntrospectToken(t, conf, second.AccessToken, adminTS) + // assert.False(t, i.Get("active").Bool(), "%s", i) + // + // i = testhelpers.IntrospectToken(t, conf, first.RefreshToken, adminTS) + // assert.False(t, i.Get("active").Bool(), "%s", i) + // + // i = testhelpers.IntrospectToken(t, conf, second.RefreshToken, adminTS) + // assert.False(t, i.Get("active").Bool(), "%s", i) + // }) + //}) + // + //t.Run("followup=graceful refresh tokens with multiple nested branches belong to the same request", func(t *testing.T) { + // start := time.Now() + // token := issueTokens(t) + // var a1Refresh, b1Refresh, a2RefreshA, a2RefreshB, b2RefreshA, b2RefreshB *oauth2.Token + // t.Run("followup=first refresh", func(t *testing.T) { + // a1Refresh = refreshTokens(t, token) + // }) + // + // t.Run("followup=second refresh", func(t *testing.T) { + // b1Refresh = refreshTokens(t, token) + // }) + // + // t.Run("followup=first refresh from first refresh", func(t *testing.T) { + // a2RefreshA = refreshTokens(t, a1Refresh) + // }) + // + // t.Run("followup=second refresh from first refresh", func(t *testing.T) { + // a2RefreshB = refreshTokens(t, a1Refresh) + // }) + // + // t.Run("followup=first refresh from second refresh", func(t *testing.T) { + // b2RefreshA = refreshTokens(t, b1Refresh) + // }) + // + // t.Run("followup=second refresh from second refresh", func(t *testing.T) { + // b2RefreshB = refreshTokens(t, b1Refresh) + // }) + // + // // Sleep until the grace period is over + // time.Sleep(time.Until(start.Add(5*time.Second + time.Millisecond*10))) + // t.Run("followup=refresh failure invalidates all tokens", func(t *testing.T) { + // _, err := conf.TokenSource(context.Background(), token).Token() + // assert.Error(t, err) + // + // for k, token := range []*oauth2.Token{ + // a1Refresh, b1Refresh, a2RefreshA, a2RefreshB, b2RefreshA, b2RefreshB, + // } { + // t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) { + // i := testhelpers.IntrospectToken(t, conf, token.AccessToken, adminTS) + // assert.False(t, i.Get("active").Bool(), "%s", i) + // + // i = testhelpers.IntrospectToken(t, conf, token.AccessToken, adminTS) + // assert.False(t, i.Get("active").Bool(), "%s", i) + // + // i = testhelpers.IntrospectToken(t, conf, token.RefreshToken, adminTS) + // assert.False(t, i.Get("active").Bool(), "%s", i) + // + // i = testhelpers.IntrospectToken(t, conf, token.RefreshToken, adminTS) + // assert.False(t, i.Get("active").Bool(), "%s", i) + // }) + // } + // }) + //}) } t.Run("strategy=jwt", func(t *testing.T) { @@ -698,8 +787,8 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) { c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) otherClient, _ := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) testhelpers.NewLoginConsentUI(t, reg.Config(), - acceptLoginHandler(t, c, subject, nil), - acceptConsentHandler(t, c, subject, nil), + acceptLoginHandler(t, c, adminClient, reg, subject, nil), + acceptConsentHandler(t, c, adminClient, reg, subject, nil), ) withWrongClientAfterLogin := &http.Client{ @@ -813,13 +902,13 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) { t.Run("case=perform flow with prompt=registration", func(t *testing.T) { c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) - regUI := httptest.NewServer(acceptLoginHandler(t, c, subject, nil)) + regUI := httptest.NewServer(acceptLoginHandler(t, c, adminClient, reg, subject, nil)) t.Cleanup(regUI.Close) reg.Config().MustSet(ctx, config.KeyRegistrationURL, regUI.URL) testhelpers.NewLoginConsentUI(t, reg.Config(), nil, - acceptConsentHandler(t, c, subject, nil)) + acceptConsentHandler(t, c, adminClient, reg, subject, nil)) code, _ := getAuthorizeCode(t, conf, nil, oauth2.SetAuthURLParam("prompt", "registration"), @@ -836,12 +925,12 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) { expectAud := "https://api.ory.sh/" c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) testhelpers.NewLoginConsentUI(t, reg.Config(), - acceptLoginHandler(t, c, subject, func(r *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest { + acceptLoginHandler(t, c, adminClient, reg, subject, func(r *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest { assert.False(t, r.Skip) assert.EqualValues(t, []string{expectAud}, r.RequestedAccessTokenAudience) return nil }), - acceptConsentHandler(t, c, subject, func(r *hydra.OAuth2ConsentRequest) { + acceptConsentHandler(t, c, adminClient, reg, subject, func(r *hydra.OAuth2ConsentRequest) { assert.False(t, *r.Skip) assert.EqualValues(t, []string{expectAud}, r.RequestedAccessTokenAudience) })) @@ -865,8 +954,8 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) { t.Run("case=respects client token lifespan configuration", func(t *testing.T) { run := func(t *testing.T, strategy string, c *client.Client, conf *oauth2.Config, expectedLifespans client.Lifespans) { testhelpers.NewLoginConsentUI(t, reg.Config(), - acceptLoginHandler(t, c, subject, nil), - acceptConsentHandler(t, c, subject, nil), + acceptLoginHandler(t, c, adminClient, reg, subject, nil), + acceptConsentHandler(t, c, adminClient, reg, subject, nil), ) code, _ := getAuthorizeCode(t, conf, nil, oauth2.SetAuthURLParam("nonce", nonce)) @@ -967,8 +1056,8 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) { t.Run("case=use remember feature and prompt=none", func(t *testing.T) { c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) testhelpers.NewLoginConsentUI(t, reg.Config(), - acceptLoginHandler(t, c, subject, nil), - acceptConsentHandler(t, c, subject, nil), + acceptLoginHandler(t, c, adminClient, reg, subject, nil), + acceptConsentHandler(t, c, adminClient, reg, subject, nil), ) oc := testhelpers.NewEmptyJarClient(t) @@ -984,12 +1073,12 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) { // Reset UI to check for skip values testhelpers.NewLoginConsentUI(t, reg.Config(), - acceptLoginHandler(t, c, subject, func(r *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest { + acceptLoginHandler(t, c, adminClient, reg, subject, func(r *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest { require.True(t, r.Skip) require.EqualValues(t, subject, r.Subject) return nil }), - acceptConsentHandler(t, c, subject, func(r *hydra.OAuth2ConsentRequest) { + acceptConsentHandler(t, c, adminClient, reg, subject, func(r *hydra.OAuth2ConsentRequest) { require.True(t, *r.Skip) require.EqualValues(t, subject, *r.Subject) }), @@ -1038,12 +1127,12 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) { t.Run("followup=passes and resets skip when prompt=login", func(t *testing.T) { testhelpers.NewLoginConsentUI(t, reg.Config(), - acceptLoginHandler(t, c, subject, func(r *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest { + acceptLoginHandler(t, c, adminClient, reg, subject, func(r *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest { require.False(t, r.Skip) require.Empty(t, r.Subject) return nil }), - acceptConsentHandler(t, c, subject, func(r *hydra.OAuth2ConsentRequest) { + acceptConsentHandler(t, c, adminClient, reg, subject, func(r *hydra.OAuth2ConsentRequest) { require.True(t, *r.Skip) require.EqualValues(t, subject, *r.Subject) }), @@ -1065,8 +1154,8 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) { t.Run("case=should fail if prompt=none but no auth session given", func(t *testing.T) { c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) testhelpers.NewLoginConsentUI(t, reg.Config(), - acceptLoginHandler(t, c, subject, nil), - acceptConsentHandler(t, c, subject, nil), + acceptLoginHandler(t, c, adminClient, reg, subject, nil), + acceptConsentHandler(t, c, adminClient, reg, subject, nil), ) oc := testhelpers.NewEmptyJarClient(t) @@ -1079,12 +1168,12 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) { t.Run("case=requires re-authentication when id_token_hint is set to a user 'patrik-neu' but the session is 'aeneas-rekkas' and then fails because the user id from the log in endpoint is 'aeneas-rekkas'", func(t *testing.T) { c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) testhelpers.NewLoginConsentUI(t, reg.Config(), - acceptLoginHandler(t, c, subject, func(r *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest { + acceptLoginHandler(t, c, adminClient, reg, subject, func(r *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest { require.False(t, r.Skip) require.Empty(t, r.Subject) return nil }), - acceptConsentHandler(t, c, subject, nil), + acceptConsentHandler(t, c, adminClient, reg, subject, nil), ) oc := testhelpers.NewEmptyJarClient(t) @@ -1103,11 +1192,11 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) { t.Run("case=should not cause issues if max_age is very low and consent takes a long time", func(t *testing.T) { c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) testhelpers.NewLoginConsentUI(t, reg.Config(), - acceptLoginHandler(t, c, subject, func(r *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest { + acceptLoginHandler(t, c, adminClient, reg, subject, func(r *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest { time.Sleep(time.Second * 2) return nil }), - acceptConsentHandler(t, c, subject, nil), + acceptConsentHandler(t, c, adminClient, reg, subject, nil), ) code, _ := getAuthorizeCode(t, conf, nil) @@ -1117,8 +1206,8 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) { t.Run("case=ensure consistent claims returned for userinfo", func(t *testing.T) { c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) testhelpers.NewLoginConsentUI(t, reg.Config(), - acceptLoginHandler(t, c, subject, nil), - acceptConsentHandler(t, c, subject, nil), + acceptLoginHandler(t, c, adminClient, reg, subject, nil), + acceptConsentHandler(t, c, adminClient, reg, subject, nil), ) code, _ := getAuthorizeCode(t, conf, nil) @@ -1203,12 +1292,12 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) { expectAud := "https://api.ory.sh/" c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) testhelpers.NewLoginConsentUI(t, reg.Config(), - acceptLoginHandler(t, c, subject, func(r *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest { + acceptLoginHandler(t, c, adminClient, reg, subject, func(r *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest { assert.False(t, r.Skip) assert.EqualValues(t, []string{expectAud}, r.RequestedAccessTokenAudience) return nil }), - acceptConsentHandler(t, c, subject, func(r *hydra.OAuth2ConsentRequest) { + acceptConsentHandler(t, c, adminClient, reg, subject, func(r *hydra.OAuth2ConsentRequest) { assert.False(t, *r.Skip) assert.EqualValues(t, []string{expectAud}, r.RequestedAccessTokenAudience) })) @@ -1252,12 +1341,12 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) { expectAud := "https://api.ory.sh/" c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) testhelpers.NewLoginConsentUI(t, reg.Config(), - acceptLoginHandler(t, c, subject, func(r *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest { + acceptLoginHandler(t, c, adminClient, reg, subject, func(r *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest { assert.False(t, r.Skip) assert.EqualValues(t, []string{expectAud}, r.RequestedAccessTokenAudience) return nil }), - acceptConsentHandler(t, c, subject, func(r *hydra.OAuth2ConsentRequest) { + acceptConsentHandler(t, c, adminClient, reg, subject, func(r *hydra.OAuth2ConsentRequest) { assert.False(t, *r.Skip) assert.EqualValues(t, []string{expectAud}, r.RequestedAccessTokenAudience) })) @@ -1292,12 +1381,12 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) { expectAud := "https://api.ory.sh/" c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) testhelpers.NewLoginConsentUI(t, reg.Config(), - acceptLoginHandler(t, c, subject, func(r *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest { + acceptLoginHandler(t, c, adminClient, reg, subject, func(r *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest { assert.False(t, r.Skip) assert.EqualValues(t, []string{expectAud}, r.RequestedAccessTokenAudience) return nil }), - acceptConsentHandler(t, c, subject, func(r *hydra.OAuth2ConsentRequest) { + acceptConsentHandler(t, c, adminClient, reg, subject, func(r *hydra.OAuth2ConsentRequest) { assert.False(t, *r.Skip) assert.EqualValues(t, []string{expectAud}, r.RequestedAccessTokenAudience) })) @@ -1332,12 +1421,12 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) { expectAud := "https://api.ory.sh/" c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) testhelpers.NewLoginConsentUI(t, reg.Config(), - acceptLoginHandler(t, c, subject, func(r *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest { + acceptLoginHandler(t, c, adminClient, reg, subject, func(r *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest { assert.False(t, r.Skip) assert.EqualValues(t, []string{expectAud}, r.RequestedAccessTokenAudience) return nil }), - acceptConsentHandler(t, c, subject, func(r *hydra.OAuth2ConsentRequest) { + acceptConsentHandler(t, c, adminClient, reg, subject, func(r *hydra.OAuth2ConsentRequest) { assert.False(t, *r.Skip) assert.EqualValues(t, []string{expectAud}, r.RequestedAccessTokenAudience) })) @@ -1497,7 +1586,7 @@ func TestAuthCodeWithMockStrategy(t *testing.T) { consentStrategy := &consentMock{} router := x.NewRouterPublic() ts := httptest.NewServer(router) - defer ts.Close() + t.Cleanup(ts.Close) reg.WithConsentStrategy(consentStrategy) handler := reg.OAuth2Handler() @@ -1511,7 +1600,7 @@ func TestAuthCodeWithMockStrategy(t *testing.T) { }) var mutex sync.Mutex - require.NoError(t, reg.ClientManager().CreateClient(context.TODO(), &client.Client{ + require.NoError(t, reg.ClientManager().CreateClient(ctx, &client.Client{ ID: "app-client", Secret: "secret", RedirectURIs: []string{ts.URL + "/callback"}, @@ -1873,10 +1962,14 @@ func TestAuthCodeWithMockStrategy(t *testing.T) { if hookType == "legacy" { conf.MustSet(ctx, config.KeyRefreshTokenHook, hs.URL) - defer conf.MustSet(ctx, config.KeyRefreshTokenHook, nil) + t.Cleanup(func() { + conf.MustSet(ctx, config.KeyRefreshTokenHook, nil) + }) } else { conf.MustSet(ctx, config.KeyTokenHook, hs.URL) - defer conf.MustSet(ctx, config.KeyTokenHook, nil) + t.Cleanup(func() { + conf.MustSet(ctx, config.KeyTokenHook, nil) + }) } res, err := testRefresh(t, &refreshedToken, ts.URL, false) @@ -1915,10 +2008,10 @@ func TestAuthCodeWithMockStrategy(t *testing.T) { if hookType == "legacy" { conf.MustSet(ctx, config.KeyRefreshTokenHook, hs.URL) - defer conf.MustSet(ctx, config.KeyRefreshTokenHook, nil) + t.Cleanup(func() { conf.MustSet(ctx, config.KeyRefreshTokenHook, nil) }) } else { conf.MustSet(ctx, config.KeyTokenHook, hs.URL) - defer conf.MustSet(ctx, config.KeyTokenHook, nil) + t.Cleanup(func() { conf.MustSet(ctx, config.KeyTokenHook, nil) }) } origAccessTokenClaims := testhelpers.IntrospectToken(t, oauthConfig, refreshedToken.AccessToken, ts) @@ -1950,10 +2043,10 @@ func TestAuthCodeWithMockStrategy(t *testing.T) { if hookType == "legacy" { conf.MustSet(ctx, config.KeyRefreshTokenHook, hs.URL) - defer conf.MustSet(ctx, config.KeyRefreshTokenHook, nil) + t.Cleanup(func() { conf.MustSet(ctx, config.KeyRefreshTokenHook, nil) }) } else { conf.MustSet(ctx, config.KeyTokenHook, hs.URL) - defer conf.MustSet(ctx, config.KeyTokenHook, nil) + t.Cleanup(func() { conf.MustSet(ctx, config.KeyTokenHook, nil) }) } res, err := testRefresh(t, &refreshedToken, ts.URL, false) @@ -1980,10 +2073,10 @@ func TestAuthCodeWithMockStrategy(t *testing.T) { if hookType == "legacy" { conf.MustSet(ctx, config.KeyRefreshTokenHook, hs.URL) - defer conf.MustSet(ctx, config.KeyRefreshTokenHook, nil) + t.Cleanup(func() { conf.MustSet(ctx, config.KeyRefreshTokenHook, nil) }) } else { conf.MustSet(ctx, config.KeyTokenHook, hs.URL) - defer conf.MustSet(ctx, config.KeyTokenHook, nil) + t.Cleanup(func() { conf.MustSet(ctx, config.KeyTokenHook, nil) }) } res, err := testRefresh(t, &refreshedToken, ts.URL, false) @@ -2010,10 +2103,10 @@ func TestAuthCodeWithMockStrategy(t *testing.T) { if hookType == "legacy" { conf.MustSet(ctx, config.KeyRefreshTokenHook, hs.URL) - defer conf.MustSet(ctx, config.KeyRefreshTokenHook, nil) + t.Cleanup(func() { conf.MustSet(ctx, config.KeyRefreshTokenHook, nil) }) } else { conf.MustSet(ctx, config.KeyTokenHook, hs.URL) - defer conf.MustSet(ctx, config.KeyTokenHook, nil) + t.Cleanup(func() { conf.MustSet(ctx, config.KeyTokenHook, nil) }) } res, err := testRefresh(t, &refreshedToken, ts.URL, false) @@ -2120,6 +2213,7 @@ func newOAuth2Client( return c, &oauth2.Config{ ClientID: c.GetID(), ClientSecret: secret, + RedirectURL: callbackURL, Endpoint: oauth2.Endpoint{ AuthURL: reg.Config().OAuth2AuthURL(ctx).String(), TokenURL: reg.Config().OAuth2TokenURL(ctx).String(), diff --git a/oauth2/oauth2_refresh_token_test.go b/oauth2/oauth2_refresh_token_test.go index 849fae0646..ffabb0dd2a 100644 --- a/oauth2/oauth2_refresh_token_test.go +++ b/oauth2/oauth2_refresh_token_test.go @@ -172,8 +172,10 @@ func TestCreateRefreshTokenSessionStress(t *testing.T) { "RETRY_WRITE_TOO_OLD", // refresh token reuse detection "token_inactive", + // Failed to refresh token because of multiple concurrent requests using the same token which is not allowed. + "multiple concurrent requests", } { - if strings.Contains(e.DebugField, errSubstr) { + if strings.Contains(e.DebugField+e.HintField, errSubstr) { matched = true break } diff --git a/persistence/sql/persister_oauth2.go b/persistence/sql/persister_oauth2.go index 083e67ac5d..091d64aab6 100644 --- a/persistence/sql/persister_oauth2.go +++ b/persistence/sql/persister_oauth2.go @@ -10,6 +10,9 @@ import ( "encoding/hex" "encoding/json" "fmt" + "github.com/gobuffalo/pop/v6" + "github.com/ory/x/dbal" + "go.opentelemetry.io/otel/attribute" "net/url" "strings" "time" @@ -60,7 +63,8 @@ type ( } OAuth2RefreshTable struct { OAuth2RequestSQL - FirstUsedAt sql.NullTime `db:"first_used_at"` + FirstUsedAt sql.NullTime `db:"first_used_at"` + AccessTokenSignature sqlxx.NullString `db:"access_token_signature"` } ) @@ -254,7 +258,7 @@ func (p *Persister) createSession(ctx context.Context, signature string, request } if err = sqlcon.HandleError(p.CreateWithNetwork(ctx, req)); errors.Is(err, sqlcon.ErrConcurrentUpdate) { - return errors.Wrap(fosite.ErrSerializationFailure, err.Error()) + return fosite.ErrSerializationFailure.WithWrap(err) } else if err != nil { return err } @@ -293,7 +297,7 @@ func (p *Persister) deleteSessionBySignature(ctx context.Context, signature stri return errorsx.WithStack(fosite.ErrNotFound) } if errors.Is(err, sqlcon.ErrConcurrentUpdate) { - return errors.Wrap(fosite.ErrSerializationFailure, err.Error()) + return fosite.ErrSerializationFailure.WithWrap(err) } return err } @@ -310,7 +314,7 @@ func (p *Persister) deleteSessionByRequestID(ctx context.Context, id string, tab } if err := sqlcon.HandleError(err); err != nil { if errors.Is(err, sqlcon.ErrConcurrentUpdate) { - return errors.Wrap(fosite.ErrSerializationFailure, err.Error()) + return fosite.ErrSerializationFailure.WithWrap(err) } if strings.Contains(err.Error(), "Error 1213") { // InnoDB Deadlock? return errors.Wrap(fosite.ErrSerializationFailure, err.Error()) @@ -426,7 +430,7 @@ func (p *Persister) DeleteAccessTokenSession(ctx context.Context, signature stri } } if errors.Is(err, sqlcon.ErrConcurrentUpdate) { - return errors.Wrap(fosite.ErrSerializationFailure, err.Error()) + return fosite.ErrSerializationFailure.WithWrap(err) } return err } @@ -612,3 +616,144 @@ func (p *Persister) DeleteAccessTokens(ctx context.Context, clientID string) (er p.QueryWithNetwork(ctx).Where("client_id=?", clientID).Delete(&OAuth2RequestSQL{Table: sqlTableAccess}), ) } + +// ---- + +func (p *Persister) gracefulRefreshRotation(ctx context.Context, c *pop.Connection, requestID string, refreshSignature string, period time.Duration) (err error) { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.gracefulRefreshRotation", + trace.WithAttributes( + attribute.String("request_id", requestID), + attribute.String("refresh_signature", refreshSignature), + attribute.String("network_id", p.NetworkID(ctx).String()), + attribute.String("grace_period", period.String()), + )) + defer otelx.End(span, &err) + + c := p.Connection(ctx) + + if p.conn.Dialect.Name() == dbal.DriverMySQL { + // MySQL does not support returning values from an update query, so we need to do two queries. + var tokensToRevoke []OAuth2RefreshTable + if err := c. + Select("access_token_signature"). + Where("request_id=? AND nid = ? AND active", id, p.NetworkID(ctx)). + Limit(500). + All(&tokensToRevoke); err != nil { + return sqlcon.HandleError(err) + } + + } + + return nil +} + +func (p *Persister) strictRefreshRotation(ctx context.Context, requestID string, refreshSignature string) (err error) { + c := p.Connection(ctx) + now := time.Now().UTC().Round(time.Millisecond) + + // Remove the rotated access token + if err := p.deleteSessionByRequestID(ctx, requestID, sqlTableAccess); err != nil { + return err + } + + // Disable the rotated refresh token. + _, err = c. + Where( + "signature = ? AND nid = ? AND active", + refreshSignature, + p.NetworkID(ctx), + ). + UpdateQuery(&OAuth2RefreshTable{ + OAuth2RequestSQL: OAuth2RequestSQL{Active: false}, + FirstUsedAt: sql.NullTime{Time: now, Valid: true}, + }, "active", "first_used_at") + return sqlcon.HandleError(err) +} + +func handleRetryError(err error) error { + if errors.Is(err, sqlcon.ErrConcurrentUpdate) { + return fosite.ErrSerializationFailure.WithWrap(err) + } + if strings.Contains(err.Error(), "Error 1213") { // InnoDB Deadlock + return errors.Wrap(fosite.ErrSerializationFailure, err.Error()) + } + return nil +} + +func (p *Persister) RotateRefreshToken(ctx context.Context, refreshSignature string) (requestID string, err error) { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.RotateRefreshToken") + defer otelx.End(span, &err) + + if gracePeriod := p.r.Config().RefreshTokenRotationGracePeriod(ctx); gracePeriod > 0 { + return handleRetryError(p.gracefulRefreshRotation(ctx, refreshSignature, gracePeriod)) + } + + return handleRetryError(p.strictRefreshRotation(ctx, refreshSignature)) + + return requestID, nil + + /* #nosec G201 table is static */ + + /* + err = p.QueryWithNetwork(ctx). + Where("request_id=?", id). + Delete(&OAuth2RequestSQL{Table: sqlTableAccess}) + if errors.Is(err, sql.ErrNoRows) { + return errorsx.WithStack(fosite.ErrNotFound) + } + + p.Transaction(ctx, func(ctx context.Context, c *sql.Tx) error { + type tokens []revokingRefreshToken + if p.conn.Dialect.Name() == dbal.DriverMySQL { + // MySQL does not support returning values from an update query, so we need to do two queries. + var t tokens + if err := p.Connection(ctx).Where("request_id=? AND nid = ? AND active", id, p.NetworkID(ctx)).Limit(500).All(&t); err != nil { + return sqlcon.HandleError(err) + } + + } else { + + } + + }) + + if err := p.Connection(ctx).RawQuery(` + SELECT access_token_signature, signature, first_used_at + FROM hydra_oauth2_refresh + WHERE request_id=? AND nid = ? AND active + ORDER BY signature LIMIT 500 + `).All(&tokens); err != nil { + return err + } + + p.Connection(ctx).Where("signature IN (?)", _).Limit(500).UpdateQuery(&OAuth2RequestSQL{ + + Table: sqlTableRefresh, + }, + "active", "first_used_at") + p.Connection(ctx).RawQuery(` + UPDATE hydra_oauth2_refresh + SET active=false, first_used_at = CURRENT_TIMESTAMP + WHERE signature in = (?) + LIMIT 500 + `).All(&t) + + p.Connection(ctx).RawQuery(` + UPDATE hydra_oauth2_refresh + SET active=false, first_used_at = CURRENT_TIMESTAMP + WHERE request_id=? AND nid = ? AND active + RETURNING access_token_signature + LIMIT 500 + `).All(&t) + + // mysql: + // "GET access_token_signature, id (?) WHERE request_id=? AND nid = ? AND active A RETURNING access_token_signature ORDER BY signature LIMIT 500" + // "UPDATE ... SET ..." + + // others: + // "UPDATE refresh SET active=false, first_used_at = CURRENT_TIMESTAMP WHERE request_id=? AND nid = ? AND active A RETURNING access_token_signature ORDER BY signature LIMIT 500" + + // "UPDATE access SET active=false WHERE request_id=? AND signature IN (?) LIMIT 500" + */ + /* #nosec G201 table is static */ +}