diff --git a/go.mod b/go.mod index 93d849a..601cec8 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.16 require ( github.com/apache/thrift v0.14.1 github.com/gofrs/uuid v3.2.0+incompatible + github.com/golang-jwt/jwt/v4 v4.3.0 github.com/reddit/baseplate.go v0.8.0 - github.com/reddit/jwt-go/v3 v3.2.2 + golang.org/x/crypto v0.0.0-20200510223506-06a226fb4e37 ) diff --git a/go.sum b/go.sum index 307094a..f0254a6 100644 --- a/go.sum +++ b/go.sum @@ -32,7 +32,6 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dgraph-io/badger v1.6.0/go.mod h1:zwt7syl517jmP8s94KqSxTlM6IMsdhYy6psNgSztDR4= -github.com/dgrijalva/jwt-go v3.2.0+incompatible h1:7qlOGliEKZXTDg6OTjfoBKDXWrumCAMpl/TFQ4/5kLM= github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw= github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= @@ -69,6 +68,8 @@ github.com/gobwas/pool v0.2.0/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6Wezm github.com/gobwas/ws v1.0.2/go.mod h1:szmBTxLgaFppYjEmNtny/v3w89xOydFnnZMcgRRu/EM= github.com/gofrs/uuid v3.2.0+incompatible h1:y12jRkkFxsd7GpqdSZ+/KCs/fJbqpEXSGd4+jfEaewE= github.com/gofrs/uuid v3.2.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= +github.com/golang-jwt/jwt/v4 v4.3.0 h1:kHL1vqdqWNfATmA0FNMdmZNMyZI1U6O31X4rlIPoBog= +github.com/golang-jwt/jwt/v4 v4.3.0/go.mod h1:/xlHOz8bRuivTWchD4jCa+NbatV+wEUSzwAxVc6locg= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= @@ -155,8 +156,6 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN github.com/rcrowley/go-metrics v0.0.0-20200313005456-10cdbea86bc0/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4= github.com/reddit/baseplate.go v0.8.0 h1:sr3gyaeDYKJUBEmLY4xjT/n//rAj9CsN+0ZuTcud2Ng= github.com/reddit/baseplate.go v0.8.0/go.mod h1:4+EkX/w2lRCODOwFsIOXZni8/H7Vvz9gIesxYTlgchI= -github.com/reddit/jwt-go/v3 v3.2.2 h1:aAQBB/BEQCT6sxety7YUQ+8PZVZHUI15pJZQ9K0LUf4= -github.com/reddit/jwt-go/v3 v3.2.2/go.mod h1:FHWZE8ije8OtrVX1+upW5jBqKF8MirOcWgO6v6VE2Ak= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g= github.com/ryanuber/columnize v2.1.0+incompatible/go.mod h1:sm1tb6uqfes/u+d4ooFouqFdy9/2g9QGwK3SQygK0Ts= @@ -211,6 +210,7 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190701094942-4def268fd1a4/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20200510223506-06a226fb4e37 h1:cg5LA/zNPRzIXIWSCxQW10Rvpy94aQh3LT/ShoCpkHw= golang.org/x/crypto v0.0.0-20200510223506-06a226fb4e37/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/lint v0.0.0-20200302205851-738671d3881b h1:Wh+f8QHJXR411sJR8/vRBTZ7YapZaRvUcLFFJhusH0k= diff --git a/lib/go/edgecontext/token.go b/lib/go/edgecontext/token.go index 6508041..0faf3e8 100644 --- a/lib/go/edgecontext/token.go +++ b/lib/go/edgecontext/token.go @@ -1,8 +1,8 @@ package edgecontext import ( + "github.com/golang-jwt/jwt/v4" "github.com/reddit/baseplate.go/timebp" - jwt "github.com/reddit/jwt-go/v3" ) // AuthenticationToken defines the json format of the authentication token. diff --git a/lib/go/edgecontext/validator.go b/lib/go/edgecontext/validator.go index f90c741..062aa3e 100644 --- a/lib/go/edgecontext/validator.go +++ b/lib/go/edgecontext/validator.go @@ -6,17 +6,38 @@ import ( "errors" "fmt" + "github.com/golang-jwt/jwt/v4" + "github.com/reddit/baseplate.go/log" "github.com/reddit/baseplate.go/secrets" - jwt "github.com/reddit/jwt-go/v3" + "golang.org/x/crypto/ssh" ) -type keysType = []*rsa.PublicKey +type keysType struct { + // map of kid -> pub key. + m map[string]*rsa.PublicKey + + // when either kid header does not exist in the jwt token, + // or the kid is not present in the map, + // we fallback to the first (usually current) key. + first *rsa.PublicKey +} + +func (kt *keysType) getKey(kid string) *rsa.PublicKey { + if key := kt.m[kid]; key != nil { + return key + } + return kt.first +} const ( authenticationPubKeySecretPath = "secret/authentication/public-key" jwtAlg = "RS256" ) +// JWTHeaderKeyID is the JWT header for the key id, +// as defined in RFC 7517 section 4.5. +const JWTHeaderKeyID = "kid" + // ErrNoPublicKeysLoaded is an error returned by ValidateToken indicates that // the function is called before any public keys are loaded from secrets. var ErrNoPublicKeysLoaded = errors.New("edgecontext.ValidateToken: no public keys loaded") @@ -28,7 +49,7 @@ var ErrEmptyToken = errors.New("edgecontext.ValidateToken: empty JWT token") // ValidateToken parses and validates a jwt token, and return the decoded // AuthenticationToken. func (impl *Impl) ValidateToken(token string) (*AuthenticationToken, error) { - keys, ok := impl.keysValue.Load().(keysType) + keys, ok := impl.keysValue.Load().(*keysType) if !ok { // This would only happen when all previous middleware parsing failed. return nil, ErrNoPublicKeysLoaded @@ -48,8 +69,9 @@ func (impl *Impl) ValidateToken(token string) (*AuthenticationToken, error) { tok, err := jwt.ParseWithClaims( token, &AuthenticationToken{}, - func(_ *jwt.Token) (interface{}, error) { - return keys, nil + func(jt *jwt.Token) (interface{}, error) { + kid, _ := jt.Header[JWTHeaderKeyID].(string) + return keys.getKey(kid), nil }, ) if err != nil { @@ -85,26 +107,55 @@ func (impl *Impl) validatorMiddleware(next secrets.SecretHandlerFunc) secrets.Se return } - all := versioned.GetAll() - keys := make(keysType, 0, len(all)) - for i, v := range all { - key, err := jwt.ParseRSAPublicKeyFromPEM([]byte(v)) - if err != nil { - impl.logger.Log(context.Background(), fmt.Sprintf( - "Failed to parse key #%d: %v", + keys := parseVersionedKeys(context.Background(), versioned, impl.logger) + if keys != nil { + impl.keysValue.Store(keys) + } + } +} + +func parseVersionedKeys(ctx context.Context, versioned secrets.VersionedSecret, logger log.Wrapper) *keysType { + all := versioned.GetAll() + keys := &keysType{ + m: make(map[string]*rsa.PublicKey, len(all)), + } + for i, v := range all { + key, err := jwt.ParseRSAPublicKeyFromPEM([]byte(v)) + if err != nil { + logger.Log(ctx, fmt.Sprintf( + "Failed to parse key #%d: %v", + i, + err, + )) + } else { + if keys.first == nil { + keys.first = key + } + if fingerprint, err := RSAPublicKeyFingerprint(key); err != nil { + logger.Log(ctx, fmt.Sprintf( + "Failed to get fingerprint of key #%d: %v", i, err, )) } else { - keys = append(keys, key) + keys.m[fingerprint] = key } } + } + if keys.first == nil { + logger.Log(ctx, "No valid keys in secrets store.") + return nil + } + return keys +} - if len(keys) == 0 { - impl.logger.Log(context.Background(), "No valid keys in secrets store.") - return - } - - impl.keysValue.Store(keys) +// RSAPublicKeyFingerprint calculates the fingerprint of an RSA public key, +// using ssh.FingerprintSHA256: +// https://pkg.go.dev/golang.org/x/crypto/ssh#FingerprintSHA256 +func RSAPublicKeyFingerprint(pubKey *rsa.PublicKey) (string, error) { + key, err := ssh.NewPublicKey(pubKey) + if err != nil { + return "", err } + return ssh.FingerprintSHA256(key), nil } diff --git a/lib/go/edgecontext/validator_middleware_test.go b/lib/go/edgecontext/validator_middleware_test.go new file mode 100644 index 0000000..af44a65 --- /dev/null +++ b/lib/go/edgecontext/validator_middleware_test.go @@ -0,0 +1,181 @@ +package edgecontext + +import ( + "context" + "testing" + + "github.com/reddit/baseplate.go/log" + "github.com/reddit/baseplate.go/secrets" +) + +const ( + validKey1 = ` +-----BEGIN PUBLIC KEY----- +MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAtzMnDEQPd75QZByogNlB +NY2auyr4sy8UNTDARs79Edq/Jw5tb7ub412mOB61mVrcuFZW6xfmCRt0ILgoaT66 +Tp1RpuEfghD+e7bYZ+Q2pckC1ZaVPIVVf/ZcCZ0tKQHoD8EpyyFINKjCh516VrCx +KuOm2fALPB/xDwDBEdeVJlh5/3HHP2V35scdvDRkvr2qkcvhzoy0+7wUWFRZ2n6H +TFrxMHQoHg0tutAJEkjsMw9xfN7V07c952SHNRZvu80V5EEpnKw/iYKXUjCmoXm8 +tpJv5kXH6XPgfvOirSbTfuo+0VGqVIx9gcomzJ0I5WfGTD22dAxDiRT7q7KZnNgt +TwIDAQAB +-----END PUBLIC KEY-----` + + fingerprint1 = "SHA256:lZ0hkWRsDpapeBu2ekX9WY2oYInHwdRaXTwtBecDicI" + + validKey2 = ` +-----BEGIN PUBLIC KEY----- +MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAycU1W/hMRWNLkaJPEwWg +j36URuSaRTV0BEvY+L0nRseCnEdlIsj8LCI+ydk3HlJqj3QicuCP9U0W5JAP4PYB +Xs+dV/J38fqdYfI1myXRG2wU5USziF3OC3YYZIXiPe41IltP7LSUmyRO/F6jAcUj +ZmRP2sxhIjY/77nQbx1F3ZMF2i91CRyaIfyd2pC8pwA4VElBTZaP9j3xXEsA8VIX +F/PSVcDsm3GoxVkwQbJTr54GedsRMoex574rvt8iujiNQ7Cb0uXWFIfnlD1thnne +4ws5ekuVhT6lq1KDB2z4e/pN2cOEzzSmfJJK1AWS79R4sAO8Fm/8cpWx6MRhlAbv +HwIDAQAB +-----END PUBLIC KEY-----` + + fingerprint2 = "SHA256:EM4Jt7RjoQIPqpRFTadBCQkdzu+G4tq1RWd3f+I6nRg" + + validKey3 = ` +-----BEGIN PUBLIC KEY----- +MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA06q+yHMtXDj3qa3qELcg +bS/48HWbylEi+smx+xa8yupMTMtne6WFvxiS3lU/+TXQj+hdHzwpLj+W24QCON1o +JqxYDLWVJ2YpmrwkU/IDbhoPKfpYchy6Zmg2bnr93FDcvc4oL2/UYaiG+3w8fS+D +BcHug7ILLmY5RnwqzdcYfQ5waX2QCK75kmtB+TBqtS3xAr2m2omdla91YeARSu3O +lVjB6h9QNfbR6KCZRalMWlNGpp0tG0faU9mEescY4zfqt2inQFAr+MuXjJhg0tW8 +kO6LskiW1+SbBlNrJeQDXUjC/vz6/8X1DvDeczd9tqbAxfV57yRjIxkfsDYxehai +6QIDAQAB +-----END PUBLIC KEY-----` + + fingerprint3 = "SHA256:DGsuFb8nHgtg88dwIsTnGL3J8Hx+yCksl0WEBCbm5Zc" + + invalidKey = ` +-----BEGIN PUBLIC KEY----- +MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAtzMnDEQPd75QZByogNlB +NY2auyr4sy8UNTDARs79Edq/Jw5tb7ub412mOB61mVrcuFZW6xfmCRt0ILgoaT66 +Tp1RpuEfghD+e7bYZ+Q2pckC1ZaVPIVVf/ZcCZ0tKQHoD8EpyyFINKjCh516VrCx +KuOm2fALPB/xDwDBEdeVJlh5/3HHP2V35scdvDRkvr2qkcvhzoy0+7wUWFRZ2n6H +TFrxMHQoHg0tutAJEkjsMw9xfN7V07c952SHNRZvu80V5EEpnKw/iYKXUjCmoXm8 +tpJv5kXH6XPgfvOirSbTfuo+0VGqVIx9gcomzJ0I5WfGTD22dAxDiRT7q7KZnNgt +TwIDAQAB +` +) + +func compareUnorderedFingerprints(tb testing.TB, got, want []string) { + tb.Helper() + + tb.Logf("compareUnorderedFingerprints: got %v, want %v", got, want) + + if len(got) != len(want) { + tb.Errorf("len mismatch: got %d, want %d", len(got), len(want)) + } + + for _, s := range got { + var found bool + for _, t := range want { + if s == t { + found = true + break + } + } + if !found { + tb.Errorf("%q in got not found in want", s) + } + } + + for _, t := range want { + var found bool + for _, s := range got { + if s == t { + found = true + break + } + } + if !found { + tb.Errorf("%q in want not found in got", t) + } + } +} + +func TestParseVersionedKeys(t *testing.T) { + for _, c := range []struct { + label string + secret secrets.VersionedSecret + nopLogger bool + expectNil bool + firstFingerprint string + fingerprints []string + }{ + { + label: "all-valid", + secret: secrets.VersionedSecret{ + Current: []byte(validKey1), + Previous: []byte(validKey2), + Next: []byte(validKey3), + }, + firstFingerprint: fingerprint1, + fingerprints: []string{ + fingerprint1, + fingerprint2, + fingerprint3, + }, + }, + { + label: "invalid-current", + secret: secrets.VersionedSecret{ + Current: []byte(invalidKey), + Previous: []byte(validKey2), + Next: []byte(validKey3), + }, + nopLogger: true, + firstFingerprint: fingerprint2, + fingerprints: []string{ + fingerprint2, + fingerprint3, + }, + }, + { + label: "only-current", + secret: secrets.VersionedSecret{ + Current: []byte(validKey1), + }, + firstFingerprint: fingerprint1, + fingerprints: []string{ + fingerprint1, + }, + }, + } { + t.Run(c.label, func(t *testing.T) { + var logger log.Wrapper + if c.nopLogger { + logger = log.NopWrapper + } else { + logger = log.TestWrapper(t) + } + keys := parseVersionedKeys(context.Background(), c.secret, logger) + if c.expectNil { + if keys != nil { + t.Errorf("Expected nil result, got %v", keys) + return + } + } else { + if keys == nil { + t.Error("Unexpected nil result") + return + } + } + fingerprints := make([]string, 0, len(keys.m)) + for k := range keys.m { + fingerprints = append(fingerprints, k) + } + compareUnorderedFingerprints(t, fingerprints, c.fingerprints) + + fingerprint, err := RSAPublicKeyFingerprint(keys.first) + if err != nil { + t.Errorf("Unable to calculate fingerprint of keys.first: %v", err) + } + if fingerprint != c.firstFingerprint { + t.Errorf("keys.first fingerprint got %q, want %q", fingerprint, c.firstFingerprint) + } + }) + } +} diff --git a/lib/go/edgecontext/validator_test.go b/lib/go/edgecontext/validator_test.go index c6a80ba..44aceb7 100644 --- a/lib/go/edgecontext/validator_test.go +++ b/lib/go/edgecontext/validator_test.go @@ -2,6 +2,10 @@ package edgecontext_test import ( "testing" + + "github.com/golang-jwt/jwt/v4" + + "github.com/reddit/edgecontext/lib/go/edgecontext" ) // copied from https://github.com/reddit/edgecontext.py/blob/420e58728ee7085a2f91c5db45df233142b251f9/tests/edge_context_tests.py#L54 @@ -18,3 +22,31 @@ func TestValidToken(t *testing.T) { t.Errorf("subject expected %q, got %q", expected, actual) } } + +const ( + testPubKeyPEM = `-----BEGIN PUBLIC KEY----- +MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAtzMnDEQPd75QZByogNlB +NY2auyr4sy8UNTDARs79Edq/Jw5tb7ub412mOB61mVrcuFZW6xfmCRt0ILgoaT66 +Tp1RpuEfghD+e7bYZ+Q2pckC1ZaVPIVVf/ZcCZ0tKQHoD8EpyyFINKjCh516VrCx +KuOm2fALPB/xDwDBEdeVJlh5/3HHP2V35scdvDRkvr2qkcvhzoy0+7wUWFRZ2n6H +TFrxMHQoHg0tutAJEkjsMw9xfN7V07c952SHNRZvu80V5EEpnKw/iYKXUjCmoXm8 +tpJv5kXH6XPgfvOirSbTfuo+0VGqVIx9gcomzJ0I5WfGTD22dAxDiRT7q7KZnNgt +TwIDAQAB +-----END PUBLIC KEY-----` + + expectedFingerprint = "SHA256:lZ0hkWRsDpapeBu2ekX9WY2oYInHwdRaXTwtBecDicI" +) + +func TestFingerprint(t *testing.T) { + pubKey, err := jwt.ParseRSAPublicKeyFromPEM([]byte(testPubKeyPEM)) + if err != nil { + t.Fatalf("Unable to parse pub key from PEM: %v", err) + } + fingerprint, err := edgecontext.RSAPublicKeyFingerprint(pubKey) + if err != nil { + t.Errorf("Unable to calculate fingerprint from pub key: %v", err) + } + if fingerprint != expectedFingerprint { + t.Errorf("Fingerprint got %q, want %q", fingerprint, expectedFingerprint) + } +}