Skip to content

Commit

Permalink
chore: synchronize workspaces
Browse files Browse the repository at this point in the history
  • Loading branch information
aeneasr committed Nov 29, 2024
1 parent 01fba86 commit 21e9a9d
Showing 1 changed file with 119 additions and 112 deletions.
231 changes: 119 additions & 112 deletions persistence/sql/persister_oauth2.go
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,7 @@ func (p *Persister) CreateRefreshTokenSession(ctx context.Context, signature str
return p.createSession(ctx, signature, requester, sqlTableRefresh, requester.GetSession().GetExpiresAt(fosite.RefreshToken).UTC())
}

func (p *Persister) RotateRefreshToken(ctx context.Context, refreshTokenSignature string) (requestID string, err error) {
func (p *Persister) GetRefreshTokenSession(ctx context.Context, signature string, session fosite.Session) (request fosite.Requester, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetRefreshTokenSession")
defer otelx.End(span, &err)

Expand Down Expand Up @@ -540,6 +540,85 @@ func (p *Persister) RevokeRefreshToken(ctx context.Context, id string) (err erro
return p.deactivateSessionByRequestID(ctx, id, sqlTableRefresh)
}

func (p *Persister) RevokeRefreshTokenMaybeGracePeriod(ctx context.Context, id string, _ string) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.RevokeRefreshTokenMaybeGracePeriod")
defer otelx.End(span, &err)

/* #nosec G201 table is static */
return sqlcon.HandleError(
p.Connection(ctx).
RawQuery(
fmt.Sprintf("UPDATE %s SET active=false, first_used_at = CURRENT_TIMESTAMP WHERE request_id=? AND nid = ? AND active", OAuth2RequestSQL{Table: sqlTableRefresh}.TableName()),
id,
p.NetworkID(ctx),
).
Exec(),
)
}

func (p *Persister) RevokeAccessToken(ctx context.Context, id string) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.RevokeAccessToken")
defer otelx.End(span, &err)
return p.deleteSessionByRequestID(ctx, id, sqlTableAccess)
}

func (p *Persister) flushInactiveTokens(ctx context.Context, notAfter time.Time, limit int, batchSize int, table tableName, lifespan time.Duration) (err error) {
/* #nosec G201 table is static */
// The value of notAfter should be the minimum between input parameter and token max expire based on its configured age
requestMaxExpire := time.Now().Add(-lifespan)
if requestMaxExpire.Before(notAfter) {
notAfter = requestMaxExpire
}

totalDeletedCount := 0
for deletedRecords := batchSize; totalDeletedCount < limit && deletedRecords == batchSize; {
d := batchSize
if limit-totalDeletedCount < batchSize {
d = limit - totalDeletedCount
}
// Delete in batches
// The outer SELECT is necessary because our version of MySQL doesn't yet support 'LIMIT & IN/ALL/ANY/SOME subquery
deletedRecords, err = p.Connection(ctx).RawQuery(
fmt.Sprintf(`DELETE FROM %s WHERE signature in (
SELECT signature FROM (SELECT signature FROM %s hoa WHERE requested_at < ? and nid = ? ORDER BY requested_at LIMIT %d ) as s
)`, OAuth2RequestSQL{Table: table}.TableName(), OAuth2RequestSQL{Table: table}.TableName(), d),
notAfter,
p.NetworkID(ctx),
).ExecWithCount()
totalDeletedCount += deletedRecords

if err != nil {
break
}
p.l.Debugf("Flushing tokens...: %d/%d", totalDeletedCount, limit)
}
p.l.Debugf("Flush Refresh Tokens flushed_records: %d", totalDeletedCount)
return sqlcon.HandleError(err)
}

func (p *Persister) FlushInactiveAccessTokens(ctx context.Context, notAfter time.Time, limit int, batchSize int) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.FlushInactiveAccessTokens")
defer otelx.End(span, &err)
return p.flushInactiveTokens(ctx, notAfter, limit, batchSize, sqlTableAccess, p.config.GetAccessTokenLifespan(ctx))
}

func (p *Persister) FlushInactiveRefreshTokens(ctx context.Context, notAfter time.Time, limit int, batchSize int) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.FlushInactiveRefreshTokens")
defer otelx.End(span, &err)
return p.flushInactiveTokens(ctx, notAfter, limit, batchSize, sqlTableRefresh, p.config.GetRefreshTokenLifespan(ctx))
}

func (p *Persister) DeleteAccessTokens(ctx context.Context, clientID string) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.DeleteAccessTokens")
defer otelx.End(span, &err)
/* #nosec G201 table is static */
return sqlcon.HandleError(
p.QueryWithNetwork(ctx).Where("client_id=?", clientID).Delete(&OAuth2RequestSQL{Table: sqlTableAccess}),
)
}

// ----

func (p *Persister) gracefulRefreshRotation(ctx context.Context, c *pop.Connection, requestID string, refreshSignature string, period time.Duration) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.gracefulRefreshRotation",
trace.WithAttributes(
Expand All @@ -550,6 +629,8 @@ func (p *Persister) gracefulRefreshRotation(ctx context.Context, c *pop.Connecti
))
defer otelx.End(span, &err)

c := p.Connection(ctx)

if p.conn.Dialect.Name() == dbal.DriverMySQL {
// MySQL does not support returning values from an update query, so we need to do two queries.
var tokensToRevoke []OAuth2RefreshTable
Expand All @@ -566,47 +647,50 @@ func (p *Persister) gracefulRefreshRotation(ctx context.Context, c *pop.Connecti
return nil
}

func (p *Persister) RevokeRotatedTokens(ctx context.Context, refreshSignature string) (fosite.Requester, error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.RevokeRotatedTokens")
defer otelx.End(span, &err)
func (p *Persister) strictRefreshRotation(ctx context.Context, requestID string, refreshSignature string) (err error) {
c := p.Connection(ctx)
now := time.Now().UTC().Round(time.Millisecond)

err = p.QueryWithNetwork(ctx).
Where("request_id=?", id).
Delete(&OAuth2RequestSQL{Table: sqlTableAccess})
if errors.Is(err, sql.ErrNoRows) {
return errorsx.WithStack(fosite.ErrNotFound)
// Remove the rotated access token
if err := p.deleteSessionByRequestID(ctx, requestID, sqlTableAccess); err != nil {
return err
}

if err := p.Transaction(ctx, func(ctx context.Context, c *pop.Connection) error {
if gracePeriod := p.r.Config().RefreshTokenRotationGracePeriod(ctx); gracePeriod > 0 {
return p.gracefulRefreshRotation(ctx, c, requestID, refreshSignature, gracePeriod)
}
// Disable the rotated refresh token.
_, err = c.
Where(
"signature = ? AND nid = ? AND active",
refreshSignature,
p.NetworkID(ctx),
).
UpdateQuery(&OAuth2RefreshTable{
OAuth2RequestSQL: OAuth2RequestSQL{Active: false},
FirstUsedAt: sql.NullTime{Time: now, Valid: true},
}, "active", "first_used_at")
return sqlcon.HandleError(err)
}

if err := p.deleteSessionByRequestID(ctx, requestID, sqlTableAccess); err != nil {
return err
}
func handleRetryError(err error) error {
if errors.Is(err, sqlcon.ErrConcurrentUpdate) {
return fosite.ErrSerializationFailure.WithWrap(err)
}
if strings.Contains(err.Error(), "Error 1213") { // InnoDB Deadlock
return errors.Wrap(fosite.ErrSerializationFailure, err.Error())
}
return nil
}

_, err := c.Where("signature = ? AND nid = ? AND active", refreshSignature, p.NetworkID(ctx)).UpdateQuery(&OAuth2RefreshTable{
OAuth2RequestSQL: OAuth2RequestSQL{
Active: false,
},
FirstUsedAt: sql.NullTime{
Time: time.Now().UTC().Round(time.Millisecond),
Valid: true,
},
}, "active", "first_used_at")
return sqlcon.HandleError(err)
}); err != nil {
if errors.Is(err, sqlcon.ErrConcurrentUpdate) {
return fosite.ErrSerializationFailure.WithWrap(err)
}
if strings.Contains(err.Error(), "Error 1213") { // InnoDB Deadlock
return errors.Wrap(fosite.ErrSerializationFailure, err.Error())
}
return err
func (p *Persister) RotateRefreshToken(ctx context.Context, refreshSignature string) (requestID string, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.RotateRefreshToken")
defer otelx.End(span, &err)

if gracePeriod := p.r.Config().RefreshTokenRotationGracePeriod(ctx); gracePeriod > 0 {
return handleRetryError(p.gracefulRefreshRotation(ctx, refreshSignature, gracePeriod))
}

return nil
return handleRetryError(p.strictRefreshRotation(ctx, refreshSignature))

return requestID, nil

/* #nosec G201 table is static */

Expand Down Expand Up @@ -673,80 +757,3 @@ func (p *Persister) RevokeRotatedTokens(ctx context.Context, refreshSignature st
*/
/* #nosec G201 table is static */
}

func (p *Persister) RevokeRefreshTokenMaybeGracePeriod(ctx context.Context, id string, _ string) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.RevokeRefreshTokenMaybeGracePeriod")
defer otelx.End(span, &err)

/* #nosec G201 table is static */
return sqlcon.HandleError(
p.Connection(ctx).
RawQuery(
fmt.Sprintf("UPDATE %s SET active=false, first_used_at = CURRENT_TIMESTAMP WHERE request_id=? AND nid = ? AND active LIMIT 500", OAuth2RequestSQL{Table: sqlTableRefresh}.TableName()),
id,
p.NetworkID(ctx),
).
Exec(),
)
}

func (p *Persister) RevokeAccessToken(ctx context.Context, id string) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.RevokeAccessToken")
defer otelx.End(span, &err)
return p.deleteSessionByRequestID(ctx, id, sqlTableAccess)
}

func (p *Persister) flushInactiveTokens(ctx context.Context, notAfter time.Time, limit int, batchSize int, table tableName, lifespan time.Duration) (err error) {
/* #nosec G201 table is static */
// The value of notAfter should be the minimum between input parameter and token max expire based on its configured age
requestMaxExpire := time.Now().Add(-lifespan)
if requestMaxExpire.Before(notAfter) {
notAfter = requestMaxExpire
}

totalDeletedCount := 0
for deletedRecords := batchSize; totalDeletedCount < limit && deletedRecords == batchSize; {
d := batchSize
if limit-totalDeletedCount < batchSize {
d = limit - totalDeletedCount
}
// Delete in batches
// The outer SELECT is necessary because our version of MySQL doesn't yet support 'LIMIT & IN/ALL/ANY/SOME subquery
deletedRecords, err = p.Connection(ctx).RawQuery(
fmt.Sprintf(`DELETE FROM %s WHERE signature in (
SELECT signature FROM (SELECT signature FROM %s hoa WHERE requested_at < ? and nid = ? ORDER BY requested_at LIMIT %d ) as s
)`, OAuth2RequestSQL{Table: table}.TableName(), OAuth2RequestSQL{Table: table}.TableName(), d),
notAfter,
p.NetworkID(ctx),
).ExecWithCount()
totalDeletedCount += deletedRecords

if err != nil {
break
}
p.l.Debugf("Flushing tokens...: %d/%d", totalDeletedCount, limit)
}
p.l.Debugf("Flush Refresh Tokens flushed_records: %d", totalDeletedCount)
return sqlcon.HandleError(err)
}

func (p *Persister) FlushInactiveAccessTokens(ctx context.Context, notAfter time.Time, limit int, batchSize int) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.FlushInactiveAccessTokens")
defer otelx.End(span, &err)
return p.flushInactiveTokens(ctx, notAfter, limit, batchSize, sqlTableAccess, p.config.GetAccessTokenLifespan(ctx))
}

func (p *Persister) FlushInactiveRefreshTokens(ctx context.Context, notAfter time.Time, limit int, batchSize int) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.FlushInactiveRefreshTokens")
defer otelx.End(span, &err)
return p.flushInactiveTokens(ctx, notAfter, limit, batchSize, sqlTableRefresh, p.config.GetRefreshTokenLifespan(ctx))
}

func (p *Persister) DeleteAccessTokens(ctx context.Context, clientID string) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.DeleteAccessTokens")
defer otelx.End(span, &err)
/* #nosec G201 table is static */
return sqlcon.HandleError(
p.QueryWithNetwork(ctx).Where("client_id=?", clientID).Delete(&OAuth2RequestSQL{Table: sqlTableAccess}),
)
}

0 comments on commit 21e9a9d

Please sign in to comment.