From 21e9a9d3a4e27f6792ba2fcf61578a5fc0be7f85 Mon Sep 17 00:00:00 2001 From: aeneasr <3372410+aeneasr@users.noreply.github.com> Date: Fri, 29 Nov 2024 10:16:57 +0100 Subject: [PATCH] chore: synchronize workspaces --- persistence/sql/persister_oauth2.go | 231 ++++++++++++++-------------- 1 file changed, 119 insertions(+), 112 deletions(-) diff --git a/persistence/sql/persister_oauth2.go b/persistence/sql/persister_oauth2.go index 071bbe0550..091d64aab6 100644 --- a/persistence/sql/persister_oauth2.go +++ b/persistence/sql/persister_oauth2.go @@ -456,7 +456,7 @@ func (p *Persister) CreateRefreshTokenSession(ctx context.Context, signature str return p.createSession(ctx, signature, requester, sqlTableRefresh, requester.GetSession().GetExpiresAt(fosite.RefreshToken).UTC()) } -func (p *Persister) RotateRefreshToken(ctx context.Context, refreshTokenSignature string) (requestID string, err error) { +func (p *Persister) GetRefreshTokenSession(ctx context.Context, signature string, session fosite.Session) (request fosite.Requester, err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetRefreshTokenSession") defer otelx.End(span, &err) @@ -540,6 +540,85 @@ func (p *Persister) RevokeRefreshToken(ctx context.Context, id string) (err erro return p.deactivateSessionByRequestID(ctx, id, sqlTableRefresh) } +func (p *Persister) RevokeRefreshTokenMaybeGracePeriod(ctx context.Context, id string, _ string) (err error) { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.RevokeRefreshTokenMaybeGracePeriod") + defer otelx.End(span, &err) + + /* #nosec G201 table is static */ + return sqlcon.HandleError( + p.Connection(ctx). + RawQuery( + fmt.Sprintf("UPDATE %s SET active=false, first_used_at = CURRENT_TIMESTAMP WHERE request_id=? AND nid = ? AND active", OAuth2RequestSQL{Table: sqlTableRefresh}.TableName()), + id, + p.NetworkID(ctx), + ). + Exec(), + ) +} + +func (p *Persister) RevokeAccessToken(ctx context.Context, id string) (err error) { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.RevokeAccessToken") + defer otelx.End(span, &err) + return p.deleteSessionByRequestID(ctx, id, sqlTableAccess) +} + +func (p *Persister) flushInactiveTokens(ctx context.Context, notAfter time.Time, limit int, batchSize int, table tableName, lifespan time.Duration) (err error) { + /* #nosec G201 table is static */ + // The value of notAfter should be the minimum between input parameter and token max expire based on its configured age + requestMaxExpire := time.Now().Add(-lifespan) + if requestMaxExpire.Before(notAfter) { + notAfter = requestMaxExpire + } + + totalDeletedCount := 0 + for deletedRecords := batchSize; totalDeletedCount < limit && deletedRecords == batchSize; { + d := batchSize + if limit-totalDeletedCount < batchSize { + d = limit - totalDeletedCount + } + // Delete in batches + // The outer SELECT is necessary because our version of MySQL doesn't yet support 'LIMIT & IN/ALL/ANY/SOME subquery + deletedRecords, err = p.Connection(ctx).RawQuery( + fmt.Sprintf(`DELETE FROM %s WHERE signature in ( + SELECT signature FROM (SELECT signature FROM %s hoa WHERE requested_at < ? and nid = ? ORDER BY requested_at LIMIT %d ) as s + )`, OAuth2RequestSQL{Table: table}.TableName(), OAuth2RequestSQL{Table: table}.TableName(), d), + notAfter, + p.NetworkID(ctx), + ).ExecWithCount() + totalDeletedCount += deletedRecords + + if err != nil { + break + } + p.l.Debugf("Flushing tokens...: %d/%d", totalDeletedCount, limit) + } + p.l.Debugf("Flush Refresh Tokens flushed_records: %d", totalDeletedCount) + return sqlcon.HandleError(err) +} + +func (p *Persister) FlushInactiveAccessTokens(ctx context.Context, notAfter time.Time, limit int, batchSize int) (err error) { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.FlushInactiveAccessTokens") + defer otelx.End(span, &err) + return p.flushInactiveTokens(ctx, notAfter, limit, batchSize, sqlTableAccess, p.config.GetAccessTokenLifespan(ctx)) +} + +func (p *Persister) FlushInactiveRefreshTokens(ctx context.Context, notAfter time.Time, limit int, batchSize int) (err error) { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.FlushInactiveRefreshTokens") + defer otelx.End(span, &err) + return p.flushInactiveTokens(ctx, notAfter, limit, batchSize, sqlTableRefresh, p.config.GetRefreshTokenLifespan(ctx)) +} + +func (p *Persister) DeleteAccessTokens(ctx context.Context, clientID string) (err error) { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.DeleteAccessTokens") + defer otelx.End(span, &err) + /* #nosec G201 table is static */ + return sqlcon.HandleError( + p.QueryWithNetwork(ctx).Where("client_id=?", clientID).Delete(&OAuth2RequestSQL{Table: sqlTableAccess}), + ) +} + +// ---- + func (p *Persister) gracefulRefreshRotation(ctx context.Context, c *pop.Connection, requestID string, refreshSignature string, period time.Duration) (err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.gracefulRefreshRotation", trace.WithAttributes( @@ -550,6 +629,8 @@ func (p *Persister) gracefulRefreshRotation(ctx context.Context, c *pop.Connecti )) defer otelx.End(span, &err) + c := p.Connection(ctx) + if p.conn.Dialect.Name() == dbal.DriverMySQL { // MySQL does not support returning values from an update query, so we need to do two queries. var tokensToRevoke []OAuth2RefreshTable @@ -566,47 +647,50 @@ func (p *Persister) gracefulRefreshRotation(ctx context.Context, c *pop.Connecti return nil } -func (p *Persister) RevokeRotatedTokens(ctx context.Context, refreshSignature string) (fosite.Requester, error) { - ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.RevokeRotatedTokens") - defer otelx.End(span, &err) +func (p *Persister) strictRefreshRotation(ctx context.Context, requestID string, refreshSignature string) (err error) { + c := p.Connection(ctx) + now := time.Now().UTC().Round(time.Millisecond) - err = p.QueryWithNetwork(ctx). - Where("request_id=?", id). - Delete(&OAuth2RequestSQL{Table: sqlTableAccess}) - if errors.Is(err, sql.ErrNoRows) { - return errorsx.WithStack(fosite.ErrNotFound) + // Remove the rotated access token + if err := p.deleteSessionByRequestID(ctx, requestID, sqlTableAccess); err != nil { + return err } - if err := p.Transaction(ctx, func(ctx context.Context, c *pop.Connection) error { - if gracePeriod := p.r.Config().RefreshTokenRotationGracePeriod(ctx); gracePeriod > 0 { - return p.gracefulRefreshRotation(ctx, c, requestID, refreshSignature, gracePeriod) - } + // Disable the rotated refresh token. + _, err = c. + Where( + "signature = ? AND nid = ? AND active", + refreshSignature, + p.NetworkID(ctx), + ). + UpdateQuery(&OAuth2RefreshTable{ + OAuth2RequestSQL: OAuth2RequestSQL{Active: false}, + FirstUsedAt: sql.NullTime{Time: now, Valid: true}, + }, "active", "first_used_at") + return sqlcon.HandleError(err) +} - if err := p.deleteSessionByRequestID(ctx, requestID, sqlTableAccess); err != nil { - return err - } +func handleRetryError(err error) error { + if errors.Is(err, sqlcon.ErrConcurrentUpdate) { + return fosite.ErrSerializationFailure.WithWrap(err) + } + if strings.Contains(err.Error(), "Error 1213") { // InnoDB Deadlock + return errors.Wrap(fosite.ErrSerializationFailure, err.Error()) + } + return nil +} - _, err := c.Where("signature = ? AND nid = ? AND active", refreshSignature, p.NetworkID(ctx)).UpdateQuery(&OAuth2RefreshTable{ - OAuth2RequestSQL: OAuth2RequestSQL{ - Active: false, - }, - FirstUsedAt: sql.NullTime{ - Time: time.Now().UTC().Round(time.Millisecond), - Valid: true, - }, - }, "active", "first_used_at") - return sqlcon.HandleError(err) - }); err != nil { - if errors.Is(err, sqlcon.ErrConcurrentUpdate) { - return fosite.ErrSerializationFailure.WithWrap(err) - } - if strings.Contains(err.Error(), "Error 1213") { // InnoDB Deadlock - return errors.Wrap(fosite.ErrSerializationFailure, err.Error()) - } - return err +func (p *Persister) RotateRefreshToken(ctx context.Context, refreshSignature string) (requestID string, err error) { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.RotateRefreshToken") + defer otelx.End(span, &err) + + if gracePeriod := p.r.Config().RefreshTokenRotationGracePeriod(ctx); gracePeriod > 0 { + return handleRetryError(p.gracefulRefreshRotation(ctx, refreshSignature, gracePeriod)) } - return nil + return handleRetryError(p.strictRefreshRotation(ctx, refreshSignature)) + + return requestID, nil /* #nosec G201 table is static */ @@ -673,80 +757,3 @@ func (p *Persister) RevokeRotatedTokens(ctx context.Context, refreshSignature st */ /* #nosec G201 table is static */ } - -func (p *Persister) RevokeRefreshTokenMaybeGracePeriod(ctx context.Context, id string, _ string) (err error) { - ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.RevokeRefreshTokenMaybeGracePeriod") - defer otelx.End(span, &err) - - /* #nosec G201 table is static */ - return sqlcon.HandleError( - p.Connection(ctx). - RawQuery( - fmt.Sprintf("UPDATE %s SET active=false, first_used_at = CURRENT_TIMESTAMP WHERE request_id=? AND nid = ? AND active LIMIT 500", OAuth2RequestSQL{Table: sqlTableRefresh}.TableName()), - id, - p.NetworkID(ctx), - ). - Exec(), - ) -} - -func (p *Persister) RevokeAccessToken(ctx context.Context, id string) (err error) { - ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.RevokeAccessToken") - defer otelx.End(span, &err) - return p.deleteSessionByRequestID(ctx, id, sqlTableAccess) -} - -func (p *Persister) flushInactiveTokens(ctx context.Context, notAfter time.Time, limit int, batchSize int, table tableName, lifespan time.Duration) (err error) { - /* #nosec G201 table is static */ - // The value of notAfter should be the minimum between input parameter and token max expire based on its configured age - requestMaxExpire := time.Now().Add(-lifespan) - if requestMaxExpire.Before(notAfter) { - notAfter = requestMaxExpire - } - - totalDeletedCount := 0 - for deletedRecords := batchSize; totalDeletedCount < limit && deletedRecords == batchSize; { - d := batchSize - if limit-totalDeletedCount < batchSize { - d = limit - totalDeletedCount - } - // Delete in batches - // The outer SELECT is necessary because our version of MySQL doesn't yet support 'LIMIT & IN/ALL/ANY/SOME subquery - deletedRecords, err = p.Connection(ctx).RawQuery( - fmt.Sprintf(`DELETE FROM %s WHERE signature in ( - SELECT signature FROM (SELECT signature FROM %s hoa WHERE requested_at < ? and nid = ? ORDER BY requested_at LIMIT %d ) as s - )`, OAuth2RequestSQL{Table: table}.TableName(), OAuth2RequestSQL{Table: table}.TableName(), d), - notAfter, - p.NetworkID(ctx), - ).ExecWithCount() - totalDeletedCount += deletedRecords - - if err != nil { - break - } - p.l.Debugf("Flushing tokens...: %d/%d", totalDeletedCount, limit) - } - p.l.Debugf("Flush Refresh Tokens flushed_records: %d", totalDeletedCount) - return sqlcon.HandleError(err) -} - -func (p *Persister) FlushInactiveAccessTokens(ctx context.Context, notAfter time.Time, limit int, batchSize int) (err error) { - ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.FlushInactiveAccessTokens") - defer otelx.End(span, &err) - return p.flushInactiveTokens(ctx, notAfter, limit, batchSize, sqlTableAccess, p.config.GetAccessTokenLifespan(ctx)) -} - -func (p *Persister) FlushInactiveRefreshTokens(ctx context.Context, notAfter time.Time, limit int, batchSize int) (err error) { - ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.FlushInactiveRefreshTokens") - defer otelx.End(span, &err) - return p.flushInactiveTokens(ctx, notAfter, limit, batchSize, sqlTableRefresh, p.config.GetRefreshTokenLifespan(ctx)) -} - -func (p *Persister) DeleteAccessTokens(ctx context.Context, clientID string) (err error) { - ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.DeleteAccessTokens") - defer otelx.End(span, &err) - /* #nosec G201 table is static */ - return sqlcon.HandleError( - p.QueryWithNetwork(ctx).Where("client_id=?", clientID).Delete(&OAuth2RequestSQL{Table: sqlTableAccess}), - ) -}