Skip to content

Commit

Permalink
Refactor RedisDB's ReadPage function to use a PageToken struct for cu…
Browse files Browse the repository at this point in the history
…rsor and offset encoding, and update readAllKeys to support an offset parameter
  • Loading branch information
Radu Ifrim committed Jul 27, 2023
1 parent ff12b6e commit 0b28b1b
Showing 1 changed file with 75 additions and 36 deletions.
111 changes: 75 additions & 36 deletions pkg/storage/redis.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@ package storage

import (
"context"
"encoding/base64"
"fmt"
"strconv"
"time"

"github.com/cenkalti/backoff/v4"
"github.com/goccy/go-json"
"github.com/pkg/errors"
"github.com/redis/go-redis/extra/redisotel/v9"
goredislib "github.com/redis/go-redis/v9"
Expand All @@ -31,24 +32,59 @@ type RedisDB struct {
}

func (b *RedisDB) ReadPage(ctx context.Context, namespace string, pageToken string, pageSize int) (map[string][]byte, string, error) {
cursor := uint64(0)
token := new(PageToken)
if pageToken != "" {
var err error
cursor, err = strconv.ParseUint(pageToken, 10, 64)
token, err = parseToken(pageToken)
if err != nil {
return nil, "", errors.Wrap(err, "parsing page token")
return nil, "", err
}
}

keys, nextCursor, err := readAllKeys(ctx, namespace, b, pageSize, cursor)
keys, nextCursor, offsetFromCursor, err := readAllKeys(ctx, namespace, b, pageSize, token.Cursor, token.OffsetFromCursor)
if err != nil {
return nil, "", err
}
results, err := readAll(ctx, namespace, keys, b)
if err != nil {
return nil, "", err
}
return results, nextCursor, nil
nextPageToken := PageToken{
Cursor: nextCursor,
OffsetFromCursor: offsetFromCursor,
}
encodedToken, err := encodeToken(nextPageToken)
if err != nil {
return nil, "", err
}
return results, encodedToken, nil
}

type PageToken struct {
Cursor uint64
OffsetFromCursor int
}

func parseToken(pageToken string) (*PageToken, error) {
pageTokenData, err := base64.RawURLEncoding.DecodeString(pageToken)
if err != nil {
return nil, errors.Wrap(err, "decoding page token")
}

var token PageToken
if err := json.Unmarshal(pageTokenData, &token); err != nil {
return nil, errors.Wrap(err, "unmarshalling page token data")
}

return &token, nil
}

func encodeToken(token PageToken) (string, error) {
data, err := json.Marshal(token)
if err != nil {
return "", errors.Wrap(err, "marshalling page token")
}
return base64.RawURLEncoding.EncodeToString(data), nil
}

var _ ServiceStorage = (*RedisDB)(nil)
Expand Down Expand Up @@ -234,7 +270,7 @@ func (b *RedisDB) Read(ctx context.Context, namespace, key string) ([]byte, erro
func (b *RedisDB) ReadPrefix(ctx context.Context, namespace, prefix string) (map[string][]byte, error) {
namespacePrefix := getRedisKey(namespace, prefix)

keys, _, err := readAllKeys(ctx, namespacePrefix, b, -1, 0)
keys, _, _, err := readAllKeys(ctx, namespacePrefix, b, -1, 0, 0)
if err != nil {
return nil, errors.Wrap(err, "read all keys")
}
Expand All @@ -243,7 +279,7 @@ func (b *RedisDB) ReadPrefix(ctx context.Context, namespace, prefix string) (map
}

func (b *RedisDB) ReadAll(ctx context.Context, namespace string) (map[string][]byte, error) {
keys, _, err := readAllKeys(ctx, namespace, b, -1, 0)
keys, _, _, err := readAllKeys(ctx, namespace, b, -1, 0, 0)
if err != nil {
return nil, errors.Wrap(err, "read all keys")
}
Expand Down Expand Up @@ -280,7 +316,7 @@ func readAll(ctx context.Context, namespace string, keys []string, b *RedisDB) (
}

func (b *RedisDB) ReadAllKeys(ctx context.Context, namespace string) ([]string, error) {
keys, _, err := readAllKeys(ctx, namespace, b, -1, 0)
keys, _, _, err := readAllKeys(ctx, namespace, b, -1, 0, 0)
if err != nil {
return nil, err
}
Expand All @@ -300,43 +336,46 @@ func (b *RedisDB) ReadAllKeys(ctx context.Context, namespace string) ([]string,

// NOTE: When passing pageSize == -1, **all** items are returns. Exercise caution regarding memory limits. Always
// prefer to set the pageSize.
func readAllKeys(ctx context.Context, namespace string, b *RedisDB, pageSize int, cursor uint64) ([]string, string, error) {

var allKeys []string
func readAllKeys(ctx context.Context, namespace string, b *RedisDB, pageSize int, cursor uint64, offset int) ([]string, uint64, int, error) {

var nextCursor uint64
var err error
var keys []string
scanCount := RedisScanBatchSize
if pageSize != -1 {
scanCount = min(RedisScanBatchSize, pageSize)
}
// Scan keys starting at cursor until the end or until we have enough keys
var scannedKeys []string
var err error
nextCursor := cursor

for {
keys, nextCursor, err = b.db.Scan(ctx, cursor, namespace+"*", int64(scanCount)).Result()
scannedKeys, nextCursor, err = b.db.Scan(ctx, nextCursor, namespace+"*", int64(RedisScanBatchSize)).Result()
if err != nil {
return nil, "", errors.Wrap(err, "scan error")
return keys, 0, 0, err
}
// Append keys one by one to ensure we don't exceed the page size
for _, key := range keys {
if pageSize != -1 && len(allKeys) >= pageSize {
break

if len(scannedKeys) == 0 {
break
}

// Apply offset
if offset > 0 {
if offset >= len(scannedKeys) {
// Offset past end of results
offset -= len(scannedKeys)
continue
}
allKeys = append(allKeys, key)

scannedKeys = scannedKeys[offset:]
offset = 0
}
// If we have enough keys or we reached the end, break
if nextCursor == 0 || (pageSize != -1 && len(allKeys) >= pageSize) {

// Append scanned keys
keys = append(keys, scannedKeys...)

// Break if we have enough keys
if len(keys) >= pageSize && pageSize != -1 {
keys = keys[:pageSize]
break
}

cursor = nextCursor
}

var nextCursorToReturn string
if nextCursor != 0 {
nextCursorToReturn = strconv.FormatUint(nextCursor, 10)
}
return allKeys, nextCursorToReturn, nil
return keys, nextCursor, offset, nil
}

func min(l int, r int) int {

Check failure on line 381 in pkg/storage/redis.go

View workflow job for this annotation

GitHub Actions / lint

func `min` is unused (unused)
Expand Down Expand Up @@ -365,7 +404,7 @@ func (b *RedisDB) Delete(ctx context.Context, namespace, key string) error {
}

func (b *RedisDB) DeleteNamespace(ctx context.Context, namespace string) error {
keys, _, err := readAllKeys(ctx, namespace, b, -1, 0)
keys, _, _, err := readAllKeys(ctx, namespace, b, -1, 0, 0)
if err != nil {
return errors.Wrap(err, "read all keys")
}
Expand Down

0 comments on commit 0b28b1b

Please sign in to comment.