From 7499b9cb7ab6e468dfc212c21379db2db4ee4336 Mon Sep 17 00:00:00 2001 From: Yuxuan 'fishy' Wang Date: Tue, 15 Mar 2022 16:12:45 -0700 Subject: [PATCH] go: Use JWKS for key rotation Instead of trying all possible pub keys one-by-one, build a map of kid (fingerprint) to pub key and use the kid header in jwt to select the key to use. If either the kid header is absent or we don't have any key with matching fingerprint, fallback to the first (usually current) key from the versioned secret. This also allows us to switch from reddit's forked version of jwt back to the upstream version. --- go.mod | 3 +- go.sum | 6 +- lib/go/edgecontext/token.go | 2 +- lib/go/edgecontext/validator.go | 89 +++++++-- .../edgecontext/validator_middleware_test.go | 181 ++++++++++++++++++ lib/go/edgecontext/validator_test.go | 32 ++++ 6 files changed, 289 insertions(+), 24 deletions(-) create mode 100644 lib/go/edgecontext/validator_middleware_test.go diff --git a/go.mod b/go.mod index 93d849a..601cec8 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.16 require ( github.com/apache/thrift v0.14.1 github.com/gofrs/uuid v3.2.0+incompatible + github.com/golang-jwt/jwt/v4 v4.3.0 github.com/reddit/baseplate.go v0.8.0 - github.com/reddit/jwt-go/v3 v3.2.2 + golang.org/x/crypto v0.0.0-20200510223506-06a226fb4e37 ) diff --git a/go.sum b/go.sum index 307094a..f0254a6 100644 --- a/go.sum +++ b/go.sum @@ -32,7 +32,6 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dgraph-io/badger v1.6.0/go.mod h1:zwt7syl517jmP8s94KqSxTlM6IMsdhYy6psNgSztDR4= -github.com/dgrijalva/jwt-go v3.2.0+incompatible h1:7qlOGliEKZXTDg6OTjfoBKDXWrumCAMpl/TFQ4/5kLM= github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw= github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= @@ -69,6 +68,8 @@ github.com/gobwas/pool v0.2.0/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6Wezm github.com/gobwas/ws v1.0.2/go.mod h1:szmBTxLgaFppYjEmNtny/v3w89xOydFnnZMcgRRu/EM= github.com/gofrs/uuid v3.2.0+incompatible h1:y12jRkkFxsd7GpqdSZ+/KCs/fJbqpEXSGd4+jfEaewE= github.com/gofrs/uuid v3.2.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= +github.com/golang-jwt/jwt/v4 v4.3.0 h1:kHL1vqdqWNfATmA0FNMdmZNMyZI1U6O31X4rlIPoBog= +github.com/golang-jwt/jwt/v4 v4.3.0/go.mod h1:/xlHOz8bRuivTWchD4jCa+NbatV+wEUSzwAxVc6locg= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= @@ -155,8 +156,6 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN github.com/rcrowley/go-metrics v0.0.0-20200313005456-10cdbea86bc0/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4= github.com/reddit/baseplate.go v0.8.0 h1:sr3gyaeDYKJUBEmLY4xjT/n//rAj9CsN+0ZuTcud2Ng= github.com/reddit/baseplate.go v0.8.0/go.mod h1:4+EkX/w2lRCODOwFsIOXZni8/H7Vvz9gIesxYTlgchI= -github.com/reddit/jwt-go/v3 v3.2.2 h1:aAQBB/BEQCT6sxety7YUQ+8PZVZHUI15pJZQ9K0LUf4= -github.com/reddit/jwt-go/v3 v3.2.2/go.mod h1:FHWZE8ije8OtrVX1+upW5jBqKF8MirOcWgO6v6VE2Ak= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g= github.com/ryanuber/columnize v2.1.0+incompatible/go.mod h1:sm1tb6uqfes/u+d4ooFouqFdy9/2g9QGwK3SQygK0Ts= @@ -211,6 +210,7 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190701094942-4def268fd1a4/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20200510223506-06a226fb4e37 h1:cg5LA/zNPRzIXIWSCxQW10Rvpy94aQh3LT/ShoCpkHw= golang.org/x/crypto v0.0.0-20200510223506-06a226fb4e37/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/lint v0.0.0-20200302205851-738671d3881b h1:Wh+f8QHJXR411sJR8/vRBTZ7YapZaRvUcLFFJhusH0k= diff --git a/lib/go/edgecontext/token.go b/lib/go/edgecontext/token.go index 6508041..0faf3e8 100644 --- a/lib/go/edgecontext/token.go +++ b/lib/go/edgecontext/token.go @@ -1,8 +1,8 @@ package edgecontext import ( + "github.com/golang-jwt/jwt/v4" "github.com/reddit/baseplate.go/timebp" - jwt "github.com/reddit/jwt-go/v3" ) // AuthenticationToken defines the json format of the authentication token. diff --git a/lib/go/edgecontext/validator.go b/lib/go/edgecontext/validator.go index f90c741..062aa3e 100644 --- a/lib/go/edgecontext/validator.go +++ b/lib/go/edgecontext/validator.go @@ -6,17 +6,38 @@ import ( "errors" "fmt" + "github.com/golang-jwt/jwt/v4" + "github.com/reddit/baseplate.go/log" "github.com/reddit/baseplate.go/secrets" - jwt "github.com/reddit/jwt-go/v3" + "golang.org/x/crypto/ssh" ) -type keysType = []*rsa.PublicKey +type keysType struct { + // map of kid -> pub key. + m map[string]*rsa.PublicKey + + // when either kid header does not exist in the jwt token, + // or the kid is not present in the map, + // we fallback to the first (usually current) key. + first *rsa.PublicKey +} + +func (kt *keysType) getKey(kid string) *rsa.PublicKey { + if key := kt.m[kid]; key != nil { + return key + } + return kt.first +} const ( authenticationPubKeySecretPath = "secret/authentication/public-key" jwtAlg = "RS256" ) +// JWTHeaderKeyID is the JWT header for the key id, +// as defined in RFC 7517 section 4.5. +const JWTHeaderKeyID = "kid" + // ErrNoPublicKeysLoaded is an error returned by ValidateToken indicates that // the function is called before any public keys are loaded from secrets. var ErrNoPublicKeysLoaded = errors.New("edgecontext.ValidateToken: no public keys loaded") @@ -28,7 +49,7 @@ var ErrEmptyToken = errors.New("edgecontext.ValidateToken: empty JWT token") // ValidateToken parses and validates a jwt token, and return the decoded // AuthenticationToken. func (impl *Impl) ValidateToken(token string) (*AuthenticationToken, error) { - keys, ok := impl.keysValue.Load().(keysType) + keys, ok := impl.keysValue.Load().(*keysType) if !ok { // This would only happen when all previous middleware parsing failed. return nil, ErrNoPublicKeysLoaded @@ -48,8 +69,9 @@ func (impl *Impl) ValidateToken(token string) (*AuthenticationToken, error) { tok, err := jwt.ParseWithClaims( token, &AuthenticationToken{}, - func(_ *jwt.Token) (interface{}, error) { - return keys, nil + func(jt *jwt.Token) (interface{}, error) { + kid, _ := jt.Header[JWTHeaderKeyID].(string) + return keys.getKey(kid), nil }, ) if err != nil { @@ -85,26 +107,55 @@ func (impl *Impl) validatorMiddleware(next secrets.SecretHandlerFunc) secrets.Se return } - all := versioned.GetAll() - keys := make(keysType, 0, len(all)) - for i, v := range all { - key, err := jwt.ParseRSAPublicKeyFromPEM([]byte(v)) - if err != nil { - impl.logger.Log(context.Background(), fmt.Sprintf( - "Failed to parse key #%d: %v", + keys := parseVersionedKeys(context.Background(), versioned, impl.logger) + if keys != nil { + impl.keysValue.Store(keys) + } + } +} + +func parseVersionedKeys(ctx context.Context, versioned secrets.VersionedSecret, logger log.Wrapper) *keysType { + all := versioned.GetAll() + keys := &keysType{ + m: make(map[string]*rsa.PublicKey, len(all)), + } + for i, v := range all { + key, err := jwt.ParseRSAPublicKeyFromPEM([]byte(v)) + if err != nil { + logger.Log(ctx, fmt.Sprintf( + "Failed to parse key #%d: %v", + i, + err, + )) + } else { + if keys.first == nil { + keys.first = key + } + if fingerprint, err := RSAPublicKeyFingerprint(key); err != nil { + logger.Log(ctx, fmt.Sprintf( + "Failed to get fingerprint of key #%d: %v", i, err, )) } else { - keys = append(keys, key) + keys.m[fingerprint] = key } } + } + if keys.first == nil { + logger.Log(ctx, "No valid keys in secrets store.") + return nil + } + return keys +} - if len(keys) == 0 { - impl.logger.Log(context.Background(), "No valid keys in secrets store.") - return - } - - impl.keysValue.Store(keys) +// RSAPublicKeyFingerprint calculates the fingerprint of an RSA public key, +// using ssh.FingerprintSHA256: +// https://pkg.go.dev/golang.org/x/crypto/ssh#FingerprintSHA256 +func RSAPublicKeyFingerprint(pubKey *rsa.PublicKey) (string, error) { + key, err := ssh.NewPublicKey(pubKey) + if err != nil { + return "", err } + return ssh.FingerprintSHA256(key), nil } diff --git a/lib/go/edgecontext/validator_middleware_test.go b/lib/go/edgecontext/validator_middleware_test.go new file mode 100644 index 0000000..af44a65 --- /dev/null +++ b/lib/go/edgecontext/validator_middleware_test.go @@ -0,0 +1,181 @@ +package edgecontext + +import ( + "context" + "testing" + + "github.com/reddit/baseplate.go/log" + "github.com/reddit/baseplate.go/secrets" +) + +const ( + validKey1 = ` +-----BEGIN PUBLIC KEY----- +MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAtzMnDEQPd75QZByogNlB +NY2auyr4sy8UNTDARs79Edq/Jw5tb7ub412mOB61mVrcuFZW6xfmCRt0ILgoaT66 +Tp1RpuEfghD+e7bYZ+Q2pckC1ZaVPIVVf/ZcCZ0tKQHoD8EpyyFINKjCh516VrCx +KuOm2fALPB/xDwDBEdeVJlh5/3HHP2V35scdvDRkvr2qkcvhzoy0+7wUWFRZ2n6H +TFrxMHQoHg0tutAJEkjsMw9xfN7V07c952SHNRZvu80V5EEpnKw/iYKXUjCmoXm8 +tpJv5kXH6XPgfvOirSbTfuo+0VGqVIx9gcomzJ0I5WfGTD22dAxDiRT7q7KZnNgt +TwIDAQAB +-----END PUBLIC KEY-----` + + fingerprint1 = "SHA256:lZ0hkWRsDpapeBu2ekX9WY2oYInHwdRaXTwtBecDicI" + + validKey2 = ` +-----BEGIN PUBLIC KEY----- +MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAycU1W/hMRWNLkaJPEwWg +j36URuSaRTV0BEvY+L0nRseCnEdlIsj8LCI+ydk3HlJqj3QicuCP9U0W5JAP4PYB +Xs+dV/J38fqdYfI1myXRG2wU5USziF3OC3YYZIXiPe41IltP7LSUmyRO/F6jAcUj +ZmRP2sxhIjY/77nQbx1F3ZMF2i91CRyaIfyd2pC8pwA4VElBTZaP9j3xXEsA8VIX +F/PSVcDsm3GoxVkwQbJTr54GedsRMoex574rvt8iujiNQ7Cb0uXWFIfnlD1thnne +4ws5ekuVhT6lq1KDB2z4e/pN2cOEzzSmfJJK1AWS79R4sAO8Fm/8cpWx6MRhlAbv +HwIDAQAB +-----END PUBLIC KEY-----` + + fingerprint2 = "SHA256:EM4Jt7RjoQIPqpRFTadBCQkdzu+G4tq1RWd3f+I6nRg" + + validKey3 = ` +-----BEGIN PUBLIC KEY----- +MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA06q+yHMtXDj3qa3qELcg +bS/48HWbylEi+smx+xa8yupMTMtne6WFvxiS3lU/+TXQj+hdHzwpLj+W24QCON1o +JqxYDLWVJ2YpmrwkU/IDbhoPKfpYchy6Zmg2bnr93FDcvc4oL2/UYaiG+3w8fS+D +BcHug7ILLmY5RnwqzdcYfQ5waX2QCK75kmtB+TBqtS3xAr2m2omdla91YeARSu3O +lVjB6h9QNfbR6KCZRalMWlNGpp0tG0faU9mEescY4zfqt2inQFAr+MuXjJhg0tW8 +kO6LskiW1+SbBlNrJeQDXUjC/vz6/8X1DvDeczd9tqbAxfV57yRjIxkfsDYxehai +6QIDAQAB +-----END PUBLIC KEY-----` + + fingerprint3 = "SHA256:DGsuFb8nHgtg88dwIsTnGL3J8Hx+yCksl0WEBCbm5Zc" + + invalidKey = ` +-----BEGIN PUBLIC KEY----- +MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAtzMnDEQPd75QZByogNlB +NY2auyr4sy8UNTDARs79Edq/Jw5tb7ub412mOB61mVrcuFZW6xfmCRt0ILgoaT66 +Tp1RpuEfghD+e7bYZ+Q2pckC1ZaVPIVVf/ZcCZ0tKQHoD8EpyyFINKjCh516VrCx +KuOm2fALPB/xDwDBEdeVJlh5/3HHP2V35scdvDRkvr2qkcvhzoy0+7wUWFRZ2n6H +TFrxMHQoHg0tutAJEkjsMw9xfN7V07c952SHNRZvu80V5EEpnKw/iYKXUjCmoXm8 +tpJv5kXH6XPgfvOirSbTfuo+0VGqVIx9gcomzJ0I5WfGTD22dAxDiRT7q7KZnNgt +TwIDAQAB +` +) + +func compareUnorderedFingerprints(tb testing.TB, got, want []string) { + tb.Helper() + + tb.Logf("compareUnorderedFingerprints: got %v, want %v", got, want) + + if len(got) != len(want) { + tb.Errorf("len mismatch: got %d, want %d", len(got), len(want)) + } + + for _, s := range got { + var found bool + for _, t := range want { + if s == t { + found = true + break + } + } + if !found { + tb.Errorf("%q in got not found in want", s) + } + } + + for _, t := range want { + var found bool + for _, s := range got { + if s == t { + found = true + break + } + } + if !found { + tb.Errorf("%q in want not found in got", t) + } + } +} + +func TestParseVersionedKeys(t *testing.T) { + for _, c := range []struct { + label string + secret secrets.VersionedSecret + nopLogger bool + expectNil bool + firstFingerprint string + fingerprints []string + }{ + { + label: "all-valid", + secret: secrets.VersionedSecret{ + Current: []byte(validKey1), + Previous: []byte(validKey2), + Next: []byte(validKey3), + }, + firstFingerprint: fingerprint1, + fingerprints: []string{ + fingerprint1, + fingerprint2, + fingerprint3, + }, + }, + { + label: "invalid-current", + secret: secrets.VersionedSecret{ + Current: []byte(invalidKey), + Previous: []byte(validKey2), + Next: []byte(validKey3), + }, + nopLogger: true, + firstFingerprint: fingerprint2, + fingerprints: []string{ + fingerprint2, + fingerprint3, + }, + }, + { + label: "only-current", + secret: secrets.VersionedSecret{ + Current: []byte(validKey1), + }, + firstFingerprint: fingerprint1, + fingerprints: []string{ + fingerprint1, + }, + }, + } { + t.Run(c.label, func(t *testing.T) { + var logger log.Wrapper + if c.nopLogger { + logger = log.NopWrapper + } else { + logger = log.TestWrapper(t) + } + keys := parseVersionedKeys(context.Background(), c.secret, logger) + if c.expectNil { + if keys != nil { + t.Errorf("Expected nil result, got %v", keys) + return + } + } else { + if keys == nil { + t.Error("Unexpected nil result") + return + } + } + fingerprints := make([]string, 0, len(keys.m)) + for k := range keys.m { + fingerprints = append(fingerprints, k) + } + compareUnorderedFingerprints(t, fingerprints, c.fingerprints) + + fingerprint, err := RSAPublicKeyFingerprint(keys.first) + if err != nil { + t.Errorf("Unable to calculate fingerprint of keys.first: %v", err) + } + if fingerprint != c.firstFingerprint { + t.Errorf("keys.first fingerprint got %q, want %q", fingerprint, c.firstFingerprint) + } + }) + } +} diff --git a/lib/go/edgecontext/validator_test.go b/lib/go/edgecontext/validator_test.go index c6a80ba..44aceb7 100644 --- a/lib/go/edgecontext/validator_test.go +++ b/lib/go/edgecontext/validator_test.go @@ -2,6 +2,10 @@ package edgecontext_test import ( "testing" + + "github.com/golang-jwt/jwt/v4" + + "github.com/reddit/edgecontext/lib/go/edgecontext" ) // copied from https://github.com/reddit/edgecontext.py/blob/420e58728ee7085a2f91c5db45df233142b251f9/tests/edge_context_tests.py#L54 @@ -18,3 +22,31 @@ func TestValidToken(t *testing.T) { t.Errorf("subject expected %q, got %q", expected, actual) } } + +const ( + testPubKeyPEM = `-----BEGIN PUBLIC KEY----- +MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAtzMnDEQPd75QZByogNlB +NY2auyr4sy8UNTDARs79Edq/Jw5tb7ub412mOB61mVrcuFZW6xfmCRt0ILgoaT66 +Tp1RpuEfghD+e7bYZ+Q2pckC1ZaVPIVVf/ZcCZ0tKQHoD8EpyyFINKjCh516VrCx +KuOm2fALPB/xDwDBEdeVJlh5/3HHP2V35scdvDRkvr2qkcvhzoy0+7wUWFRZ2n6H +TFrxMHQoHg0tutAJEkjsMw9xfN7V07c952SHNRZvu80V5EEpnKw/iYKXUjCmoXm8 +tpJv5kXH6XPgfvOirSbTfuo+0VGqVIx9gcomzJ0I5WfGTD22dAxDiRT7q7KZnNgt +TwIDAQAB +-----END PUBLIC KEY-----` + + expectedFingerprint = "SHA256:lZ0hkWRsDpapeBu2ekX9WY2oYInHwdRaXTwtBecDicI" +) + +func TestFingerprint(t *testing.T) { + pubKey, err := jwt.ParseRSAPublicKeyFromPEM([]byte(testPubKeyPEM)) + if err != nil { + t.Fatalf("Unable to parse pub key from PEM: %v", err) + } + fingerprint, err := edgecontext.RSAPublicKeyFingerprint(pubKey) + if err != nil { + t.Errorf("Unable to calculate fingerprint from pub key: %v", err) + } + if fingerprint != expectedFingerprint { + t.Errorf("Fingerprint got %q, want %q", fingerprint, expectedFingerprint) + } +}