Skip to content

Commit

Permalink
go: Use JWKS for key rotation
Browse files Browse the repository at this point in the history
Instead of trying all possible pub keys one-by-one, build a map of kid
(fingerprint) to pub key and use the kid header in jwt to select the key
to use. If either the kid header is absent or we don't have any key with
matching fingerprint, fallback to the first (usually current) key from
the versioned secret.

This also allows us to switch from reddit's forked version of jwt back
to the upstream version.
  • Loading branch information
fishy committed Mar 16, 2022
1 parent 4b9526c commit 7499b9c
Show file tree
Hide file tree
Showing 6 changed files with 289 additions and 24 deletions.
3 changes: 2 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ go 1.16
require (
github.com/apache/thrift v0.14.1
github.com/gofrs/uuid v3.2.0+incompatible
github.com/golang-jwt/jwt/v4 v4.3.0
github.com/reddit/baseplate.go v0.8.0
github.com/reddit/jwt-go/v3 v3.2.2
golang.org/x/crypto v0.0.0-20200510223506-06a226fb4e37
)
6 changes: 3 additions & 3 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dgraph-io/badger v1.6.0/go.mod h1:zwt7syl517jmP8s94KqSxTlM6IMsdhYy6psNgSztDR4=
github.com/dgrijalva/jwt-go v3.2.0+incompatible h1:7qlOGliEKZXTDg6OTjfoBKDXWrumCAMpl/TFQ4/5kLM=
github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ=
github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw=
github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk=
Expand Down Expand Up @@ -69,6 +68,8 @@ github.com/gobwas/pool v0.2.0/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6Wezm
github.com/gobwas/ws v1.0.2/go.mod h1:szmBTxLgaFppYjEmNtny/v3w89xOydFnnZMcgRRu/EM=
github.com/gofrs/uuid v3.2.0+incompatible h1:y12jRkkFxsd7GpqdSZ+/KCs/fJbqpEXSGd4+jfEaewE=
github.com/gofrs/uuid v3.2.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM=
github.com/golang-jwt/jwt/v4 v4.3.0 h1:kHL1vqdqWNfATmA0FNMdmZNMyZI1U6O31X4rlIPoBog=
github.com/golang-jwt/jwt/v4 v4.3.0/go.mod h1:/xlHOz8bRuivTWchD4jCa+NbatV+wEUSzwAxVc6locg=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
Expand Down Expand Up @@ -155,8 +156,6 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN
github.com/rcrowley/go-metrics v0.0.0-20200313005456-10cdbea86bc0/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4=
github.com/reddit/baseplate.go v0.8.0 h1:sr3gyaeDYKJUBEmLY4xjT/n//rAj9CsN+0ZuTcud2Ng=
github.com/reddit/baseplate.go v0.8.0/go.mod h1:4+EkX/w2lRCODOwFsIOXZni8/H7Vvz9gIesxYTlgchI=
github.com/reddit/jwt-go/v3 v3.2.2 h1:aAQBB/BEQCT6sxety7YUQ+8PZVZHUI15pJZQ9K0LUf4=
github.com/reddit/jwt-go/v3 v3.2.2/go.mod h1:FHWZE8ije8OtrVX1+upW5jBqKF8MirOcWgO6v6VE2Ak=
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g=
github.com/ryanuber/columnize v2.1.0+incompatible/go.mod h1:sm1tb6uqfes/u+d4ooFouqFdy9/2g9QGwK3SQygK0Ts=
Expand Down Expand Up @@ -211,6 +210,7 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20190701094942-4def268fd1a4/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200510223506-06a226fb4e37 h1:cg5LA/zNPRzIXIWSCxQW10Rvpy94aQh3LT/ShoCpkHw=
golang.org/x/crypto v0.0.0-20200510223506-06a226fb4e37/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
golang.org/x/lint v0.0.0-20200302205851-738671d3881b h1:Wh+f8QHJXR411sJR8/vRBTZ7YapZaRvUcLFFJhusH0k=
Expand Down
2 changes: 1 addition & 1 deletion lib/go/edgecontext/token.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
package edgecontext

import (
"github.com/golang-jwt/jwt/v4"
"github.com/reddit/baseplate.go/timebp"
jwt "github.com/reddit/jwt-go/v3"
)

// AuthenticationToken defines the json format of the authentication token.
Expand Down
89 changes: 70 additions & 19 deletions lib/go/edgecontext/validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,38 @@ import (
"errors"
"fmt"

"github.com/golang-jwt/jwt/v4"
"github.com/reddit/baseplate.go/log"
"github.com/reddit/baseplate.go/secrets"
jwt "github.com/reddit/jwt-go/v3"
"golang.org/x/crypto/ssh"
)

type keysType = []*rsa.PublicKey
type keysType struct {
// map of kid -> pub key.
m map[string]*rsa.PublicKey

// when either kid header does not exist in the jwt token,
// or the kid is not present in the map,
// we fallback to the first (usually current) key.
first *rsa.PublicKey
}

func (kt *keysType) getKey(kid string) *rsa.PublicKey {
if key := kt.m[kid]; key != nil {
return key
}
return kt.first
}

const (
authenticationPubKeySecretPath = "secret/authentication/public-key"
jwtAlg = "RS256"
)

// JWTHeaderKeyID is the JWT header for the key id,
// as defined in RFC 7517 section 4.5.
const JWTHeaderKeyID = "kid"

// ErrNoPublicKeysLoaded is an error returned by ValidateToken indicates that
// the function is called before any public keys are loaded from secrets.
var ErrNoPublicKeysLoaded = errors.New("edgecontext.ValidateToken: no public keys loaded")
Expand All @@ -28,7 +49,7 @@ var ErrEmptyToken = errors.New("edgecontext.ValidateToken: empty JWT token")
// ValidateToken parses and validates a jwt token, and return the decoded
// AuthenticationToken.
func (impl *Impl) ValidateToken(token string) (*AuthenticationToken, error) {
keys, ok := impl.keysValue.Load().(keysType)
keys, ok := impl.keysValue.Load().(*keysType)
if !ok {
// This would only happen when all previous middleware parsing failed.
return nil, ErrNoPublicKeysLoaded
Expand All @@ -48,8 +69,9 @@ func (impl *Impl) ValidateToken(token string) (*AuthenticationToken, error) {
tok, err := jwt.ParseWithClaims(
token,
&AuthenticationToken{},
func(_ *jwt.Token) (interface{}, error) {
return keys, nil
func(jt *jwt.Token) (interface{}, error) {
kid, _ := jt.Header[JWTHeaderKeyID].(string)
return keys.getKey(kid), nil
},
)
if err != nil {
Expand Down Expand Up @@ -85,26 +107,55 @@ func (impl *Impl) validatorMiddleware(next secrets.SecretHandlerFunc) secrets.Se
return
}

all := versioned.GetAll()
keys := make(keysType, 0, len(all))
for i, v := range all {
key, err := jwt.ParseRSAPublicKeyFromPEM([]byte(v))
if err != nil {
impl.logger.Log(context.Background(), fmt.Sprintf(
"Failed to parse key #%d: %v",
keys := parseVersionedKeys(context.Background(), versioned, impl.logger)
if keys != nil {
impl.keysValue.Store(keys)
}
}
}

func parseVersionedKeys(ctx context.Context, versioned secrets.VersionedSecret, logger log.Wrapper) *keysType {
all := versioned.GetAll()
keys := &keysType{
m: make(map[string]*rsa.PublicKey, len(all)),
}
for i, v := range all {
key, err := jwt.ParseRSAPublicKeyFromPEM([]byte(v))
if err != nil {
logger.Log(ctx, fmt.Sprintf(
"Failed to parse key #%d: %v",
i,
err,
))
} else {
if keys.first == nil {
keys.first = key
}
if fingerprint, err := RSAPublicKeyFingerprint(key); err != nil {
logger.Log(ctx, fmt.Sprintf(
"Failed to get fingerprint of key #%d: %v",
i,
err,
))
} else {
keys = append(keys, key)
keys.m[fingerprint] = key
}
}
}
if keys.first == nil {
logger.Log(ctx, "No valid keys in secrets store.")
return nil
}
return keys
}

if len(keys) == 0 {
impl.logger.Log(context.Background(), "No valid keys in secrets store.")
return
}

impl.keysValue.Store(keys)
// RSAPublicKeyFingerprint calculates the fingerprint of an RSA public key,
// using ssh.FingerprintSHA256:
// https://pkg.go.dev/golang.org/x/crypto/ssh#FingerprintSHA256
func RSAPublicKeyFingerprint(pubKey *rsa.PublicKey) (string, error) {
key, err := ssh.NewPublicKey(pubKey)
if err != nil {
return "", err
}
return ssh.FingerprintSHA256(key), nil
}
181 changes: 181 additions & 0 deletions lib/go/edgecontext/validator_middleware_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
package edgecontext

import (
"context"
"testing"

"github.com/reddit/baseplate.go/log"
"github.com/reddit/baseplate.go/secrets"
)

const (
validKey1 = `
-----BEGIN PUBLIC KEY-----
MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAtzMnDEQPd75QZByogNlB
NY2auyr4sy8UNTDARs79Edq/Jw5tb7ub412mOB61mVrcuFZW6xfmCRt0ILgoaT66
Tp1RpuEfghD+e7bYZ+Q2pckC1ZaVPIVVf/ZcCZ0tKQHoD8EpyyFINKjCh516VrCx
KuOm2fALPB/xDwDBEdeVJlh5/3HHP2V35scdvDRkvr2qkcvhzoy0+7wUWFRZ2n6H
TFrxMHQoHg0tutAJEkjsMw9xfN7V07c952SHNRZvu80V5EEpnKw/iYKXUjCmoXm8
tpJv5kXH6XPgfvOirSbTfuo+0VGqVIx9gcomzJ0I5WfGTD22dAxDiRT7q7KZnNgt
TwIDAQAB
-----END PUBLIC KEY-----`

fingerprint1 = "SHA256:lZ0hkWRsDpapeBu2ekX9WY2oYInHwdRaXTwtBecDicI"

validKey2 = `
-----BEGIN PUBLIC KEY-----
MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAycU1W/hMRWNLkaJPEwWg
j36URuSaRTV0BEvY+L0nRseCnEdlIsj8LCI+ydk3HlJqj3QicuCP9U0W5JAP4PYB
Xs+dV/J38fqdYfI1myXRG2wU5USziF3OC3YYZIXiPe41IltP7LSUmyRO/F6jAcUj
ZmRP2sxhIjY/77nQbx1F3ZMF2i91CRyaIfyd2pC8pwA4VElBTZaP9j3xXEsA8VIX
F/PSVcDsm3GoxVkwQbJTr54GedsRMoex574rvt8iujiNQ7Cb0uXWFIfnlD1thnne
4ws5ekuVhT6lq1KDB2z4e/pN2cOEzzSmfJJK1AWS79R4sAO8Fm/8cpWx6MRhlAbv
HwIDAQAB
-----END PUBLIC KEY-----`

fingerprint2 = "SHA256:EM4Jt7RjoQIPqpRFTadBCQkdzu+G4tq1RWd3f+I6nRg"

validKey3 = `
-----BEGIN PUBLIC KEY-----
MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA06q+yHMtXDj3qa3qELcg
bS/48HWbylEi+smx+xa8yupMTMtne6WFvxiS3lU/+TXQj+hdHzwpLj+W24QCON1o
JqxYDLWVJ2YpmrwkU/IDbhoPKfpYchy6Zmg2bnr93FDcvc4oL2/UYaiG+3w8fS+D
BcHug7ILLmY5RnwqzdcYfQ5waX2QCK75kmtB+TBqtS3xAr2m2omdla91YeARSu3O
lVjB6h9QNfbR6KCZRalMWlNGpp0tG0faU9mEescY4zfqt2inQFAr+MuXjJhg0tW8
kO6LskiW1+SbBlNrJeQDXUjC/vz6/8X1DvDeczd9tqbAxfV57yRjIxkfsDYxehai
6QIDAQAB
-----END PUBLIC KEY-----`

fingerprint3 = "SHA256:DGsuFb8nHgtg88dwIsTnGL3J8Hx+yCksl0WEBCbm5Zc"

invalidKey = `
-----BEGIN PUBLIC KEY-----
MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAtzMnDEQPd75QZByogNlB
NY2auyr4sy8UNTDARs79Edq/Jw5tb7ub412mOB61mVrcuFZW6xfmCRt0ILgoaT66
Tp1RpuEfghD+e7bYZ+Q2pckC1ZaVPIVVf/ZcCZ0tKQHoD8EpyyFINKjCh516VrCx
KuOm2fALPB/xDwDBEdeVJlh5/3HHP2V35scdvDRkvr2qkcvhzoy0+7wUWFRZ2n6H
TFrxMHQoHg0tutAJEkjsMw9xfN7V07c952SHNRZvu80V5EEpnKw/iYKXUjCmoXm8
tpJv5kXH6XPgfvOirSbTfuo+0VGqVIx9gcomzJ0I5WfGTD22dAxDiRT7q7KZnNgt
TwIDAQAB
`
)

func compareUnorderedFingerprints(tb testing.TB, got, want []string) {
tb.Helper()

tb.Logf("compareUnorderedFingerprints: got %v, want %v", got, want)

if len(got) != len(want) {
tb.Errorf("len mismatch: got %d, want %d", len(got), len(want))
}

for _, s := range got {
var found bool
for _, t := range want {
if s == t {
found = true
break
}
}
if !found {
tb.Errorf("%q in got not found in want", s)
}
}

for _, t := range want {
var found bool
for _, s := range got {
if s == t {
found = true
break
}
}
if !found {
tb.Errorf("%q in want not found in got", t)
}
}
}

func TestParseVersionedKeys(t *testing.T) {
for _, c := range []struct {
label string
secret secrets.VersionedSecret
nopLogger bool
expectNil bool
firstFingerprint string
fingerprints []string
}{
{
label: "all-valid",
secret: secrets.VersionedSecret{
Current: []byte(validKey1),
Previous: []byte(validKey2),
Next: []byte(validKey3),
},
firstFingerprint: fingerprint1,
fingerprints: []string{
fingerprint1,
fingerprint2,
fingerprint3,
},
},
{
label: "invalid-current",
secret: secrets.VersionedSecret{
Current: []byte(invalidKey),
Previous: []byte(validKey2),
Next: []byte(validKey3),
},
nopLogger: true,
firstFingerprint: fingerprint2,
fingerprints: []string{
fingerprint2,
fingerprint3,
},
},
{
label: "only-current",
secret: secrets.VersionedSecret{
Current: []byte(validKey1),
},
firstFingerprint: fingerprint1,
fingerprints: []string{
fingerprint1,
},
},
} {
t.Run(c.label, func(t *testing.T) {
var logger log.Wrapper
if c.nopLogger {
logger = log.NopWrapper
} else {
logger = log.TestWrapper(t)
}
keys := parseVersionedKeys(context.Background(), c.secret, logger)
if c.expectNil {
if keys != nil {
t.Errorf("Expected nil result, got %v", keys)
return
}
} else {
if keys == nil {
t.Error("Unexpected nil result")
return
}
}
fingerprints := make([]string, 0, len(keys.m))
for k := range keys.m {
fingerprints = append(fingerprints, k)
}
compareUnorderedFingerprints(t, fingerprints, c.fingerprints)

fingerprint, err := RSAPublicKeyFingerprint(keys.first)
if err != nil {
t.Errorf("Unable to calculate fingerprint of keys.first: %v", err)
}
if fingerprint != c.firstFingerprint {
t.Errorf("keys.first fingerprint got %q, want %q", fingerprint, c.firstFingerprint)
}
})
}
}
Loading

0 comments on commit 7499b9c

Please sign in to comment.