Skip to content
This repository has been archived by the owner on May 21, 2022. It is now read-only.

Native support for key rotation in verifications #372

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 27 additions & 19 deletions ecdsa.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,9 @@ func (m *SigningMethodECDSA) Alg() string {
return m.Name
}

// Implements the Verify method from SigningMethod
// For this verify method, key must be an ecdsa.PublicKey struct
// Implements the Verify method from SigningMethod.
// For this verify method, key must be in types of either *ecdsa.PublicKey or
// []*ecdsa.PublicKey (for rotation keys).
func (m *SigningMethodECDSA) Verify(signingString, signature string, key interface{}) error {
var err error

Expand All @@ -64,35 +65,42 @@ func (m *SigningMethodECDSA) Verify(signingString, signature string, key interfa
return err
}

// Get the key
var ecdsaKey *ecdsa.PublicKey
switch k := key.(type) {
case *ecdsa.PublicKey:
ecdsaKey = k
default:
return ErrInvalidKeyType
}

if len(sig) != 2*m.KeySize {
return ErrECDSAVerification
}

r := big.NewInt(0).SetBytes(sig[:m.KeySize])
s := big.NewInt(0).SetBytes(sig[m.KeySize:])

// Create hasher
if !m.Hash.Available() {
return ErrHashUnavailable
}
hasher := m.Hash.New()
hasher.Write([]byte(signingString))

// Verify the signature
if verifystatus := ecdsa.Verify(ecdsaKey, hasher.Sum(nil), r, s); verifystatus == true {
return nil
} else {
return ErrECDSAVerification
// Get the keys
var keys []*ecdsa.PublicKey
switch v := key.(type) {
case *ecdsa.PublicKey:
keys = append(keys, v)
case []*ecdsa.PublicKey:
keys = v
}
if len(keys) == 0 {
return ErrInvalidKeyType
}

var lastErr error
for _, ecdsaKey := range keys {
// Create hasher
hasher := m.Hash.New()
hasher.Write([]byte(signingString))

// Verify the signature
if verifystatus := ecdsa.Verify(ecdsaKey, hasher.Sum(nil), r, s); verifystatus == true {
return nil
}
lastErr = ErrECDSAVerification
}
return lastErr
}

// Implements the Sign method from SigningMethod
Expand Down
47 changes: 47 additions & 0 deletions ecdsa_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,53 @@ func TestECDSAVerify(t *testing.T) {
}
}

func TestECDSAVerifyKeyRotation(t *testing.T) {
targetName := "Basic ES256"
for _, data := range ecdsaTestData {
if data.name != targetName {
continue
}

var err error

key, _ := ioutil.ReadFile("test/ec256-public.pem")
var ecdsaKey *ecdsa.PublicKey
if ecdsaKey, err = jwt.ParseECPublicKeyFromPEM(key); err != nil {
t.Errorf("Unable to parse ECDSA public key: %v", err)
}

key, _ = ioutil.ReadFile("test/ec384-public.pem")
var invalidKey1 *ecdsa.PublicKey
if invalidKey1, err = jwt.ParseECPublicKeyFromPEM(key); err != nil {
t.Errorf("Unable to parse ECDSA public key: %v", err)
}

key, _ = ioutil.ReadFile("test/ec512-public.pem")
var invalidKey2 *ecdsa.PublicKey
if invalidKey2, err = jwt.ParseECPublicKeyFromPEM(key); err != nil {
t.Errorf("Unable to parse ECDSA public key: %v", err)
}

parts := strings.Split(data.tokenString, ".")

method := jwt.GetSigningMethod(data.alg)
err = method.Verify(strings.Join(parts[0:2], "."), parts[2], []*ecdsa.PublicKey{invalidKey1, ecdsaKey, invalidKey2})
if err != nil {
t.Errorf("[%v] Error while verifying key: %v", data.name, err)
}

err = method.Verify(strings.Join(parts[0:2], "."), parts[2], []*ecdsa.PublicKey{})
if err == nil {
t.Errorf("[%v] Empty key list passed validation", data.name)
}

err = method.Verify(strings.Join(parts[0:2], "."), parts[2], []*ecdsa.PublicKey{invalidKey1, invalidKey2})
if err == nil {
t.Errorf("[%v] Key list with only invalid keys passed validation", data.name)
}
}
}

func TestECDSASign(t *testing.T) {
for _, data := range ecdsaTestData {
var err error
Expand Down
39 changes: 24 additions & 15 deletions hmac.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,6 @@ func (m *SigningMethodHMAC) Alg() string {

// Verify the signature of HSXXX tokens. Returns nil if the signature is valid.
func (m *SigningMethodHMAC) Verify(signingString, signature string, key interface{}) error {
// Verify the key is the right type
keyBytes, ok := key.([]byte)
if !ok {
return ErrInvalidKeyType
}

// Decode signature, for comparison
sig, err := DecodeSegment(signature)
if err != nil {
Expand All @@ -64,17 +58,32 @@ func (m *SigningMethodHMAC) Verify(signingString, signature string, key interfac
return ErrHashUnavailable
}

// This signing method is symmetric, so we validate the signature
// by reproducing the signature from the signing string and key, then
// comparing that against the provided signature.
hasher := hmac.New(m.Hash.New, keyBytes)
hasher.Write([]byte(signingString))
if !hmac.Equal(sig, hasher.Sum(nil)) {
return ErrSignatureInvalid
// Verify the keys are the right types
var keys [][]byte
switch v := key.(type) {
case []byte:
keys = append(keys, v)
case [][]byte:
keys = v
}
if len(keys) == 0 {
return ErrInvalidKeyType
}

// No validation errors. Signature is good.
return nil
var lastErr error
for _, keyBytes := range keys {
// This signing method is symmetric, so we validate the signature
// by reproducing the signature from the signing string and key, then
// comparing that against the provided signature.
hasher := hmac.New(m.Hash.New, keyBytes)
hasher.Write([]byte(signingString))
if hmac.Equal(sig, hasher.Sum(nil)) {
// No validation errors. Signature is good.
return nil
}
lastErr = ErrSignatureInvalid
}
return lastErr
}

// Implements the Sign method from SigningMethod for this signing method.
Expand Down
29 changes: 29 additions & 0 deletions hmac_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,35 @@ func TestHMACVerify(t *testing.T) {
}
}

func TestHMACVerifyKeyRotation(t *testing.T) {
invalidKey1 := []byte("foo")
invalidKey2 := []byte("bar")
for _, data := range hmacTestData {
parts := strings.Split(data.tokenString, ".")

method := jwt.GetSigningMethod(data.alg)
err := method.Verify(strings.Join(parts[0:2], "."), parts[2], [][]byte{invalidKey1, hmacTestKey, invalidKey2})
if data.valid && err != nil {
t.Errorf("[%v] Error while verifying key: %v", data.name, err)
}
if !data.valid && err == nil {
t.Errorf("[%v] Invalid key passed validation", data.name)
}

if !data.valid {
continue
}
err = method.Verify(strings.Join(parts[0:2], "."), parts[2], [][]byte{})
if err == nil {
t.Errorf("[%v] Empty key list passed validation", data.name)
}
err = method.Verify(strings.Join(parts[0:2], "."), parts[2], [][]byte{invalidKey1, invalidKey2})
if err == nil {
t.Errorf("[%v] Key list with only invalid keys passed validation", data.name)
}
}
}

func TestHMACSign(t *testing.T) {
for _, data := range hmacTestData {
if data.valid {
Expand Down
37 changes: 25 additions & 12 deletions rsa.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ func (m *SigningMethodRSA) Alg() string {
}

// Implements the Verify method from SigningMethod
// For this signing method, must be an *rsa.PublicKey structure.
// For this signing method, key must be in types of either *rsa.PublicKey or
// []*rsa.PublicKey (for rotation keys).
func (m *SigningMethodRSA) Verify(signingString, signature string, key interface{}) error {
var err error

Expand All @@ -55,22 +56,34 @@ func (m *SigningMethodRSA) Verify(signingString, signature string, key interface
return err
}

var rsaKey *rsa.PublicKey
var ok bool
if !m.Hash.Available() {
return ErrHashUnavailable
}

if rsaKey, ok = key.(*rsa.PublicKey); !ok {
var keys []*rsa.PublicKey
switch v := key.(type) {
case *rsa.PublicKey:
keys = append(keys, v)
case []*rsa.PublicKey:
keys = v
}
if len(keys) == 0 {
return ErrInvalidKeyType
}

// Create hasher
if !m.Hash.Available() {
return ErrHashUnavailable
var lastErr error
for _, rsaKey := range keys {
// Create hasher
hasher := m.Hash.New()
hasher.Write([]byte(signingString))

// Verify the signature
lastErr = rsa.VerifyPKCS1v15(rsaKey, m.Hash, hasher.Sum(nil), sig)
if lastErr == nil {
return nil
}
}
hasher := m.Hash.New()
hasher.Write([]byte(signingString))

// Verify the signature
return rsa.VerifyPKCS1v15(rsaKey, m.Hash, hasher.Sum(nil), sig)
return lastErr
}

// Implements the Sign method from SigningMethod
Expand Down
44 changes: 28 additions & 16 deletions rsa_pss.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ func init() {
}

// Implements the Verify method from SigningMethod
// For this verify method, key must be an rsa.PublicKey struct
// For this verify method, key must be in the types of either *rsa.PublicKey or
// []*rsa.PublicKey (for rotation keys).
func (m *SigningMethodRSAPSS) Verify(signingString, signature string, key interface{}) error {
var err error

Expand All @@ -90,27 +91,38 @@ func (m *SigningMethodRSAPSS) Verify(signingString, signature string, key interf
return err
}

var rsaKey *rsa.PublicKey
switch k := key.(type) {
case *rsa.PublicKey:
rsaKey = k
default:
return ErrInvalidKey
}

// Create hasher
if !m.Hash.Available() {
return ErrHashUnavailable
}
hasher := m.Hash.New()
hasher.Write([]byte(signingString))

opts := m.Options
if m.VerifyOptions != nil {
opts = m.VerifyOptions
var keys []*rsa.PublicKey
switch v := key.(type) {
case *rsa.PublicKey:
keys = append(keys, v)
case []*rsa.PublicKey:
keys = v
}
if len(keys) == 0 {
return ErrInvalidKeyType
}

return rsa.VerifyPSS(rsaKey, m.Hash, hasher.Sum(nil), sig, opts)
var lastErr error
for _, rsaKey := range keys {
// Create hasher
hasher := m.Hash.New()
hasher.Write([]byte(signingString))

opts := m.Options
if m.VerifyOptions != nil {
opts = m.VerifyOptions
}

lastErr = rsa.VerifyPSS(rsaKey, m.Hash, hasher.Sum(nil), sig, opts)
if lastErr == nil {
return nil
}
}
return lastErr
}

// Implements the Sign method from SigningMethod
Expand Down
12 changes: 12 additions & 0 deletions rsa_pss_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,18 @@ func TestRSAPSSVerify(t *testing.T) {
if !data.valid && err == nil {
t.Errorf("[%v] Invalid key passed validation", data.name)
}

if !data.valid {
continue
}
err = method.Verify(strings.Join(parts[0:2], "."), parts[2], []*rsa.PublicKey{rsaPSSKey})
if err != nil {
t.Errorf("[%v] Error while verifying key list: %v", data.name, err)
}
err = method.Verify(strings.Join(parts[0:2], "."), parts[2], []*rsa.PublicKey{})
if err == nil {
t.Errorf("[%v] Empty key list passed validation", data.name)
}
}
}

Expand Down
13 changes: 13 additions & 0 deletions rsa_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package jwt_test

import (
"crypto/rsa"
"github.com/dgrijalva/jwt-go"
"io/ioutil"
"strings"
Expand Down Expand Up @@ -59,6 +60,18 @@ func TestRSAVerify(t *testing.T) {
if !data.valid && err == nil {
t.Errorf("[%v] Invalid key passed validation", data.name)
}

if !data.valid {
continue
}
err = method.Verify(strings.Join(parts[0:2], "."), parts[2], []*rsa.PublicKey{key})
if err != nil {
t.Errorf("[%v] Error while verifying key list: %v", data.name, err)
}
err = method.Verify(strings.Join(parts[0:2], "."), parts[2], []*rsa.PublicKey{})
if err == nil {
t.Errorf("[%v] Empty key list passed validation", data.name)
}
}
}

Expand Down